@@ -354,7 +354,7 @@ def forward(
354354 )
355355
356356
357- class MoEOld (nn .Module ):
357+ class MoE (nn .Module ):
358358 def __init__ (self , moe_args : MoEArgs , dim : int , hidden_dim : int ):
359359 super ().__init__ ()
360360
@@ -430,7 +430,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
430430 with torch .no_grad ():
431431 self .tokens_per_expert .add_ (num_tokens_per_expert )
432432
433- # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
433+ # top_scores shape (bs*slen,top_k)
434+ # token_indices_experts_sorted shape (bs*slen*top_k,)
434435 # num_tokens_per_expert shape (num_experts,)
435436 # NOTE: the reason we need to compute num_tokens_per_expert again is:
436437 # 1st computation in router is to update self.tokens_per_expert
@@ -443,17 +444,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
443444 token_indices_experts_sorted ,
444445 num_tokens_per_expert ,
445446 ) = self .reorderer (top_scores , selected_experts_indices )
446- token_indices_experts_sorted = (
447- token_indices_experts_sorted // self .reorderer .top_k
448- )
449447
450448 # shape (bs*slen*top_k, dim)
451- token_indices_experts_sorted = token_indices_experts_sorted .reshape (
452- - 1 , 1
453- ).expand (- 1 , dim )
454-
455- # shape (bs*slen*top_k, dim)
456- routed_input = torch .gather (x , dim = 0 , index = token_indices_experts_sorted )
449+ routed_input = x [token_indices_experts_sorted // self .router .top_k ]
457450
458451 if self .score_before_experts :
459452 routed_input = (
@@ -467,21 +460,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
467460 # shared expert
468461 # Note: we execute the shared expert before scoring the output of the routed expert
469462 # to "implicitly" overlap the shared expert compute with token combine communication
470- if self .shared_experts is not None :
471- out = self .shared_experts (x )
472- else :
473- out = torch .zeros_like (x )
463+ out = self .shared_experts (x ) if self .shared_experts is not None else None
474464
475465 if not self .score_before_experts :
476- routed_output = (
477- routed_output .to (torch .float32 )
478- * top_scores_experts_sorted .reshape (- 1 , 1 )
479- ).to (x .dtype )
466+ # Unsort scores and routed outputs. Also save some allocations: store unsorted scores
467+ # and outputs in top_scores and routed_input, respectively.
468+ top_scores = top_scores .flatten ()
469+ top_scores [token_indices_experts_sorted ] = top_scores_experts_sorted
470+ routed_input [token_indices_experts_sorted ] = routed_output
471+ routed_input = routed_input .reshape (bs * slen , - 1 , dim )
472+ top_scores = top_scores .reshape (bs * slen , 1 , - 1 )
473+ out_experts = (
474+ torch .bmm (top_scores , routed_input .float ()).to (x .dtype ).squeeze (1 )
475+ )
476+ else :
477+ # Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
478+ routed_input [token_indices_experts_sorted ] = routed_output
479+ out_experts = routed_input .reshape (bs * slen , - 1 , dim ).sum (dim = 1 )
480480
481- out = out .scatter_add (
482- dim = 0 , index = token_indices_experts_sorted , src = routed_output
483- )
484- out = out .reshape (bs , slen , dim )
481+ if out is None :
482+ return out_experts .reshape (bs , slen , dim )
483+ out = (out + out_experts ).reshape (bs , slen , dim )
485484 return out
486485
487486 def init_weights (
@@ -504,7 +503,7 @@ def init_weights(
504503 )
505504
506505
507- class MoE (nn .Module ):
506+ class MoEOld (nn .Module ):
508507 def __init__ (self , moe_args : MoEArgs , dim : int , hidden_dim : int ):
509508 super ().__init__ ()
510509
@@ -580,8 +579,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
580579 with torch .no_grad ():
581580 self .tokens_per_expert .add_ (num_tokens_per_expert )
582581
583- # top_scores shape (bs*slen,top_k)
584- # token_indices_experts_sorted shape (bs*slen*top_k,)
582+ # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
585583 # num_tokens_per_expert shape (num_experts,)
586584 # NOTE: the reason we need to compute num_tokens_per_expert again is:
587585 # 1st computation in router is to update self.tokens_per_expert
@@ -594,9 +592,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
594592 token_indices_experts_sorted ,
595593 num_tokens_per_expert ,
596594 ) = self .reorderer (top_scores , selected_experts_indices )
595+ token_indices_experts_sorted = (
596+ token_indices_experts_sorted // self .reorderer .top_k
597+ )
597598
598599 # shape (bs*slen*top_k, dim)
599- routed_input = x [token_indices_experts_sorted // self .router .top_k ]
600+ token_indices_experts_sorted = token_indices_experts_sorted .reshape (
601+ - 1 , 1
602+ ).expand (- 1 , dim )
603+
604+ # shape (bs*slen*top_k, dim)
605+ routed_input = torch .gather (x , dim = 0 , index = token_indices_experts_sorted )
600606
601607 if self .score_before_experts :
602608 routed_input = (
@@ -610,27 +616,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
610616 # shared expert
611617 # Note: we execute the shared expert before scoring the output of the routed expert
612618 # to "implicitly" overlap the shared expert compute with token combine communication
613- out = self .shared_experts (x ) if self .shared_experts is not None else None
619+ if self .shared_experts is not None :
620+ out = self .shared_experts (x )
621+ else :
622+ out = torch .zeros_like (x )
614623
615624 if not self .score_before_experts :
616- # Unsort scores and routed outputs. Also save some allocations: store unsorted scores
617- # and outputs in top_scores and routed_input, respectively.
618- top_scores = top_scores .flatten ()
619- top_scores [token_indices_experts_sorted ] = top_scores_experts_sorted
620- routed_input [token_indices_experts_sorted ] = routed_output
621- routed_input = routed_input .reshape (bs * slen , - 1 , dim )
622- top_scores = top_scores .reshape (bs * slen , 1 , - 1 )
623- out_experts = (
624- torch .bmm (top_scores , routed_input .float ()).to (x .dtype ).squeeze (1 )
625- )
626- else :
627- # Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
628- routed_input [token_indices_experts_sorted ] = routed_output
629- out_experts = routed_input .reshape (bs * slen , - 1 , dim ).sum (dim = 1 )
625+ routed_output = (
626+ routed_output .to (torch .float32 )
627+ * top_scores_experts_sorted .reshape (- 1 , 1 )
628+ ).to (x .dtype )
630629
631- if out is None :
632- return out_experts .reshape (bs , slen , dim )
633- out = (out + out_experts ).reshape (bs , slen , dim )
630+ out = out .scatter_add (
631+ dim = 0 , index = token_indices_experts_sorted , src = routed_output
632+ )
633+ out = out .reshape (bs , slen , dim )
634634 return out
635635
636636 def init_weights (
0 commit comments