Commit 6c7ac38a authored by sbl1996@126.com's avatar sbl1996@126.com

Count global eval_stats in impala

parent 6c08e7d2
...@@ -220,7 +220,6 @@ def rollout( ...@@ -220,7 +220,6 @@ def rollout(
args: Args, args: Args,
rollout_queue, rollout_queue,
params_queue, params_queue,
eval_queue,
writer, writer,
learner_devices, learner_devices,
device_thread_id, device_thread_id,
...@@ -407,6 +406,23 @@ def rollout( ...@@ -407,6 +406,23 @@ def rollout(
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_done, next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
else:
eval_stats = None
learn_opponent = False learn_opponent = False
payload = ( payload = (
global_step, global_step,
...@@ -415,6 +431,7 @@ def rollout( ...@@ -415,6 +431,7 @@ def rollout(
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
learn_opponent, learn_opponent,
eval_stats,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -441,34 +458,6 @@ def rollout( ...@@ -441,34 +458,6 @@ def rollout(
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step) writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
...@@ -515,6 +504,10 @@ if __name__ == "__main__": ...@@ -515,6 +504,10 @@ if __name__ == "__main__":
for process_index in range(args.world_size) for process_index in range(args.world_size)
for d_id in args.learner_device_ids for d_id in args.learner_device_ids
] ]
global_main_devices = [
global_devices[process_index * len(local_devices)]
for process_index in range(args.world_size)
]
print("global_learner_decices", global_learner_decices) print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices] args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices] args.actor_devices = [str(item) for item in actor_devices]
...@@ -762,6 +755,12 @@ if __name__ == "__main__": ...@@ -762,6 +755,12 @@ if __name__ == "__main__":
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean() approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
all_reduce_value = jax.pmap(
lambda x: jax.lax.pmean(x, axis_name="main_devices"),
axis_name="main_devices",
devices=global_main_devices,
)
multi_device_update = jax.pmap( multi_device_update = jax.pmap(
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
...@@ -771,7 +770,6 @@ if __name__ == "__main__": ...@@ -771,7 +770,6 @@ if __name__ == "__main__":
params_queues = [] params_queues = []
rollout_queues = [] rollout_queues = []
eval_queue = queue.Queue()
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
...@@ -790,7 +788,6 @@ if __name__ == "__main__": ...@@ -790,7 +788,6 @@ if __name__ == "__main__":
args, args,
rollout_queues[-1], rollout_queues[-1],
params_queues[-1], params_queues[-1],
eval_queue,
writer if d_idx == 0 and thread_id == 0 else dummy_writer, writer if d_idx == 0 and thread_id == 0 else dummy_writer,
learner_devices, learner_devices,
actor_thread_id, actor_thread_id,
...@@ -805,6 +802,7 @@ if __name__ == "__main__": ...@@ -805,6 +802,7 @@ if __name__ == "__main__":
learner_policy_version += 1 learner_policy_version += 1
rollout_queue_get_time_start = time.time() rollout_queue_get_time_start = time.time()
sharded_data_list = [] sharded_data_list = []
eval_stat_list = []
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
( (
...@@ -813,8 +811,21 @@ if __name__ == "__main__": ...@@ -813,8 +811,21 @@ if __name__ == "__main__":
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
learn_opponent, learn_opponent,
eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
if eval_stats is not None:
eval_stat_list.append(eval_stats)
if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0)
eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time() training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update( (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
......
...@@ -430,7 +430,6 @@ def rollout( ...@@ -430,7 +430,6 @@ def rollout(
eval_time = time.time() - _start eval_time = time.time() - _start
other_time += eval_time other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32) eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
print(eval_stats)
else: else:
eval_stats = None eval_stats = None
...@@ -846,7 +845,6 @@ if __name__ == "__main__": ...@@ -846,7 +845,6 @@ if __name__ == "__main__":
if update % args.eval_interval == 0: if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0) eval_stats = np.mean(eval_stat_list, axis=0)
print(eval_stats)
eval_stats = jax.device_put(eval_stats, local_devices[0]) eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0]) eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats eval_time, eval_return, eval_win_rate = eval_stats
...@@ -854,7 +852,6 @@ if __name__ == "__main__": ...@@ -854,7 +852,6 @@ if __name__ == "__main__":
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step) writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}") print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time() training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update( (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment