Commit 1b81fc0d authored by sbl1996@126.com's avatar sbl1996@126.com

Fix orbax

parent d9334955
......@@ -1154,10 +1154,10 @@ def main():
*list(zip(*sharded_data_list)),
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(get_state(agent_state))
new_state = get_state(agent_state)
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params = jax.device_put(flax.jax_utils.unreplicate(new_state), local_devices[d_id])
device_params["encoder"]['id_embed']["embedding"].value.block_until_ready()
params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads):
......@@ -1197,7 +1197,8 @@ def main():
if learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M"
ckpt_maneger.save(unreplicated_params, ckpt_name)
new_state = jax.tree.map(orbax.utils.fully_replicated_host_local_array_to_global_array, new_state)
ckpt_maneger.save(new_state, ckpt_name)
if learner_policy_version >= args.num_updates:
break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment