Commit 4ef751bf authored by sbl1996@126.com's avatar sbl1996@126.com

Accurate evaluate (battle)

parent b55996bf
...@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None): ...@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
assert num_episodes == envs.num_envs
num_envs = envs.num_envs num_envs = envs.num_envs
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): ...@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_) dones = np.zeros(num_envs, dtype=np.bool_)
collected = np.zeros((num_episodes,), dtype=np.bool_)
main_player = np.concatenate([ main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
...@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): ...@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
next_to_play = infos['to_play'] next_to_play = infos['to_play']
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if not d: if not d or collected[idx]:
continue continue
collected[idx] = True
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1) episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
......
...@@ -2230,7 +2230,7 @@ public: ...@@ -2230,7 +2230,7 @@ public:
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // to_play_ is the previous player
reward = winner_ == to_play_ ? base_reward : -base_reward; reward = winner_ == player ? base_reward : -base_reward;
} else { } else {
reward = winner_ == ai_player_ ? base_reward : -base_reward; reward = winner_ == ai_player_ ? base_reward : -base_reward;
} }
...@@ -2403,21 +2403,21 @@ private: ...@@ -2403,21 +2403,21 @@ private:
const SpecInfo& find_spec_info(SpecInfos &spec_infos, const std::string &spec) { const SpecInfo& find_spec_info(SpecInfos &spec_infos, const std::string &spec) {
auto it = spec_infos.find(spec); auto it = spec_infos.find(spec);
if (it == spec_infos.end()) { if (it == spec_infos.end()) {
// TODO(2): find the root cause // TODO(2): find the root cause
// print spec2index // print spec2index
show_deck(0); show_deck(0);
show_deck(1); show_deck(1);
show_buffer(); show_buffer();
show_turn(); show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_); fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec); fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec_infos) { for (auto &[k, v] : spec_infos) {
fmt::print("{}: {} {}, ", k, v.index, v.cid); fmt::print("{}: {} {}, ", k, v.index, v.cid);
} }
fmt::print("\n"); fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec); // throw std::runtime_error("Spec not found: " + spec);
spec_infos[spec] = {1, 1}; spec_infos[spec] = {0, 0};
return spec_infos[spec]; return spec_infos[spec];
} }
return it->second; return it->second;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment