Commit 907b51bc authored by sbl1996@126.com's avatar sbl1996@126.com

Use 2**20 as M (better for save)

parent d0a27000
......@@ -535,7 +535,7 @@ if __name__ == "__main__":
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3)
args.ckpt_dir, save_fn, n_saved=2)
# seeding
seed_offset = args.local_rank * 10000
......@@ -871,7 +871,7 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 1e6
M_steps = args.batch_size * learner_policy_version // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
......
......@@ -545,7 +545,7 @@ if __name__ == "__main__":
f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3)
args.ckpt_dir, save_fn, n_saved=2)
# seeding
seed_offset = args.local_rank * 10000
......@@ -897,7 +897,7 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 1e6
M_steps = args.batch_size * learner_policy_version // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment