Commit 6c08e7d2 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix

parent 78a3bc47
......@@ -220,7 +220,6 @@ def rollout(
args: Args,
rollout_queue,
params_queue,
eval_queue,
writer,
learner_devices,
device_thread_id,
......@@ -428,9 +427,8 @@ def rollout(
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
print(eval_stats)
else:
......@@ -799,7 +797,6 @@ if __name__ == "__main__":
params_queues = []
rollout_queues = []
eval_queue = queue.Queue()
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids):
......@@ -818,7 +815,6 @@ if __name__ == "__main__":
args,
rollout_queues[-1],
params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices,
actor_thread_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment