Commit 2a419375 authored by sbl1996@126.com's avatar sbl1996@126.com

Add league training

parent cd2974ce
......@@ -43,8 +43,8 @@ class Args:
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq: int = 1000
"""the logging frequency of the deck time statistics"""
time_log_freq: int = 0
"""the logging frequency of the deck time statistics, 0 to disable"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
......@@ -360,7 +360,7 @@ def rollout(
if args.concurrency:
if update != 2:
params = params_queue.get()
params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
# params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
actor_policy_version += 1
else:
params = params_queue.get()
......@@ -417,6 +417,7 @@ def rollout(
t.rewards[idx] = -next_reward[idx]
break
if args.time_log_freq:
for i in range(2):
deck_time = info['step_time'][idx][i]
deck_name = deck_names[info['deck'][idx][i]]
......@@ -474,14 +475,12 @@ def rollout(
else:
eval_stats = None
learn_opponent = False
payload = (
global_step,
update,
sharded_storage,
*sharded_data,
np.mean(params_queue_get_time),
learn_opponent,
eval_stats,
)
rollout_queue.put(payload)
......@@ -758,7 +757,6 @@ def main():
sharded_next_inputs: List,
sharded_next_main: List,
key: jax.random.PRNGKey,
learn_opponent: bool = False,
):
storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
......@@ -862,7 +860,6 @@ def main():
single_device_update,
axis_name="local_devices",
devices=global_learner_decices,
static_broadcasted_argnums=(7,),
)
params_queues = []
......@@ -906,7 +903,6 @@ def main():
update,
*sharded_data,
avg_params_queue_get_time,
learn_opponent,
eval_stats,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_data_list.append(sharded_data)
......@@ -929,13 +925,12 @@ def main():
agent_state,
*list(zip(*sharded_data_list)),
learner_keys,
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
# device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment