Commit 157f440c authored by biluo.shen's avatar biluo.shen

add torchrun_setup

parent 92898eae
......@@ -24,6 +24,20 @@ def reduce_gradidents(params, world_size):
offset += param.numel()
def test_nccl(local_rank):
# manual init nccl
x = torch.rand(4, device=f'cuda:{local_rank}')
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
def torchrun_setup(backend, local_rank):
dist.init_process_group(
backend, timeout=datetime.timedelta(seconds=60 * 30))
test_nccl(local_rank)
def setup(backend, rank, world_size, port):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port)
......@@ -31,12 +45,7 @@ def setup(backend, rank, world_size, port):
backend, rank=rank, world_size=world_size,
timeout=datetime.timedelta(seconds=60 * 30))
# manual init nccl
x = torch.rand(4, device=f'cuda:{rank}')
dist.all_reduce(x, op=dist.ReduceOp.SUM)
x.mean().item()
dist.barrier()
# print(f"Rank {rank} initialized")
test_nccl(rank)
def mp_start(run):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment