Commit d9334955 authored by sbl1996@126.com's avatar sbl1996@126.com

orbax need to call save on all hosts

parent fcaf7bf7
......@@ -1194,7 +1194,7 @@ def main():
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), tb_global_step)
writer.add_scalar("losses/loss", loss, tb_global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
if learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M"
ckpt_maneger.save(unreplicated_params, ckpt_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment