Skip to content

Commit 8a4147a

Browse files
committed
try to minimize diff
1 parent 2e5f2fb commit 8a4147a

File tree

1 file changed

+45
-45
lines changed

1 file changed

+45
-45
lines changed

torchtitan/models/moe/moe.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def forward(
354354
)
355355

356356

357-
class MoEOld(nn.Module):
357+
class MoE(nn.Module):
358358
def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
359359
super().__init__()
360360

@@ -430,7 +430,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
430430
with torch.no_grad():
431431
self.tokens_per_expert.add_(num_tokens_per_expert)
432432

433-
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
433+
# top_scores shape (bs*slen,top_k)
434+
# token_indices_experts_sorted shape (bs*slen*top_k,)
434435
# num_tokens_per_expert shape (num_experts,)
435436
# NOTE: the reason we need to compute num_tokens_per_expert again is:
436437
# 1st computation in router is to update self.tokens_per_expert
@@ -443,17 +444,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
443444
token_indices_experts_sorted,
444445
num_tokens_per_expert,
445446
) = self.reorderer(top_scores, selected_experts_indices)
446-
token_indices_experts_sorted = (
447-
token_indices_experts_sorted // self.reorderer.top_k
448-
)
449447

450448
# shape (bs*slen*top_k, dim)
451-
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
452-
-1, 1
453-
).expand(-1, dim)
454-
455-
# shape (bs*slen*top_k, dim)
456-
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
449+
routed_input = x[token_indices_experts_sorted // self.router.top_k]
457450

458451
if self.score_before_experts:
459452
routed_input = (
@@ -467,21 +460,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
467460
# shared expert
468461
# Note: we execute the shared expert before scoring the output of the routed expert
469462
# to "implicitly" overlap the shared expert compute with token combine communication
470-
if self.shared_experts is not None:
471-
out = self.shared_experts(x)
472-
else:
473-
out = torch.zeros_like(x)
463+
out = self.shared_experts(x) if self.shared_experts is not None else None
474464

475465
if not self.score_before_experts:
476-
routed_output = (
477-
routed_output.to(torch.float32)
478-
* top_scores_experts_sorted.reshape(-1, 1)
479-
).to(x.dtype)
466+
# Unsort scores and routed outputs. Also save some allocations: store unsorted scores
467+
# and outputs in top_scores and routed_input, respectively.
468+
top_scores = top_scores.flatten()
469+
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
470+
routed_input[token_indices_experts_sorted] = routed_output
471+
routed_input = routed_input.reshape(bs * slen, -1, dim)
472+
top_scores = top_scores.reshape(bs * slen, 1, -1)
473+
out_experts = (
474+
torch.bmm(top_scores, routed_input.float()).to(x.dtype).squeeze(1)
475+
)
476+
else:
477+
# Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
478+
routed_input[token_indices_experts_sorted] = routed_output
479+
out_experts = routed_input.reshape(bs * slen, -1, dim).sum(dim=1)
480480

481-
out = out.scatter_add(
482-
dim=0, index=token_indices_experts_sorted, src=routed_output
483-
)
484-
out = out.reshape(bs, slen, dim)
481+
if out is None:
482+
return out_experts.reshape(bs, slen, dim)
483+
out = (out + out_experts).reshape(bs, slen, dim)
485484
return out
486485

487486
def init_weights(
@@ -504,7 +503,7 @@ def init_weights(
504503
)
505504

506505

507-
class MoE(nn.Module):
506+
class MoEOld(nn.Module):
508507
def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
509508
super().__init__()
510509

@@ -580,8 +579,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
580579
with torch.no_grad():
581580
self.tokens_per_expert.add_(num_tokens_per_expert)
582581

583-
# top_scores shape (bs*slen,top_k)
584-
# token_indices_experts_sorted shape (bs*slen*top_k,)
582+
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
585583
# num_tokens_per_expert shape (num_experts,)
586584
# NOTE: the reason we need to compute num_tokens_per_expert again is:
587585
# 1st computation in router is to update self.tokens_per_expert
@@ -594,9 +592,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
594592
token_indices_experts_sorted,
595593
num_tokens_per_expert,
596594
) = self.reorderer(top_scores, selected_experts_indices)
595+
token_indices_experts_sorted = (
596+
token_indices_experts_sorted // self.reorderer.top_k
597+
)
597598

598599
# shape (bs*slen*top_k, dim)
599-
routed_input = x[token_indices_experts_sorted // self.router.top_k]
600+
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
601+
-1, 1
602+
).expand(-1, dim)
603+
604+
# shape (bs*slen*top_k, dim)
605+
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
600606

601607
if self.score_before_experts:
602608
routed_input = (
@@ -610,27 +616,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
610616
# shared expert
611617
# Note: we execute the shared expert before scoring the output of the routed expert
612618
# to "implicitly" overlap the shared expert compute with token combine communication
613-
out = self.shared_experts(x) if self.shared_experts is not None else None
619+
if self.shared_experts is not None:
620+
out = self.shared_experts(x)
621+
else:
622+
out = torch.zeros_like(x)
614623

615624
if not self.score_before_experts:
616-
# Unsort scores and routed outputs. Also save some allocations: store unsorted scores
617-
# and outputs in top_scores and routed_input, respectively.
618-
top_scores = top_scores.flatten()
619-
top_scores[token_indices_experts_sorted] = top_scores_experts_sorted
620-
routed_input[token_indices_experts_sorted] = routed_output
621-
routed_input = routed_input.reshape(bs * slen, -1, dim)
622-
top_scores = top_scores.reshape(bs * slen, 1, -1)
623-
out_experts = (
624-
torch.bmm(top_scores, routed_input.float()).to(x.dtype).squeeze(1)
625-
)
626-
else:
627-
# Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
628-
routed_input[token_indices_experts_sorted] = routed_output
629-
out_experts = routed_input.reshape(bs * slen, -1, dim).sum(dim=1)
625+
routed_output = (
626+
routed_output.to(torch.float32)
627+
* top_scores_experts_sorted.reshape(-1, 1)
628+
).to(x.dtype)
630629

631-
if out is None:
632-
return out_experts.reshape(bs, slen, dim)
633-
out = (out + out_experts).reshape(bs, slen, dim)
630+
out = out.scatter_add(
631+
dim=0, index=token_indices_experts_sorted, src=routed_output
632+
)
633+
out = out.reshape(bs, slen, dim)
634634
return out
635635

636636
def init_weights(

0 commit comments

Comments
 (0)