Skip to content

Commit 59e57a4

Browse files
authored
fix EP fsdp gradient divide factor (#1551)
issue pointed out in #1534 (comment) pytorch/pytorch#160285 solution given by @rakkit in #1534 (comment)
1 parent 2c8b594 commit 59e57a4

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from functools import cached_property
98

109
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1110

@@ -219,11 +218,18 @@ def pp_enabled(self):
219218
def ep_enabled(self):
220219
return self.ep > 1
221220

222-
@cached_property
221+
@property
222+
def fsdp_gradient_divide_factor(self) -> int:
223+
# This is needed for FSDP-sharded experts when Expert Parallel is enabled.
224+
# Although the FSDP sharding of experts is done on a mesh of a different size than
225+
# other parameters, the gradient division factor should be consistent with data.
226+
return self.dp_replicate * self.dp_shard * self.cp
227+
228+
@property
223229
def non_data_parallel_size(self):
224230
return self.cp * self.tp * self.pp
225231

226-
@cached_property
232+
@property
227233
def seq_len_divisor(self):
228234
# Sequence Parallel requires that seq_len be divisible by TP degree.
229235
# https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def parallelize_llama(
139139
if dp_mod_ep_mesh_dim_names
140140
else None
141141
),
142+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
142143
)
143144

144145
if parallel_dims.dp_replicate_enabled:
@@ -270,6 +271,7 @@ def apply_fsdp(
270271
cpu_offload: bool = False,
271272
reshard_after_forward_policy: str = "default",
272273
dp_mod_ep_mesh: DeviceMesh | None = None,
274+
gradient_divide_factor: int | None = None,
273275
):
274276
"""
275277
Apply data parallelism (via FSDP2) to the model.
@@ -322,6 +324,12 @@ def apply_fsdp(
322324
**fsdp_mod_ep_config,
323325
reshard_after_forward=reshard_after_forward,
324326
)
327+
# NOTE: # Although the FSDP sharding of experts is done on a mesh of
328+
# a different size than other parameters, the gradient division
329+
# factor should be consistent with data.
330+
transformer_block.moe.experts.set_gradient_divide_factor(
331+
gradient_divide_factor,
332+
)
325333

326334
fully_shard(
327335
transformer_block,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def parallelize_deepseekv3(
122122
if dp_mod_ep_mesh_dim_names
123123
else None
124124
),
125+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
125126
)
126127

127128
if parallel_dims.dp_replicate_enabled:

0 commit comments

Comments
 (0)