Commit 7edeea56 authored by biluo.shen's avatar biluo.shen

Fix for nccl timeout

parent 1e253c14
......@@ -112,11 +112,11 @@ class Args:
"""tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 1000
save_interval: int = 500
"""the number of iterations to save the model"""
log_p: float = 0.1
log_p: float = 1.0
"""the probability of logging"""
port: int = 12355
port: int = 12356
"""the port to use for distributed training"""
# to be filled in runtime
......@@ -217,8 +217,6 @@ def run(local_rank, world_size):
if args.embedding_file:
agent.load_embeddings(embeddings)
# if args.compile:
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
......@@ -341,18 +339,21 @@ def run(local_rank, world_size):
continue
for idx, d in enumerate(next_done_):
if d and random.random() < args.log_p:
if d:
episode_length = info['l'][idx]
episode_reward = info['r'][idx]
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
avg_ep_returns.append(episode_reward)
winner = 0 if episode_reward > 0 else 1
avg_ep_returns.append(episode_reward)
avg_win_rates.append(1 - winner)
if random.random() < args.log_p:
n = 100
if random.random() < 10/n or iteration <= 2:
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > 100:
if len(avg_win_rates) > n:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = []
......
......@@ -32,6 +32,7 @@ def setup(backend, rank, world_size, port):
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
# print(f"Rank {rank} initialized")
def mp_start(run):
......@@ -39,7 +40,7 @@ def mp_start(run):
if world_size == 1:
run(local_rank=0, world_size=world_size)
else:
mp.set_start_method('spawn')
# mp.set_start_method('spawn')
children = []
for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment