1414from __future__ import annotations
1515
1616from dataclasses import dataclass
17+ from typing import List , Optional
1718
1819import paddle
1920import paddle .distributed as dist
@@ -157,12 +158,25 @@ class MoeConfig:
157158 norm_topk_prob : bool = True
158159 moe_every2 : bool = False
159160
161+ shared_expert_intermediate_size : int = 0
162+ shared_expert_ffn1_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
163+ shared_expert_ffn1_weight_scale_attrs : Optional [List [paddle .ParamAttr ]] = None
164+ shared_expert_ffn2_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
165+ shared_expert_ffn2_weight_scale_attrs : Optional [List [paddle .ParamAttr ]] = None
166+ shared_expert_gate_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
167+
160168 def has_moe (self ) -> bool :
161169 return self .num_experts > 1
162170
163171 def use_moe (self , i : int ) -> bool :
164172 return self .has_moe () and (self .moe_every2 is False or (self .moe_every2 and i % 2 == 1 ))
165173
174+ def has_shared_expert (self ) -> bool :
175+ return self .has_moe () and self .shared_expert_intermediate_size > 0
176+
177+ def use_shared_expert (self , i : int ) -> bool :
178+ return self .use_moe (i ) and self .shared_expert_intermediate_size > 0
179+
166180
167181class FusedMultiTransformerConfig :
168182 def __init__ (
@@ -342,9 +356,15 @@ def __init__(self, config: FusedMultiTransformerConfig):
342356 self .gate_weights = []
343357 self .ffn1_weights , self .ffn1_biases = [], []
344358 self .ffn2_weights , self .ffn2_biases = [], []
359+ if self .config .moe_config .has_shared_expert ():
360+ self .shared_expert_gate_weights = []
361+ self .shared_expert_ffn1_weights = []
362+ self .shared_expert_ffn2_weights = []
345363 self .cache_k_scales , self .cache_v_scales = [], []
346364 self .cache_k_out_scales , self .cache_v_out_scales = [], []
347365
366+ self .init_weight_shape (config )
367+
348368 for i in range (self .num_layers ):
349369 ln_scale_attr = self .get_attr (config .ln_scale_attrs , i )
350370 ln_bias_attr = self .get_attr (config .ln_bias_attrs , i )
@@ -362,6 +382,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
362382 ffn2_weight_attr = self .get_attr (config .ffn2_weight_attrs , i )
363383 ffn2_bias_attr = self .get_attr (config .ffn2_bias_attrs , i )
364384
385+ if self .config .moe_config .use_shared_expert (i ):
386+ shared_expert_gate_weight_attr = self .get_attr (config .moe_config .shared_expert_gate_weight_attrs , i )
387+ shared_expert_ffn1_weight_attr = self .get_attr (config .moe_config .shared_expert_ffn1_weight_attrs , i )
388+ shared_expert_ffn2_weight_attr = self .get_attr (config .moe_config .shared_expert_ffn2_weight_attrs , i )
389+
365390 cache_k_scale_attr = self .get_attr (config .cache_k_scale_attrs , i )
366391 cache_v_scale_attr = self .get_attr (config .cache_v_scale_attrs , i )
367392 cache_k_out_scale_attr = self .get_attr (config .cache_k_out_scale_attrs , i )
@@ -381,7 +406,6 @@ def __init__(self, config: FusedMultiTransformerConfig):
381406 is_bias = True ,
382407 dtype = self ._norm_weight_dtype ,
383408 )
384- self .init_weight_shape (config )
385409
386410 qkv_weight = self .create_parameter (
387411 shape = self .qkv_weight_shape ,
@@ -433,7 +457,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
433457 )
434458
435459 gate_weight = None
436- if config .moe_config .use_moe (i ):
460+
461+ if self .config .moe_config .use_moe (i ):
437462 gate_weight = self .create_parameter (
438463 shape = [config .embed_dim , self .config .moe_config .num_experts ],
439464 attr = gate_weight_attr ,
@@ -442,7 +467,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
442467 default_initializer = paddle .nn .initializer .Constant (0 ),
443468 )
444469
445- if config .moe_config .use_moe (i ):
470+ if self . config .moe_config .use_moe (i ):
446471 ffn1_weight = self .create_parameter (
447472 shape = self .moe_ffn1_weight_shape ,
448473 attr = ffn1_weight_attr ,
@@ -493,7 +518,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
493518
494519 ffn2_bias = None
495520 if ffn2_bias_attr :
496- if config .moe_config .use_moe (i ):
521+ if self . config .moe_config .use_moe (i ):
497522 ffn2_bias = self .create_parameter (
498523 shape = [self .config .moe_config .num_experts , config .embed_dim ],
499524 attr = ffn2_bias_attr ,
@@ -508,6 +533,23 @@ def __init__(self, config: FusedMultiTransformerConfig):
508533 is_bias = True ,
509534 )
510535
536+ if self .config .moe_config .use_shared_expert (i ):
537+ shared_expert_ffn1_weight = self .create_parameter (
538+ shape = self .shared_expert_ffn1_weight_shape ,
539+ attr = shared_expert_ffn1_weight_attr ,
540+ dtype = self .create_params_type ,
541+ )
542+ shared_expert_ffn2_weight = self .create_parameter (
543+ shape = self .shared_expert_ffn2_weight_shape ,
544+ attr = shared_expert_ffn2_weight_attr ,
545+ dtype = self .create_params_type ,
546+ )
547+ shared_expert_gate_weight = self .create_parameter (
548+ shape = self .shared_expert_gate_weight_shape ,
549+ attr = shared_expert_gate_weight_attr ,
550+ dtype = self ._helper .get_default_dtype (),
551+ )
552+
511553 cache_k_scale = None
512554 if cache_k_scale_attr :
513555 cache_k_scale = self .create_parameter (
@@ -571,6 +613,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
571613 self .ffn2_weights .append (ffn2_weight )
572614 self .ffn2_biases .append (ffn2_bias )
573615
616+ if self .config .moe_config .use_shared_expert (i ):
617+ self .shared_expert_ffn1_weights .append (shared_expert_ffn1_weight )
618+ self .shared_expert_ffn2_weights .append (shared_expert_ffn2_weight )
619+ self .shared_expert_gate_weights .append (shared_expert_gate_weight )
620+
574621 self .cache_k_scales .append (cache_k_scale )
575622 self .cache_v_scales .append (cache_v_scale )
576623 self .cache_k_out_scales .append (cache_k_out_scale )
@@ -592,6 +639,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
592639 self ._add_parameter (ffn2_weight )
593640 self ._add_parameter (ffn2_bias )
594641
642+ if self .config .moe_config .use_shared_expert (i ):
643+ self ._add_parameter (shared_expert_ffn1_weight )
644+ self ._add_parameter (shared_expert_ffn2_weight )
645+ self ._add_parameter (shared_expert_gate_weight )
646+
595647 self ._add_parameter (cache_k_scale )
596648 self ._add_parameter (cache_v_scale )
597649 self ._add_parameter (cache_k_out_scale )
@@ -624,6 +676,7 @@ def init_weight_shape(self, config):
624676 else [self .embed_dim , (self .num_heads + 2 * self .kv_num_heads ) * self .head_dim ]
625677 )
626678 self .linear_weight_shape = [self .num_heads * self .head_dim , self .embed_dim ]
679+
627680 self .ffn1_weight_shape = (
628681 [self .embed_dim , self .dim_feedforward * 2 ]
629682 if self .activation .endswith ("glu" )
@@ -639,6 +692,20 @@ def init_weight_shape(self, config):
639692 )
640693 self .moe_ffn2_weight_shape = [self .config .moe_config .num_experts , self .dim_feedforward , self .embed_dim ]
641694
695+ if self .config .moe_config .has_shared_expert ():
696+ self .shared_expert_ffn1_weight_shape = [
697+ self .embed_dim ,
698+ self .config .moe_config .shared_expert_intermediate_size * 2 ,
699+ ]
700+ self .shared_expert_ffn2_weight_shape = [
701+ self .config .moe_config .shared_expert_intermediate_size ,
702+ self .embed_dim ,
703+ ]
704+ self .shared_expert_gate_weight_shape = [
705+ self .embed_dim ,
706+ 1 ,
707+ ]
708+
642709 def get_weight_create_dype (self ):
643710 return self ._dtype
644711
@@ -851,6 +918,15 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
851918 )[0 ]
852919 return tmp_out , residual_input
853920
921+ def compute_shared_expert (self , tmp_out , i ):
922+ ffn1_out = paddle .matmul (tmp_out , self .shared_expert_ffn1_weights [i ])
923+ ffn1_out = fused_act_bias_wrapper (ffn1_out , None , act_method = self .activation )
924+ ffn2_out = paddle .matmul (ffn1_out , self .shared_expert_ffn2_weights [i ])
925+ gate_out = paddle .matmul (tmp_out , self .shared_expert_gate_weights [i ])
926+ gate_out = paddle .nn .functional .sigmoid (gate_out )
927+ shared_expert_output = gate_out * ffn2_out
928+ return shared_expert_output
929+
854930 def pre_process (self , ** kwargs ):
855931 pass
856932
@@ -962,6 +1038,10 @@ def forward(
9621038 # fused moe
9631039 ffn2_out = self .compute_fused_moe (tmp_out , i )
9641040
1041+ # shared_expert
1042+ if self .config .moe_config .use_shared_expert (i ):
1043+ shared_expert_out = self .compute_shared_expert (tmp_out , i )
1044+ ffn2_out = ffn2_out + shared_expert_out
9651045 else :
9661046 # ffn1 matmul
9671047 ffn1_out = self .compute_ffn1 (tmp_out , i )
@@ -1046,13 +1126,25 @@ def __init__(self, config: FusedMultiTransformerConfig):
10461126 self .ffn1_weights_scale = []
10471127 self .ffn2_weights_scale = []
10481128
1129+ if self .config .moe_config .has_shared_expert ():
1130+ self .shared_expert_ffn1_weights_scale = []
1131+ self .shared_expert_ffn2_weights_scale = []
1132+
10491133 for i in range (self .num_layers ):
10501134
10511135 qkv_weight_scale_attr = self .get_attr (config .qkv_weight_scale_attrs , i )
10521136 linear_weight_scale_attr = self .get_attr (config .linear_weight_scale_attrs , i )
10531137 ffn1_weight_scale_attr = self .get_attr (config .ffn1_weight_scale_attrs , i )
10541138 ffn2_weight_scale_attr = self .get_attr (config .ffn2_weight_scale_attrs , i )
10551139
1140+ if self .config .moe_config .use_shared_expert (i ):
1141+ shared_expert_ffn1_weight_scale_attr = self .get_attr (
1142+ config .moe_config .shared_expert_ffn1_weight_scale_attrs , i
1143+ )
1144+ shared_expert_ffn2_weight_scale_attr = self .get_attr (
1145+ config .moe_config .shared_expert_ffn2_weight_scale_attrs , i
1146+ )
1147+
10561148 qkv_weight_scale = self .create_parameter (
10571149 shape = [(self .num_heads + 2 * self .kv_num_heads ) * self .head_dim ],
10581150 attr = qkv_weight_scale_attr ,
@@ -1069,9 +1161,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
10691161
10701162 if self .config .moe_config .use_moe (i ):
10711163 ffn1_weight_scale = self .create_parameter (
1072- shape = [config .moe_config .num_experts , self .dim_feedforward * 2 ]
1164+ shape = [self . config .moe_config .num_experts , self .dim_feedforward * 2 ]
10731165 if config .activation .endswith ("glu" )
1074- else [config .moe_config .num_experts , self .dim_feedforward ],
1166+ else [self . config .moe_config .num_experts , self .dim_feedforward ],
10751167 attr = ffn1_weight_scale_attr ,
10761168 dtype = self .weight_scale_dtype ,
10771169 is_bias = False ,
@@ -1086,7 +1178,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
10861178
10871179 if self .config .moe_config .use_moe (i ):
10881180 ffn2_weight_scale = self .create_parameter (
1089- shape = [config .moe_config .num_experts , self .embed_dim ],
1181+ shape = [self . config .moe_config .num_experts , self .embed_dim ],
10901182 attr = ffn2_weight_scale_attr ,
10911183 dtype = self .weight_scale_dtype ,
10921184 is_bias = False ,
@@ -1099,16 +1191,38 @@ def __init__(self, config: FusedMultiTransformerConfig):
10991191 is_bias = False ,
11001192 )
11011193
1194+ if self .config .moe_config .use_shared_expert (i ):
1195+ shared_expert_ffn1_weight_scale = self .create_parameter (
1196+ shape = [self .config .moe_config .shared_expert_intermediate_size * 2 ],
1197+ attr = shared_expert_ffn1_weight_scale_attr ,
1198+ dtype = self .weight_scale_dtype ,
1199+ is_bias = False ,
1200+ )
1201+ shared_expert_ffn2_weight_scale = self .create_parameter (
1202+ shape = [self .embed_dim ],
1203+ attr = shared_expert_ffn2_weight_scale_attr ,
1204+ dtype = self .weight_scale_dtype ,
1205+ is_bias = False ,
1206+ )
1207+
11021208 self .qkv_weights_scale .append (qkv_weight_scale )
11031209 self .linear_weights_scale .append (linear_weight_scale )
11041210 self .ffn1_weights_scale .append (ffn1_weight_scale )
11051211 self .ffn2_weights_scale .append (ffn2_weight_scale )
11061212
1213+ if self .config .moe_config .use_shared_expert (i ):
1214+ self .shared_expert_ffn1_weights_scale .append (shared_expert_ffn1_weight_scale )
1215+ self .shared_expert_ffn2_weights_scale .append (shared_expert_ffn2_weight_scale )
1216+
11071217 self ._add_parameter (qkv_weight_scale )
11081218 self ._add_parameter (linear_weight_scale )
11091219 self ._add_parameter (ffn1_weight_scale )
11101220 self ._add_parameter (ffn2_weight_scale )
11111221
1222+ if self .config .moe_config .use_shared_expert (i ):
1223+ self ._add_parameter (shared_expert_ffn1_weight_scale )
1224+ self ._add_parameter (shared_expert_ffn2_weight_scale )
1225+
11121226 def get_weight_create_dype (self ):
11131227 return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half.
11141228
@@ -1141,6 +1255,20 @@ def init_weight_shape(self, config):
11411255 self .moe_ffn1_weight_shape [2 ] //= 2
11421256 self .moe_ffn2_weight_shape [2 ] //= 2
11431257
1258+ if self .config .moe_config .has_shared_expert ():
1259+ self .shared_expert_ffn1_weight_shape = [
1260+ self .config .moe_config .shared_expert_intermediate_size * 2 ,
1261+ self .embed_dim ,
1262+ ]
1263+ self .shared_expert_ffn2_weight_shape = [
1264+ self .embed_dim ,
1265+ self .config .moe_config .shared_expert_intermediate_size ,
1266+ ]
1267+ self .shared_expert_gate_weight_shape = [
1268+ self .embed_dim ,
1269+ 1 ,
1270+ ]
1271+
11441272 def compute_qkv_linear (self , ln_out , i ):
11451273 return weight_only_linear (
11461274 ln_out ,
@@ -1197,6 +1325,29 @@ def compute_ffn2(self, ffn1_out, i):
11971325 weight_dtype = self .weight_dtype ,
11981326 )
11991327
1328+ def compute_shared_expert (self , tmp_out , i ):
1329+ ffn1_out = weight_only_linear (
1330+ tmp_out ,
1331+ weight = self .shared_expert_ffn1_weights [i ],
1332+ weight_scale = self .shared_expert_ffn1_weights_scale [i ],
1333+ weight_dtype = self .weight_dtype ,
1334+ )
1335+
1336+ ffn1_out = fused_act_bias_wrapper (ffn1_out , None , act_method = self .activation )
1337+
1338+ ffn2_out = weight_only_linear (
1339+ ffn1_out ,
1340+ weight = self .shared_expert_ffn2_weights [i ],
1341+ weight_scale = self .shared_expert_ffn2_weights_scale [i ],
1342+ weight_dtype = self .weight_dtype ,
1343+ )
1344+
1345+ gate_out = paddle .matmul (tmp_out , self .shared_expert_gate_weights [i ])
1346+ gate_out = paddle .nn .functional .sigmoid (gate_out )
1347+
1348+ shared_expert_output = gate_out * ffn2_out
1349+ return shared_expert_output
1350+
12001351
12011352class FusedMultiTransformerWeightOnlyPostLayernorm (
12021353 FusedMultiTransformerWeightOnly , FusedMultiTransformerPostLayernorm
0 commit comments