Commit 7edeea56 authored by biluo.shen's avatar biluo.shen

Fix for nccl timeout

parent 1e253c14
...@@ -112,11 +112,11 @@ class Args: ...@@ -112,11 +112,11 @@ class Args:
"""tensorboard log directory""" """tensorboard log directory"""
ckpt_dir: str = "./checkpoints" ckpt_dir: str = "./checkpoints"
"""checkpoint directory""" """checkpoint directory"""
save_interval: int = 1000 save_interval: int = 500
"""the number of iterations to save the model""" """the number of iterations to save the model"""
log_p: float = 0.1 log_p: float = 1.0
"""the probability of logging""" """the probability of logging"""
port: int = 12355 port: int = 12356
"""the port to use for distributed training""" """the port to use for distributed training"""
# to be filled in runtime # to be filled in runtime
...@@ -217,8 +217,6 @@ def run(local_rank, world_size): ...@@ -217,8 +217,6 @@ def run(local_rank, world_size):
if args.embedding_file: if args.embedding_file:
agent.load_embeddings(embeddings) agent.load_embeddings(embeddings)
# if args.compile:
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8) scaler = GradScaler(enabled=args.fp16_train, init_scale=2 ** 8)
...@@ -341,22 +339,25 @@ def run(local_rank, world_size): ...@@ -341,22 +339,25 @@ def run(local_rank, world_size):
continue continue
for idx, d in enumerate(next_done_): for idx, d in enumerate(next_done_):
if d and random.random() < args.log_p: if d:
episode_length = info['l'][idx] episode_length = info['l'][idx]
episode_reward = info['r'][idx] episode_reward = info['r'][idx]
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
avg_ep_returns.append(episode_reward)
winner = 0 if episode_reward > 0 else 1 winner = 0 if episode_reward > 0 else 1
avg_ep_returns.append(episode_reward)
avg_win_rates.append(1 - winner) avg_win_rates.append(1 - winner)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > 100: if random.random() < args.log_p:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step) n = 100
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step) if random.random() < 10/n or iteration <= 2:
avg_win_rates = [] writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
avg_ep_returns = [] writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
print(f"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}")
if len(avg_win_rates) > n:
writer.add_scalar("charts/avg_win_rate", np.mean(avg_win_rates), global_step)
writer.add_scalar("charts/avg_ep_return", np.mean(avg_ep_returns), global_step)
avg_win_rates = []
avg_ep_returns = []
collect_time = time.time() - collect_start collect_time = time.time() - collect_start
......
...@@ -32,6 +32,7 @@ def setup(backend, rank, world_size, port): ...@@ -32,6 +32,7 @@ def setup(backend, rank, world_size, port):
dist.all_reduce(x, op=dist.ReduceOp.SUM) dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item() x.mean().item()
dist.barrier() dist.barrier()
# print(f"Rank {rank} initialized")
def mp_start(run): def mp_start(run):
...@@ -39,7 +40,7 @@ def mp_start(run): ...@@ -39,7 +40,7 @@ def mp_start(run):
if world_size == 1: if world_size == 1:
run(local_rank=0, world_size=world_size) run(local_rank=0, world_size=world_size)
else: else:
mp.set_start_method('spawn') # mp.set_start_method('spawn')
children = [] children = []
for i in range(world_size): for i in range(world_size):
subproc = mp.Process(target=run, args=(i, world_size)) subproc = mp.Process(target=run, args=(i, world_size))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment