Commit 1b81fc0d authored by sbl1996@126.com's avatar sbl1996@126.com

Fix orbax

parent d9334955
...@@ -1154,10 +1154,10 @@ def main(): ...@@ -1154,10 +1154,10 @@ def main():
*list(zip(*sharded_data_list)), *list(zip(*sharded_data_list)),
learner_keys, learner_keys,
) )
unreplicated_params = flax.jax_utils.unreplicate(get_state(agent_state)) new_state = get_state(agent_state)
params_queue_put_time = 0 params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids): for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id]) device_params = jax.device_put(flax.jax_utils.unreplicate(new_state), local_devices[d_id])
device_params["encoder"]['id_embed']["embedding"].value.block_until_ready() device_params["encoder"]['id_embed']["embedding"].value.block_until_ready()
params_queue_put_start = time.time() params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads): for thread_id in range(args.num_actor_threads):
...@@ -1197,7 +1197,8 @@ def main(): ...@@ -1197,7 +1197,8 @@ def main():
if learner_policy_version % args.save_interval == 0 and not args.debug: if learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20 M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M" ckpt_name = f"{timestamp}_{M_steps}M"
ckpt_maneger.save(unreplicated_params, ckpt_name) new_state = jax.tree.map(orbax.utils.fully_replicated_host_local_array_to_global_array, new_state)
ckpt_maneger.save(new_state, ckpt_name)
if learner_policy_version >= args.num_updates: if learner_policy_version >= args.num_updates:
break break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment