Commit 25e8c58b authored by biluo.shen's avatar biluo.shen

Improve nccl

parent 3c7a7080
...@@ -95,6 +95,8 @@ class Args: ...@@ -95,6 +95,8 @@ class Args:
compile: bool = True compile: bool = True
"""whether to use torch.compile to compile the model and functions""" """whether to use torch.compile to compile the model and functions"""
compile_mode: Optional[str] = None
"""the mode to use for torch.compile"""
torch_threads: Optional[int] = None torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size""" """the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = None env_threads: Optional[int] = None
...@@ -102,6 +104,10 @@ class Args: ...@@ -102,6 +104,10 @@ class Args:
tb_dir: str = "./runs" tb_dir: str = "./runs"
"""tensorboard log directory""" """tensorboard log directory"""
ckpt_dir: str = "./checkpoints"
"""checkpoint directory"""
save_interval: int = 100
"""the number of iterations to save the model"""
port: int = 12355 port: int = 12355
"""the port to use for distributed training""" """the port to use for distributed training"""
...@@ -140,6 +146,8 @@ def run(local_rank, world_size): ...@@ -140,6 +146,8 @@ def run(local_rank, world_size):
if args.world_size > 1: if args.world_size > 1:
setup(args.backend, local_rank, args.world_size, args.port) setup(args.backend, local_rank, args.world_size, args.port)
os.makedirs(args.ckpt_dir, exist_ok=True)
timestamp = int(time.time()) timestamp = int(time.time())
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer = None writer = None
...@@ -196,7 +204,7 @@ def run(local_rank, world_size): ...@@ -196,7 +204,7 @@ def run(local_rank, world_size):
agent.load_embeddings(embeddings) agent.load_embeddings(embeddings)
if args.compile: if args.compile:
agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode='reduce-overhead') agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
def masked_mean(x, valid): def masked_mean(x, valid):
...@@ -243,12 +251,11 @@ def run(local_rank, world_size): ...@@ -243,12 +251,11 @@ def run(local_rank, world_size):
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
reduce_gradidents(agent, args.world_size) reduce_gradidents(agent, args.world_size)
return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
if args.compile: if args.compile:
train_step = torch.compile(train_step, mode='reduce-overhead') train_step = torch.compile(train_step, mode=args.compile_mode)
def to_tensor(x, dtype=torch.float32): def to_tensor(x, dtype=torch.float32):
return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x)
...@@ -368,7 +375,6 @@ def run(local_rank, world_size): ...@@ -368,7 +375,6 @@ def run(local_rank, world_size):
old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \ old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss = \
train_step(agent, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], train_step(agent, mb_obs, b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds],
b_returns[mb_inds], b_values[mb_inds]) b_returns[mb_inds], b_values[mb_inds])
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
clipfracs.append(clipfrac.item()) clipfracs.append(clipfrac.item())
...@@ -387,6 +393,9 @@ def run(local_rank, world_size): ...@@ -387,6 +393,9 @@ def run(local_rank, world_size):
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
if local_rank == 0: if local_rank == 0:
if iteration % args.save_interval == 0:
torch.save(agent.state_dict(), os.path.join(args.ckpt_dir, f"ppo_{iteration}.pth"))
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
......
...@@ -27,6 +27,12 @@ def setup(backend, rank, world_size, port): ...@@ -27,6 +27,12 @@ def setup(backend, rank, world_size, port):
os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_PORT'] = str(port)
dist.init_process_group(backend, rank=rank, world_size=world_size) dist.init_process_group(backend, rank=rank, world_size=world_size)
# manual init nccl
x = torch.rand(4, device=f'cuda:{rank}')
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
def mp_start(run): def mp_start(run):
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment