Commit 4ef751bf authored by sbl1996@126.com's avatar sbl1996@126.com

Accurate evaluate (battle)

parent b55996bf
...@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None): ...@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
assert num_episodes == envs.num_envs
num_envs = envs.num_envs num_envs = envs.num_envs
episode_rewards = [] episode_rewards = []
episode_lengths = [] episode_lengths = []
...@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): ...@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_) dones = np.zeros(num_envs, dtype=np.bool_)
collected = np.zeros((num_episodes,), dtype=np.bool_)
main_player = np.concatenate([ main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs // 2, dtype=np.int64),
...@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): ...@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
next_to_play = infos['to_play'] next_to_play = infos['to_play']
for idx, d in enumerate(dones): for idx, d in enumerate(dones):
if not d: if not d or collected[idx]:
continue continue
collected[idx] = True
episode_length = infos['l'][idx] episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1) episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
......
...@@ -2230,7 +2230,7 @@ public: ...@@ -2230,7 +2230,7 @@ public:
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // to_play_ is the previous player
reward = winner_ == to_play_ ? base_reward : -base_reward; reward = winner_ == player ? base_reward : -base_reward;
} else { } else {
reward = winner_ == ai_player_ ? base_reward : -base_reward; reward = winner_ == ai_player_ ? base_reward : -base_reward;
} }
...@@ -2416,7 +2416,7 @@ private: ...@@ -2416,7 +2416,7 @@ private:
} }
fmt::print("\n"); fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec); // throw std::runtime_error("Spec not found: " + spec);
spec_infos[spec] = {1, 1}; spec_infos[spec] = {0, 0};
return spec_infos[spec]; return spec_infos[spec];
} }
return it->second; return it->second;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment