@@ -157,6 +157,12 @@ class MoeConfig:
157157 norm_topk_prob : bool = True
158158 moe_every2 : bool = False
159159
160+ def has_moe (self ) -> bool :
161+ return self .num_experts > 1
162+
163+ def use_moe (self , i : int ) -> bool :
164+ return self .has_moe () and (self .moe_every2 is False or (self .moe_every2 and i % 2 == 1 ))
165+
160166
161167class FusedMultiTransformerConfig :
162168 def __init__ (
@@ -270,13 +276,8 @@ def __init__(
270276 self .rank_id = rank_id
271277 self .trans_qkvw = trans_qkvw
272278 self .ring_id = ring_id
273- if moe_config :
274- self .is_moe = True
275- self .moe_every2 = moe_config .moe_every2
276- self .moe_topk = moe_config .top_k
277- self .num_experts = moe_config .num_experts
278- else :
279- self .is_moe = False
279+
280+ self .moe_config = moe_config
280281
281282
282283class FusedMultiTransformerBase (Layer ):
@@ -432,16 +433,16 @@ def __init__(self, config: FusedMultiTransformerConfig):
432433 )
433434
434435 gate_weight = None
435- if config .is_moe is True and (( config . moe_every2 is True and i % 2 == 1 ) or config . moe_every2 is False ):
436+ if config .moe_config . use_moe ( i ):
436437 gate_weight = self .create_parameter (
437- shape = [config .embed_dim , config .num_experts ],
438+ shape = [config .embed_dim , self . config . moe_config .num_experts ],
438439 attr = gate_weight_attr ,
439440 dtype = "float32" ,
440441 is_bias = False ,
441442 default_initializer = paddle .nn .initializer .Constant (0 ),
442443 )
443444
444- if config .is_moe is True and (( config . moe_every2 is True and i % 2 == 1 ) or config . moe_every2 is False ):
445+ if config .moe_config . use_moe ( i ):
445446 ffn1_weight = self .create_parameter (
446447 shape = self .moe_ffn1_weight_shape ,
447448 attr = ffn1_weight_attr ,
@@ -458,13 +459,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
458459
459460 ffn1_bias = None
460461 if ffn1_bias_attr :
461- if self .config .is_moe is True and (
462- (self .config .moe_every2 is True and i % 2 == 1 ) or self .config .moe_every2 is False
463- ):
462+ if self .config .moe_config .use_moe (i ):
464463 ffn1_bias = self .create_parameter (
465- shape = [config .num_experts , self .dim_feedforward * 2 ]
464+ shape = [self . config . moe_config .num_experts , self .dim_feedforward * 2 ]
466465 if self .activation .endswith ("glu" )
467- else [config .num_experts , self .dim_feedforward ],
466+ else [self . config . moe_config .num_experts , self .dim_feedforward ],
468467 attr = ffn1_bias_attr ,
469468 dtype = self ._dtype ,
470469 is_bias = True ,
@@ -477,9 +476,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
477476 is_bias = True ,
478477 )
479478
480- if self .config .is_moe is True and (
481- (self .config .moe_every2 is True and i % 2 == 1 ) or self .config .moe_every2 is False
482- ):
479+ if self .config .moe_config .use_moe (i ):
483480 ffn2_weight = self .create_parameter (
484481 shape = self .moe_ffn2_weight_shape ,
485482 attr = ffn2_weight_attr ,
@@ -496,11 +493,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
496493
497494 ffn2_bias = None
498495 if ffn2_bias_attr :
499- if config .is_moe is True and (
500- (config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False
501- ):
496+ if config .moe_config .use_moe (i ):
502497 ffn2_bias = self .create_parameter (
503- shape = [config .num_experts , config .embed_dim ],
498+ shape = [self . config . moe_config .num_experts , config .embed_dim ],
504499 attr = ffn2_bias_attr ,
505500 dtype = self ._dtype ,
506501 is_bias = True ,
@@ -636,13 +631,13 @@ def init_weight_shape(self, config):
636631 )
637632 self .ffn2_weight_shape = [self .dim_feedforward , self .embed_dim ]
638633
639- if self .config .is_moe is True :
634+ if self .config .moe_config . has_moe () is True :
640635 self .moe_ffn1_weight_shape = (
641- [config .num_experts , self .embed_dim , self .dim_feedforward * 2 ]
636+ [self . config . moe_config .num_experts , self .embed_dim , self .dim_feedforward * 2 ]
642637 if self .activation .endswith ("glu" )
643- else [config .num_experts , self .embed_dim , self .dim_feedforward ]
638+ else [self . config . moe_config .num_experts , self .embed_dim , self .dim_feedforward ]
644639 )
645- self .moe_ffn2_weight_shape = [config .num_experts , self .dim_feedforward , self .embed_dim ]
640+ self .moe_ffn2_weight_shape = [self . config . moe_config .num_experts , self .dim_feedforward , self .embed_dim ]
646641
647642 def get_weight_create_dype (self ):
648643 return self ._dtype
@@ -817,7 +812,7 @@ def compute_fused_moe(self, tmp_out, i):
817812 None ,
818813 None ,
819814 "None" ,
820- self .config .moe_topk ,
815+ self .config .moe_config . top_k ,
821816 )
822817 return fused_moe_out
823818
@@ -962,9 +957,7 @@ def forward(
962957 # ffn layernorm
963958 tmp_out , residual_input = self .compute_ffn_layernorm (out_linear_out , residual_input , i )
964959
965- if self .config .is_moe is True and (
966- (self .config .moe_every2 is True and i % 2 == 1 ) or self .config .moe_every2 is False
967- ):
960+ if self .config .moe_config .use_moe (i ):
968961 # fused moe
969962 ffn2_out = self .compute_fused_moe (tmp_out , i )
970963
0 commit comments