Skip to content

Commit 4463081

Browse files
committed
Move fc2_latent_proj into combine method to make sure tensors are the right shape
Signed-off-by: Deepak Narayanan <[email protected]>
1 parent 8a9f127 commit 4463081

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

megatron/core/transformer/moe/moe_layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
269269
"""
270270
output = self.token_dispatcher.token_combine(output)
271271
output = self.token_dispatcher.combine_postprocess(output)
272+
# Project the output back from latent dimension to hidden dimension after combine
273+
# in latent dimension.
274+
if self.config.moe_latent_size:
275+
output, _ = self.fc2_latent_proj(output)
272276
if shared_expert_output is not None:
273277
output = output + shared_expert_output
274278
return output
@@ -313,10 +317,6 @@ def custom_forward(hidden_states):
313317
output = output + mlp_bias
314318
mlp_bias = None
315319
output = self.combine(output, shared_expert_output)
316-
# Project the output back from latent dimension to hidden dimension after combine
317-
# in latent dimension.
318-
if self.config.moe_latent_size:
319-
output, _ = self.fc2_latent_proj(output)
320320

321321
return output, mlp_bias
322322

0 commit comments

Comments
 (0)