Commit 907b51bc authored by sbl1996@126.com's avatar sbl1996@126.com

Use 2**20 as M (better for save)

parent d0a27000
...@@ -535,7 +535,7 @@ if __name__ == "__main__": ...@@ -535,7 +535,7 @@ if __name__ == "__main__":
f.write(flax.serialization.to_bytes(obj)) f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint( ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3) args.ckpt_dir, save_fn, n_saved=2)
# seeding # seeding
seed_offset = args.local_rank * 10000 seed_offset = args.local_rank * 10000
...@@ -871,7 +871,7 @@ if __name__ == "__main__": ...@@ -871,7 +871,7 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step) writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 1e6 M_steps = args.batch_size * learner_policy_version // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model" ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name) ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None: if args.gcs_bucket is not None:
......
...@@ -545,7 +545,7 @@ if __name__ == "__main__": ...@@ -545,7 +545,7 @@ if __name__ == "__main__":
f.write(flax.serialization.to_bytes(obj)) f.write(flax.serialization.to_bytes(obj))
ckpt_maneger = ModelCheckpoint( ckpt_maneger = ModelCheckpoint(
args.ckpt_dir, save_fn, n_saved=3) args.ckpt_dir, save_fn, n_saved=2)
# seeding # seeding
seed_offset = args.local_rank * 10000 seed_offset = args.local_rank * 10000
...@@ -897,7 +897,7 @@ if __name__ == "__main__": ...@@ -897,7 +897,7 @@ if __name__ == "__main__":
writer.add_scalar("losses/loss", loss, global_step) writer.add_scalar("losses/loss", loss, global_step)
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug: if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
M_steps = args.batch_size * learner_policy_version // 1e6 M_steps = args.batch_size * learner_policy_version // 2**20
ckpt_name = f"{timestamp}_{M_steps}M.flax_model" ckpt_name = f"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name) ckpt_maneger.save(unreplicated_params, ckpt_name)
if args.gcs_bucket is not None: if args.gcs_bucket is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment