Commit 878f6a9e authored by sbl1996@126.com's avatar sbl1996@126.com

Fix ckpt

parent b3160234
......@@ -1189,9 +1189,6 @@ def main():
if args.local_rank == 0 and learner_policy_version % args.save_interval == 0 and not args.debug:
ckpt_steps = tb_global_step // 2**20
step_str = "M"
if ckpt_steps == 0:
ckpt_steps = tb_global_step // 2**10
step_str = "K"
ckpt_name = f"{timestamp}_{ckpt_steps}{step_str}.flax_model"
ckpt_maneger.save(unreplicated_params, ckpt_name)
......
......@@ -40,11 +40,13 @@ class ModelCheckpoint(object):
print(f"Saved model to {path}")
if len(self._saved) > self._n_saved:
path = self._saved.pop(0)
if path.is_dir():
shutil.rmtree(path)
else:
os.remove(path)
to_remove = self._saved.pop(0)
if to_remove != path:
if to_remove.is_dir():
shutil.rmtree(to_remove)
else:
if to_remove.exists():
os.remove(to_remove)
def get_latest(self):
path = self._saved[-1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment