Commit eba0e134 authored by biluo.shen's avatar biluo.shen

Change ckpt_dir

parent 25e8c58b
This diff is collapsed.
......@@ -146,8 +146,6 @@ def run(local_rank, world_size):
if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port)
os.makedirs(args.ckpt_dir, exist_ok=True)
timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None
......@@ -159,6 +157,9 @@ def run(local_rank, world_size):
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
ckpt_dir = os.path.join(args.ckpt_dir, run_name)
os.makedirs(ckpt_dir, exist_ok=True)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
......@@ -394,7 +395,7 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(args.ckpt_dir, f"ppo_{iteration}.pth"))
torch.save(agent.state_dict(), os.path.join(ckpt_dir, f"{iteration}.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment