Commit 78a3bc47 authored by sbl1996@126.com's avatar sbl1996@126.com

Count global eval_stats

parent 81f63996
...@@ -417,6 +417,25 @@ def rollout( ...@@ -417,6 +417,25 @@ def rollout(
sharded_data = jax.tree.map(lambda x: jax.device_put_sharded( sharded_data = jax.tree.map(lambda x: jax.device_put_sharded(
np.split(x, len(learner_devices)), devices=learner_devices), np.split(x, len(learner_devices)), devices=learner_devices),
(init_rstate1, init_rstate2, (next_rstate, next_obs), next_main)) (init_rstate1, init_rstate2, (next_rstate, next_obs), next_main))
if args.eval_interval and update % args.eval_interval == 0:
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
eval_stats = np.array([eval_time, eval_return, eval_win_rate], dtype=np.float32)
print(eval_stats)
else:
eval_stats = None
learn_opponent = False learn_opponent = False
payload = ( payload = (
global_step, global_step,
...@@ -425,6 +444,7 @@ def rollout( ...@@ -425,6 +444,7 @@ def rollout(
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
learn_opponent, learn_opponent,
eval_stats,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -451,34 +471,6 @@ def rollout( ...@@ -451,34 +471,6 @@ def rollout(
writer.add_scalar("charts/SPS", SPS, global_step) writer.add_scalar("charts/SPS", SPS, global_step)
writer.add_scalar("charts/SPS_update", SPS_update, global_step) writer.add_scalar("charts/SPS_update", SPS_update, global_step)
if args.eval_interval and update % args.eval_interval == 0:
# Eval with rule-based policy
_start = time.time()
if eval_mode == 'bot':
predict_fn = lambda x: get_action(params, x)
eval_return, eval_ep_len, eval_win_rate = evaluate(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
else:
predict_fn = lambda *x: get_action_battle(params, eval_params, *x)
eval_return, eval_ep_len, eval_win_rate = battle(
eval_envs, args.local_eval_episodes, predict_fn, eval_rstate)
eval_stat = np.array([eval_return, eval_win_rate])
if device_thread_id != 0:
eval_queue.put(eval_stat)
else:
eval_stats = []
eval_stats.append(eval_stat)
for _ in range(1, n_actors):
eval_stats.append(eval_queue.get())
eval_stats = np.stack(eval_stats)
eval_return, eval_win_rate = np.mean(eval_stats, axis=0)
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
if device_thread_id == 0:
eval_time = time.time() - _start
other_time += eval_time
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
args = tyro.cli(Args) args = tyro.cli(Args)
...@@ -525,6 +517,10 @@ if __name__ == "__main__": ...@@ -525,6 +517,10 @@ if __name__ == "__main__":
for process_index in range(args.world_size) for process_index in range(args.world_size)
for d_id in args.learner_device_ids for d_id in args.learner_device_ids
] ]
global_main_devices = [
global_devices[process_index * len(local_devices)]
for process_index in range(args.world_size)
]
print("global_learner_decices", global_learner_decices) print("global_learner_decices", global_learner_decices)
args.global_learner_decices = [str(item) for item in global_learner_decices] args.global_learner_decices = [str(item) for item in global_learner_decices]
args.actor_devices = [str(item) for item in actor_devices] args.actor_devices = [str(item) for item in actor_devices]
...@@ -788,6 +784,12 @@ if __name__ == "__main__": ...@@ -788,6 +784,12 @@ if __name__ == "__main__":
approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean() approx_kl = jax.lax.pmean(approx_kl, axis_name="local_devices").mean()
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
all_reduce_value = jax.pmap(
lambda x: jax.lax.pmean(x, axis_name="main_devices"),
axis_name="main_devices",
devices=global_main_devices,
)
multi_device_update = jax.pmap( multi_device_update = jax.pmap(
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
...@@ -831,6 +833,7 @@ if __name__ == "__main__": ...@@ -831,6 +833,7 @@ if __name__ == "__main__":
learner_policy_version += 1 learner_policy_version += 1
rollout_queue_get_time_start = time.time() rollout_queue_get_time_start = time.time()
sharded_data_list = [] sharded_data_list = []
eval_stat_list = []
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
( (
...@@ -839,8 +842,23 @@ if __name__ == "__main__": ...@@ -839,8 +842,23 @@ if __name__ == "__main__":
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
learn_opponent, learn_opponent,
eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
if eval_stats is not None:
eval_stat_list.append(eval_stats)
if update % args.eval_interval == 0:
eval_stats = np.mean(eval_stat_list, axis=0)
print(eval_stats)
eval_stats = jax.device_put(eval_stats, local_devices[0])
eval_stats = np.array(all_reduce_value(eval_stats[None])[0])
eval_time, eval_return, eval_win_rate = eval_stats
writer.add_scalar(f"charts/eval_return", eval_return, global_step)
writer.add_scalar(f"charts/eval_win_rate", eval_win_rate, global_step)
print(f"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}")
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time() training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update( (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment