Skip to content

Commit 8751cff

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Add input argument dist_group into DDPDistributor
Summary: By leveraging the fixed API `distributed.new_subrgoups(group=)` introduced in pytorch/pytorch#152765, this diff adds an input argument `dist_group` for enabling `DDPDistributor` to operate under `dist_group` instead of always default to the whole world of ranks. This new input argument will also be used to simplfiy `HSDPDistributor` and `HybridShardDistributor` by passing in each replicate group into this new `DDPDistributor`. Reviewed By: wz337 Differential Revision: D74971276 fbshipit-source-id: 0f0a62abc354806859fe64a886f89457fed5607e
1 parent af62244 commit 8751cff

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

distributed_shampoo/distributor/gpu_tests/shampoo_ddp_distributor_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import abc
1313
import contextlib
14+
import os
1415
import re
1516
import unittest
1617

@@ -833,6 +834,16 @@ class ShampooDDPDistributorCPUTest(AbstractTest.ShampooDDPDistributorDeviceTest)
833834
def _device(self) -> torch.device:
834835
return torch.device("cpu")
835836

837+
def setUp(self) -> None:
838+
# Set TORCH_GLOO_LAZY_INIT to prevent timeout in test_empty_local_blocked_params.
839+
os.environ["TORCH_GLOO_LAZY_INIT"] = "1"
840+
super().setUp()
841+
842+
def tearDown(self) -> None:
843+
# Clean up the environment variable after the test.
844+
del os.environ["TORCH_GLOO_LAZY_INIT"]
845+
return super().tearDown()
846+
836847

837848
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
838849
class ShampooDDPDistributorGPUTest(AbstractTest.ShampooDDPDistributorDeviceTest):

distributed_shampoo/distributor/shampoo_ddp_distributor.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DISTRIBUTED_CONFIG,
2323
PARAMS,
2424
)
25+
from distributed_shampoo.utils.commons import batched
2526
from distributed_shampoo.utils.shampoo_utils import (
2627
compress_list,
2728
distribute_buffer_sizes,
@@ -30,6 +31,7 @@
3031
)
3132
from torch import Tensor
3233
from torch.distributed import tensor as dtensor
34+
from torch.distributed.device_mesh import _mesh_resources
3335

3436
from torch.distributed.tensor import zeros as dtensor_zeros
3537

@@ -131,10 +133,15 @@ class DDPDistributor(DistributorInterface):
131133
132134
Args:
133135
param_group (dict[str, Any]): Parameter group containing parameters.
136+
dist_group (dist.ProcessGroup | None): Optional process group for distributed operations. (Default: dist.distributed_c10d.GroupMember.WORLD)
134137
135138
"""
136139

137-
def __init__(self, param_group: dict[str, Any]) -> None:
140+
def __init__(
141+
self,
142+
param_group: dict[str, Any],
143+
dist_group: dist.ProcessGroup | None = dist.distributed_c10d.GroupMember.WORLD,
144+
) -> None:
138145
super().__init__(param_group)
139146
distributed_config: DDPDistributedConfig = param_group[DISTRIBUTED_CONFIG]
140147

@@ -145,7 +152,7 @@ def __init__(self, param_group: dict[str, Any]) -> None:
145152

146153
# Check num_trainers_per_group and get global and group sizes.
147154
# NOTE: If num_trainers_per_group = -1, then we use the global world size.
148-
self._global_size: int = dist.get_world_size()
155+
self._global_size: int = dist.get_world_size(group=dist_group)
149156

150157
if distributed_config.num_trainers_per_group == -1:
151158
logger.info(
@@ -161,8 +168,8 @@ def __init__(self, param_group: dict[str, Any]) -> None:
161168
self._communicate_params: bool = distributed_config.communicate_params
162169

163170
# Initialize _dist_group and _group_rank.
164-
self._dist_group: dist.ProcessGroup | None = dist.new_subgroups(
165-
group_size=self._group_size
171+
self._dist_group: dist.ProcessGroup = dist.new_subgroups(
172+
group_size=self._group_size, group=dist_group
166173
)[0]
167174
group_rank: int = dist.get_rank(group=self._dist_group)
168175

@@ -494,19 +501,19 @@ def _allocate_zeros_distributed_tensor(
494501
out (Tensor): Desired DTensor.
495502
496503
"""
497-
device_mesh_ranks = tuple(
498-
range(
499-
group_source_rank % self._group_size,
500-
self._global_size,
501-
self._group_size,
502-
)
504+
ranks_in_group = dist.get_process_group_ranks(group=self._dist_group)
505+
device_mesh_2d = get_device_mesh(
506+
device_type=device.type,
507+
mesh=tuple(batched(iterable=ranks_in_group, n=self._group_size)),
508+
mesh_dim_names=("replicate", "shard"),
503509
)
510+
replicated_submesh = _mesh_resources._get_all_submeshes(
511+
device_mesh_2d, "replicate"
512+
)[group_source_rank]
504513

505514
return dtensor_zeros(
506515
size,
507516
dtype=dtype,
508-
device_mesh=get_device_mesh(
509-
device_type=device.type, mesh=device_mesh_ranks
510-
),
517+
device_mesh=replicated_submesh,
511518
placements=[dtensor.Replicate()],
512519
)

0 commit comments

Comments
 (0)