Commit 2a419375 authored by sbl1996@126.com's avatar sbl1996@126.com

Add league training

parent cd2974ce
...@@ -43,8 +43,8 @@ class Args: ...@@ -43,8 +43,8 @@ class Args:
"""seed of the experiment""" """seed of the experiment"""
log_frequency: int = 10 log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)""" """the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq: int = 1000 time_log_freq: int = 0
"""the logging frequency of the deck time statistics""" """the logging frequency of the deck time statistics, 0 to disable"""
save_interval: int = 400 save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)""" """the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None checkpoint: Optional[str] = None
...@@ -360,7 +360,7 @@ def rollout( ...@@ -360,7 +360,7 @@ def rollout(
if args.concurrency: if args.concurrency:
if update != 2: if update != 2:
params = params_queue.get() params = params_queue.get()
params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready() # params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
actor_policy_version += 1 actor_policy_version += 1
else: else:
params = params_queue.get() params = params_queue.get()
...@@ -416,20 +416,21 @@ def rollout( ...@@ -416,20 +416,21 @@ def rollout(
t.next_dones[idx] = True t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx] t.rewards[idx] = -next_reward[idx]
break break
for i in range(2): if args.time_log_freq:
deck_time = info['step_time'][idx][i] for i in range(2):
deck_name = deck_names[info['deck'][idx][i]] deck_time = info['step_time'][idx][i]
deck_name = deck_names[info['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_avg_times[deck_name] time_count = deck_time_count[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1) avg_time = deck_avg_times[deck_name]
max_time = max(deck_time, deck_max_times[deck_name]) avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
deck_avg_times[deck_name] = avg_time max_time = max(deck_time, deck_max_times[deck_name])
deck_max_times[deck_name] = max_time deck_avg_times[deck_name] = avg_time
deck_time_count[deck_name] += 1 deck_max_times[deck_name] = max_time
if deck_time_count[deck_name] % args.time_log_freq == 0: deck_time_count[deck_name] += 1
print(f"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}") if deck_time_count[deck_name] % args.time_log_freq == 0:
print(f"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}")
episode_reward = info['r'][idx] * (1 if cur_main else -1) episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0 win = 1 if episode_reward > 0 else 0
...@@ -474,14 +475,12 @@ def rollout( ...@@ -474,14 +475,12 @@ def rollout(
else: else:
eval_stats = None eval_stats = None
learn_opponent = False
payload = ( payload = (
global_step, global_step,
update, update,
sharded_storage, sharded_storage,
*sharded_data, *sharded_data,
np.mean(params_queue_get_time), np.mean(params_queue_get_time),
learn_opponent,
eval_stats, eval_stats,
) )
rollout_queue.put(payload) rollout_queue.put(payload)
...@@ -758,7 +757,6 @@ def main(): ...@@ -758,7 +757,6 @@ def main():
sharded_next_inputs: List, sharded_next_inputs: List,
sharded_next_main: List, sharded_next_main: List,
key: jax.random.PRNGKey, key: jax.random.PRNGKey,
learn_opponent: bool = False,
): ):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2 # TODO: rstate will be out-date after the first update, maybe consider R2D2
...@@ -862,7 +860,6 @@ def main(): ...@@ -862,7 +860,6 @@ def main():
single_device_update, single_device_update,
axis_name="local_devices", axis_name="local_devices",
devices=global_learner_decices, devices=global_learner_decices,
static_broadcasted_argnums=(7,),
) )
params_queues = [] params_queues = []
...@@ -906,7 +903,6 @@ def main(): ...@@ -906,7 +903,6 @@ def main():
update, update,
*sharded_data, *sharded_data,
avg_params_queue_get_time, avg_params_queue_get_time,
learn_opponent,
eval_stats, eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data) sharded_data_list.append(sharded_data)
...@@ -929,13 +925,12 @@ def main(): ...@@ -929,13 +925,12 @@ def main():
agent_state, agent_state,
*list(zip(*sharded_data_list)), *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
learn_opponent,
) )
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
params_queue_put_time = 0 params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id]) device_params = jax.device_put(unreplicated_params, local_devices[d_id])
# device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready() device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start = time.time() params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params) params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment