Commit 157f440c authored by biluo.shen's avatar biluo.shen

add torchrun_setup

parent 92898eae
...@@ -24,6 +24,20 @@ def reduce_gradidents(params, world_size): ...@@ -24,6 +24,20 @@ def reduce_gradidents(params, world_size):
offset += param.numel() offset += param.numel()
def test_nccl(local_rank):
# manual init nccl
x = torch.rand(4, device=f'cuda:{local_rank}')
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
def torchrun_setup(backend, local_rank):
dist.init_process_group(
backend, timeout=datetime.timedelta(seconds=60 * 30))
test_nccl(local_rank)
def setup(backend, rank, world_size, port): def setup(backend, rank, world_size, port):
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_PORT'] = str(port)
...@@ -31,12 +45,7 @@ def setup(backend, rank, world_size, port): ...@@ -31,12 +45,7 @@ def setup(backend, rank, world_size, port):
backend, rank=rank, world_size=world_size, backend, rank=rank, world_size=world_size,
timeout=datetime.timedelta(seconds=60 * 30)) timeout=datetime.timedelta(seconds=60 * 30))
# manual init nccl test_nccl(rank)
x = torch.rand(4, device=f'cuda:{rank}')
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
# print(f"Rank {rank} initialized")
def mp_start(run): def mp_start(run):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment