Commit b3160234 authored by sbl1996@126.com's avatar sbl1996@126.com

Fix checkpoint

parent 5cfbe008
......@@ -1187,8 +1187,12 @@ def main():
writer.add_scalar("losses/loss", loss, tb_global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = tb_global_step // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_steps = tb_global_step // 2**20
step_str = "M"
if ckpt_steps == 0:
ckpt_steps = tb_global_step // 2**10
step_str = "K"
ckpt_name = f"{timestamp}_{ckpt_steps}{step_str}.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if learner_policy_version >= args.num_updates:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment