Commit afbe4893 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix checkpoint

parent 4c0edbf8
......@@ -798,7 +798,7 @@ def main():
if args.checkpoint:
with open(args.checkpoint, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
agent_state = agent_state.replace(params=params)
agent_state = agent_state.replace(params=(params, params_rp))
print(f"loaded checkpoint from {args.checkpoint}")
agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
......@@ -1258,7 +1258,7 @@ def main():
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
ckpt_maneger.save(unreplicated_params[0], ckpt_name)
if args.gcs_bucket is not None:
lastest_path = ckpt_maneger.get_latest()
copy_path = lastest_path.with_name("latest" + lastest_path.suffix)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment