File tree Expand file tree Collapse file tree 3 files changed +18
-3
lines changed Expand file tree Collapse file tree 3 files changed +18
-3
lines changed Original file line number Diff line number Diff line change 55# LICENSE file in the root directory of this source tree.
66
77from dataclasses import dataclass
8- from functools import cached_property
98
109from torch .distributed .device_mesh import DeviceMesh , init_device_mesh
1110
@@ -219,11 +218,18 @@ def pp_enabled(self):
219218 def ep_enabled (self ):
220219 return self .ep > 1
221220
222- @cached_property
221+ @property
222+ def fsdp_gradient_divide_factor (self ) -> int :
223+ # This is needed for FSDP-sharded experts when Expert Parallel is enabled.
224+ # Although the FSDP sharding of experts is done on a mesh of a different size than
225+ # other parameters, the gradient division factor should be consistent with data.
226+ return self .dp_replicate * self .dp_shard * self .cp
227+
228+ @property
223229 def non_data_parallel_size (self ):
224230 return self .cp * self .tp * self .pp
225231
226- @cached_property
232+ @property
227233 def seq_len_divisor (self ):
228234 # Sequence Parallel requires that seq_len be divisible by TP degree.
229235 # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001
Original file line number Diff line number Diff line change @@ -139,6 +139,7 @@ def parallelize_llama(
139139 if dp_mod_ep_mesh_dim_names
140140 else None
141141 ),
142+ gradient_divide_factor = parallel_dims .fsdp_gradient_divide_factor ,
142143 )
143144
144145 if parallel_dims .dp_replicate_enabled :
@@ -270,6 +271,7 @@ def apply_fsdp(
270271 cpu_offload : bool = False ,
271272 reshard_after_forward_policy : str = "default" ,
272273 dp_mod_ep_mesh : DeviceMesh | None = None ,
274+ gradient_divide_factor : int | None = None ,
273275):
274276 """
275277 Apply data parallelism (via FSDP2) to the model.
@@ -322,6 +324,12 @@ def apply_fsdp(
322324 ** fsdp_mod_ep_config ,
323325 reshard_after_forward = reshard_after_forward ,
324326 )
327+ # NOTE: # Although the FSDP sharding of experts is done on a mesh of
328+ # a different size than other parameters, the gradient division
329+ # factor should be consistent with data.
330+ transformer_block .moe .experts .set_gradient_divide_factor (
331+ gradient_divide_factor ,
332+ )
325333
326334 fully_shard (
327335 transformer_block ,
Original file line number Diff line number Diff line change @@ -122,6 +122,7 @@ def parallelize_deepseekv3(
122122 if dp_mod_ep_mesh_dim_names
123123 else None
124124 ),
125+ gradient_divide_factor = parallel_dims .fsdp_gradient_divide_factor ,
125126 )
126127
127128 if parallel_dims .dp_replicate_enabled :
You can’t perform that action at this time.
0 commit comments