Commit 4ef751bf authored by sbl1996@126.com's avatar sbl1996@126.com

Accurate evaluate (battle)

parent b55996bf
......@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
assert num_episodes == envs.num_envs
num_envs = envs.num_envs
episode_rewards = []
episode_lengths = []
......@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
collected = np.zeros((num_episodes,), dtype=np.bool_)
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
......@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
next_to_play = infos['to_play']
for idx, d in enumerate(dones):
if not d:
if not d or collected[idx]:
continue
collected[idx] = True
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0
......
......@@ -2230,7 +2230,7 @@ public:
if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player
reward = winner_ == to_play_ ? base_reward : -base_reward;
reward = winner_ == player ? base_reward : -base_reward;
} else {
reward = winner_ == ai_player_ ? base_reward : -base_reward;
}
......@@ -2416,7 +2416,7 @@ private:
}
fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec);
spec_infos[spec] = {1, 1};
spec_infos[spec] = {0, 0};
return spec_infos[spec];
}
return it->second;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment