Commit 4ef751bf authored by sbl1996@126.com's avatar sbl1996@126.com

Accurate evaluate (battle)

parent b55996bf
......@@ -37,6 +37,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
assert num_episodes == envs.num_envs
num_envs = envs.num_envs
episode_rewards = []
episode_lengths = []
......@@ -45,6 +46,7 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
obs, infos = envs.reset()
next_to_play = infos['to_play']
dones = np.zeros(num_envs, dtype=np.bool_)
collected = np.zeros((num_episodes,), dtype=np.bool_)
main_player = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
......@@ -60,8 +62,9 @@ def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None):
next_to_play = infos['to_play']
for idx, d in enumerate(dones):
if not d:
if not d or collected[idx]:
continue
collected[idx] = True
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * (1 if main[idx] else -1)
win = 1 if episode_reward > 0 else 0
......
......@@ -2230,7 +2230,7 @@ public:
if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player
reward = winner_ == to_play_ ? base_reward : -base_reward;
reward = winner_ == player ? base_reward : -base_reward;
} else {
reward = winner_ == ai_player_ ? base_reward : -base_reward;
}
......@@ -2403,21 +2403,21 @@ private:
const SpecInfo& find_spec_info(SpecInfos &spec_infos, const std::string &spec) {
auto it = spec_infos.find(spec);
if (it == spec_infos.end()) {
// TODO(2): find the root cause
// print spec2index
show_deck(0);
show_deck(1);
show_buffer();
show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec_infos) {
fmt::print("{}: {} {}, ", k, v.index, v.cid);
}
fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec);
spec_infos[spec] = {1, 1};
return spec_infos[spec];
// TODO(2): find the root cause
// print spec2index
show_deck(0);
show_deck(1);
show_buffer();
show_turn();
fmt::println("MS: idx: {}, mode: {}, min: {}, max: {}, must: {}, specs: {}, combs: {}", ms_idx_, ms_mode_, ms_min_, ms_max_, ms_must_, ms_specs_, ms_combs_);
fmt::println("Spec: {}, Spec2index:", spec);
for (auto &[k, v] : spec_infos) {
fmt::print("{}: {} {}, ", k, v.index, v.cid);
}
fmt::print("\n");
// throw std::runtime_error("Spec not found: " + spec);
spec_infos[spec] = {0, 0};
return spec_infos[spec];
}
return it->second;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment