@@ -316,15 +316,6 @@ def __init__(self, config: FusedMultiTransformerConfig):
316316 self .head_dim = config .embed_dim // config .num_heads
317317 assert self .head_dim * config .num_heads == config .embed_dim , "embed_dim must be divisible by num_heads"
318318
319- if config .is_moe :
320- self ._is_moe = config .is_moe
321- self ._moe_every2 = config .moe_every2
322- self ._moe_topk = config .moe_topk
323- else :
324- self ._is_moe = False
325- self ._moe_every2 = False
326- self ._moe_topk = 1
327-
328319 # tensor model parallel
329320 if config .nranks > 1 :
330321 assert config .ring_id != - 1
@@ -449,24 +440,10 @@ def __init__(self, config: FusedMultiTransformerConfig):
449440 is_bias = False ,
450441 default_initializer = paddle .nn .initializer .Constant (0 ),
451442 )
452- else :
453- gate_weight = self .create_parameter (
454- shape = [1 ],
455- attr = gate_weight_attr ,
456- dtype = "float32" ,
457- is_bias = False ,
458- default_initializer = paddle .nn .initializer .Constant (0 ),
459- )
460-
461- if config .is_moe is False :
462- gate_weight = None
463- self .gate_weights = None
464443
465444 if config .is_moe is True and ((config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False ):
466445 ffn1_weight = self .create_parameter (
467- shape = [config .num_experts , self .embed_dim , self .dim_feedforward * 2 ]
468- if self .activation .endswith ("glu" )
469- else [config .num_experts , self .embed_dim , self .dim_feedforward ],
446+ shape = self .moe_ffn1_weight_shape ,
470447 attr = ffn1_weight_attr ,
471448 dtype = self .create_params_type ,
472449 is_bias = False ,
@@ -481,8 +458,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
481458
482459 ffn1_bias = None
483460 if ffn1_bias_attr :
484- if config .is_moe is True and (
485- (config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False
461+ if self . config .is_moe is True and (
462+ (self . config .moe_every2 is True and i % 2 == 1 ) or self . config .moe_every2 is False
486463 ):
487464 ffn1_bias = self .create_parameter (
488465 shape = [config .num_experts , self .dim_feedforward * 2 ]
@@ -500,9 +477,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
500477 is_bias = True ,
501478 )
502479
503- if config .is_moe is True and ((config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False ):
480+ if self .config .is_moe is True and (
481+ (self .config .moe_every2 is True and i % 2 == 1 ) or self .config .moe_every2 is False
482+ ):
504483 ffn2_weight = self .create_parameter (
505- shape = [ config . num_experts , self .dim_feedforward , self . embed_dim ] ,
484+ shape = self .moe_ffn2_weight_shape ,
506485 attr = ffn2_weight_attr ,
507486 dtype = self .create_params_type ,
508487 is_bias = False ,
@@ -657,6 +636,14 @@ def init_weight_shape(self, config):
657636 )
658637 self .ffn2_weight_shape = [self .dim_feedforward , self .embed_dim ]
659638
639+ if self .config .is_moe is True :
640+ self .moe_ffn1_weight_shape = (
641+ [config .num_experts , self .embed_dim , self .dim_feedforward * 2 ]
642+ if self .activation .endswith ("glu" )
643+ else [config .num_experts , self .embed_dim , self .dim_feedforward ]
644+ )
645+ self .moe_ffn2_weight_shape = [config .num_experts , self .dim_feedforward , self .embed_dim ]
646+
660647 def get_weight_create_dype (self ):
661648 return self ._dtype
662649
@@ -830,7 +817,7 @@ def compute_fused_moe(self, tmp_out, i):
830817 None ,
831818 None ,
832819 "None" ,
833- self ._moe_topk ,
820+ self .config . moe_topk ,
834821 )
835822 return fused_moe_out
836823
@@ -975,7 +962,9 @@ def forward(
975962 # ffn layernorm
976963 tmp_out , residual_input = self .compute_ffn_layernorm (out_linear_out , residual_input , i )
977964
978- if self ._is_moe is True and ((self ._moe_every2 is True and i % 2 == 1 ) or self ._moe_every2 is False ):
965+ if self .config .is_moe is True and (
966+ (self .config .moe_every2 is True and i % 2 == 1 ) or self .config .moe_every2 is False
967+ ):
979968 # fused moe
980969 ffn2_out = self .compute_fused_moe (tmp_out , i )
981970
0 commit comments