Commit 3ac7492c authored by sbl1996@126.com's avatar sbl1996@126.com

fix battle for 1 num_envs

parent ad9980e5
...@@ -174,9 +174,10 @@ if __name__ == "__main__": ...@@ -174,9 +174,10 @@ if __name__ == "__main__":
start = time.time() start = time.time()
start_step = step start_step = step
num_envs_half = num_envs // 2
player1_ = np.concatenate([ player1_ = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64), np.zeros(num_envs_half, dtype=np.int64),
np.ones(num_envs // 2, dtype=np.int64) np.ones(num_envs - num_envs_half, dtype=np.int64)
]) ])
player1 = torch.from_numpy(player1_).to(device=device) player1 = torch.from_numpy(player1_).to(device=device)
...@@ -221,6 +222,11 @@ if __name__ == "__main__": ...@@ -221,6 +222,11 @@ if __name__ == "__main__":
win_rates.append(win) win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0) win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n") sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
if args.verbose:
player1_ = 1 - player1_
player1 = 1 - player1
if len(episode_lengths) >= args.num_episodes: if len(episode_lengths) >= args.num_episodes:
break break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment