1818from paddle .framework import LayerHelper , core , in_dynamic_mode
1919from paddle .incubate .nn .functional import (
2020 fused_layer_norm ,
21+ fused_moe ,
2122 fused_rms_norm ,
2223 masked_multihead_attention ,
2324 variable_length_memory_efficient_attention ,
@@ -167,6 +168,7 @@ def __init__(
167168 linear_bias_attrs = None ,
168169 ffn_ln_scale_attrs = None ,
169170 ffn_ln_bias_attrs = None ,
171+ gate_weight_attrs = None ,
170172 ffn1_weight_attrs = None ,
171173 ffn1_weight_scale_attrs = None ,
172174 ffn1_bias_attrs = None ,
@@ -197,12 +199,15 @@ def __init__(
197199 kv_num_heads = - 1 ,
198200 cachekv_int8_type = None ,
199201 rank_id = - 1 ,
202+ is_moe = False ,
203+ moe_every2 = False ,
204+ moe_topk = 2 ,
205+ num_experts = 1 ,
200206 ):
201207 self .embed_dim = embed_dim
202208 self .num_heads = num_heads
203209 if kv_num_heads > 0 :
204210 self .kv_num_heads = kv_num_heads
205- assert nranks == 1 , "nranks should be 1 for kv_num_heads > 0"
206211 else :
207212 self .kv_num_heads = num_heads
208213 self .dim_feedforward = dim_feedforward
@@ -222,6 +227,7 @@ def __init__(
222227 self .linear_bias_attrs = linear_bias_attrs
223228 self .ffn_ln_scale_attrs = ffn_ln_scale_attrs
224229 self .ffn_ln_bias_attrs = ffn_ln_bias_attrs
230+ self .gate_weight_attrs = gate_weight_attrs
225231 self .ffn1_weight_attrs = ffn1_weight_attrs
226232 self .ffn1_weight_scale_attrs = ffn1_weight_scale_attrs
227233 self .ffn1_bias_attrs = ffn1_bias_attrs
@@ -255,6 +261,10 @@ def __init__(
255261 self .rank_id = rank_id
256262 self .trans_qkvw = trans_qkvw
257263 self .ring_id = ring_id
264+ self .is_moe = is_moe
265+ self .moe_every2 = moe_every2
266+ self .moe_topk = moe_topk
267+ self .num_experts = num_experts
258268
259269
260270class FusedMultiTransformerBase (Layer ):
@@ -294,6 +304,10 @@ def __init__(self, config: FusedMultiTransformerConfig):
294304 self .head_dim = config .embed_dim // config .num_heads
295305 assert self .head_dim * config .num_heads == config .embed_dim , "embed_dim must be divisible by num_heads"
296306
307+ self ._is_moe = config .is_moe
308+ self ._moe_every2 = config .moe_every2
309+ self ._moe_topk = config .moe_topk
310+
297311 # tensor model parallel
298312 if config .nranks > 1 :
299313 assert config .ring_id != - 1
@@ -316,6 +330,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
316330 self .qkv_weights , self .qkv_biases = [], []
317331 self .linear_weights , self .linear_biases = [], []
318332 self .ffn_ln_scales , self .ffn_ln_biases = [], []
333+ self .gate_weights = []
319334 self .ffn1_weights , self .ffn1_biases = [], []
320335 self .ffn2_weights , self .ffn2_biases = [], []
321336 self .cache_k_scales , self .cache_v_scales = [], []
@@ -327,6 +342,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
327342 qkv_weight_attr = self .get_attr (config .qkv_weight_attrs , i )
328343
329344 qkv_bias_attr = self .get_attr (config .qkv_bias_attrs , i )
345+ gate_weight_attr = self .get_attr (config .gate_weight_attrs , i )
330346 linear_weight_attr = self .get_attr (config .linear_weight_attrs , i )
331347 linear_bias_attr = self .get_attr (config .linear_bias_attrs , i )
332348
@@ -407,37 +423,99 @@ def __init__(self, config: FusedMultiTransformerConfig):
407423 dtype = self ._norm_weight_dtype ,
408424 )
409425
410- ffn1_weight = self .create_parameter (
411- shape = self .ffn1_weight_shape ,
412- attr = ffn1_weight_attr ,
413- dtype = self .create_params_type ,
414- is_bias = False ,
415- )
426+ gate_weight = None
427+ if config .is_moe is True and ((config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False ):
428+ gate_weight = self .create_parameter (
429+ shape = [config .embed_dim , config .num_experts ],
430+ attr = gate_weight_attr ,
431+ dtype = "float32" ,
432+ is_bias = False ,
433+ default_initializer = paddle .nn .initializer .Constant (0 ),
434+ )
435+ else :
436+ gate_weight = self .create_parameter (
437+ shape = [1 ],
438+ attr = gate_weight_attr ,
439+ dtype = "float32" ,
440+ is_bias = False ,
441+ default_initializer = paddle .nn .initializer .Constant (0 ),
442+ )
443+
444+ if config .is_moe is False :
445+ gate_weight = None
446+ self .gate_weights = None
447+
448+ if config .is_moe is True and ((config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False ):
449+ ffn1_weight = self .create_parameter (
450+ shape = [config .num_experts , self .embed_dim , self .dim_feedforward * 2 ]
451+ if self .activation .endswith ("glu" )
452+ else [config .num_experts , self .embed_dim , self .dim_feedforward ],
453+ attr = ffn1_weight_attr ,
454+ dtype = self .create_params_type ,
455+ is_bias = False ,
456+ )
457+ else :
458+ ffn1_weight = self .create_parameter (
459+ shape = self .ffn1_weight_shape ,
460+ attr = ffn1_weight_attr ,
461+ dtype = self .create_params_type ,
462+ is_bias = False ,
463+ )
416464
417465 ffn1_bias = None
418466 if ffn1_bias_attr :
419- ffn1_bias = self .create_parameter (
420- shape = [dim_feedforward * 2 ] if config .activation .endswith ("glu" ) else [dim_feedforward ],
421- attr = ffn1_bias_attr ,
422- dtype = self ._dtype ,
423- is_bias = True ,
467+ if config .is_moe is True and (
468+ (config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False
469+ ):
470+ ffn1_bias = self .create_parameter (
471+ shape = [config .num_experts , self .dim_feedforward * 2 ]
472+ if self .activation .endswith ("glu" )
473+ else [config .num_experts , self .dim_feedforward ],
474+ attr = ffn1_bias_attr ,
475+ dtype = self ._dtype ,
476+ is_bias = True ,
477+ )
478+ else :
479+ ffn1_bias = self .create_parameter (
480+ shape = [dim_feedforward * 2 ] if self .activation .endswith ("glu" ) else [dim_feedforward ],
481+ attr = ffn1_bias_attr ,
482+ dtype = self ._dtype ,
483+ is_bias = True ,
484+ )
485+
486+ if config .is_moe is True and ((config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False ):
487+ ffn2_weight = self .create_parameter (
488+ shape = [config .num_experts , self .dim_feedforward , self .embed_dim ],
489+ attr = ffn2_weight_attr ,
490+ dtype = self .create_params_type ,
491+ is_bias = False ,
492+ )
493+ else :
494+ ffn2_weight = self .create_parameter (
495+ shape = self .ffn2_weight_shape ,
496+ attr = ffn2_weight_attr ,
497+ dtype = self .create_params_type ,
498+ is_bias = False ,
424499 )
425-
426- ffn2_weight = self .create_parameter (
427- shape = self .ffn2_weight_shape ,
428- attr = ffn2_weight_attr ,
429- dtype = self .create_params_type ,
430- is_bias = False ,
431- )
432500
433501 ffn2_bias = None
434502 if ffn2_bias_attr :
435- ffn2_bias = self .create_parameter (
436- shape = [config .embed_dim ],
437- attr = ffn2_bias_attr ,
438- dtype = self ._dtype ,
439- is_bias = True ,
440- )
503+ if config .is_moe is True and (
504+ (config .moe_every2 is True and i % 2 == 1 ) or config .moe_every2 is False
505+ ):
506+ ffn2_bias = self .create_parameter (
507+ shape = [config .num_experts , config .embed_dim ],
508+ attr = ffn2_bias_attr ,
509+ dtype = self ._dtype ,
510+ is_bias = True ,
511+ )
512+ else :
513+ ffn2_bias = self .create_parameter (
514+ shape = [config .embed_dim ],
515+ attr = ffn2_bias_attr ,
516+ dtype = self ._dtype ,
517+ is_bias = True ,
518+ )
441519
442520 cache_k_scale = None
443521 if cache_k_scale_attr :
@@ -495,6 +573,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
495573
496574 self .ffn_ln_scales .append (ffn_ln_scale )
497575 self .ffn_ln_biases .append (ffn_ln_bias )
576+ if gate_weight is not None :
577+ self .gate_weights .append (gate_weight )
498578 self .ffn1_weights .append (ffn1_weight )
499579 self .ffn1_biases .append (ffn1_bias )
500580 self .ffn2_weights .append (ffn2_weight )
@@ -713,6 +793,28 @@ def compute_ffn_layernorm(self, out_linear_out, residual_input, i):
713793
714794 return tmp_out , residual_input
715795
796+ def compute_fused_moe (self , tmp_out , i ):
797+ # todo[xinhw]: make bias optional
798+ if self .ffn1_biases [i ] is None :
799+ shape1 = paddle .to_tensor ([self .ffn1_weights [i ].shape [0 ], 1 , self .dim_feedforward * 2 ])
800+ self .ffn1_biases [i ] = paddle .zeros (shape1 )
801+ if self .ffn2_biases [i ] is None :
802+ shape2 = paddle .to_tensor ([self .ffn1_weights [i ].shape [0 ], 1 , self .embed_dim ])
803+ self .ffn2_biases [i ] = paddle .zeros (shape2 )
804+ fused_moe_out = fused_moe (
805+ tmp_out ,
806+ self .gate_weights [i ],
807+ self .ffn1_weights [i ],
808+ self .ffn1_biases [i ],
809+ self .ffn2_weights [i ],
810+ self .ffn2_biases [i ],
811+ None ,
812+ None ,
813+ "None" ,
814+ self ._moe_topk ,
815+ )
816+ return fused_moe_out
817+
716818 def compute_activation (self , ffn1_out , i ):
717819 return fused_act_bias_wrapper (ffn1_out , self .ffn1_biases [i ], act_method = self .activation )
718820
@@ -854,12 +956,17 @@ def forward(
854956 # ffn layernorm
855957 tmp_out , residual_input = self .compute_ffn_layernorm (out_linear_out , residual_input , i )
856958
857- # ffn1 matmul
858- ffn1_out = self .compute_ffn1 (tmp_out , i )
859- ffn1_out = self .compute_activation (ffn1_out , i )
959+ if self ._is_moe is True and ((self ._moe_every2 is True and i % 2 == 1 ) or self ._moe_every2 is False ):
960+ # fused moe
961+ ffn2_out = self .compute_fused_moe (tmp_out , i )
962+
963+ else :
964+ # ffn1 matmul
965+ ffn1_out = self .compute_ffn1 (tmp_out , i )
966+ ffn1_out = self .compute_activation (ffn1_out , i )
860967
861- # ffn2 matmul
862- ffn2_out = self .compute_ffn2 (ffn1_out , i )
968+ # ffn2 matmul
969+ ffn2_out = self .compute_ffn2 (ffn1_out , i )
863970
864971 # all_reduce
865972 if self .nranks > 1 :
0 commit comments