Skip to content

Commit 213ada5

Browse files
committed
[Refactor] refine code
1 parent bb936da commit 213ada5

File tree

2 files changed

+19
-31
lines changed

2 files changed

+19
-31
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,6 @@ def __init__(self, config: FusedMultiTransformerConfig):
316316
self.head_dim = config.embed_dim // config.num_heads
317317
assert self.head_dim * config.num_heads == config.embed_dim, "embed_dim must be divisible by num_heads"
318318

319-
if config.is_moe:
320-
self._is_moe = config.is_moe
321-
self._moe_every2 = config.moe_every2
322-
self._moe_topk = config.moe_topk
323-
else:
324-
self._is_moe = False
325-
self._moe_every2 = False
326-
self._moe_topk = 1
327-
328319
# tensor model parallel
329320
if config.nranks > 1:
330321
assert config.ring_id != -1
@@ -449,24 +440,10 @@ def __init__(self, config: FusedMultiTransformerConfig):
449440
is_bias=False,
450441
default_initializer=paddle.nn.initializer.Constant(0),
451442
)
452-
else:
453-
gate_weight = self.create_parameter(
454-
shape=[1],
455-
attr=gate_weight_attr,
456-
dtype="float32",
457-
is_bias=False,
458-
default_initializer=paddle.nn.initializer.Constant(0),
459-
)
460-
461-
if config.is_moe is False:
462-
gate_weight = None
463-
self.gate_weights = None
464443

465444
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
466445
ffn1_weight = self.create_parameter(
467-
shape=[config.num_experts, self.embed_dim, self.dim_feedforward * 2]
468-
if self.activation.endswith("glu")
469-
else [config.num_experts, self.embed_dim, self.dim_feedforward],
446+
shape=self.moe_ffn1_weight_shape,
470447
attr=ffn1_weight_attr,
471448
dtype=self.create_params_type,
472449
is_bias=False,
@@ -481,8 +458,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
481458

482459
ffn1_bias = None
483460
if ffn1_bias_attr:
484-
if config.is_moe is True and (
485-
(config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False
461+
if self.config.is_moe is True and (
462+
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
486463
):
487464
ffn1_bias = self.create_parameter(
488465
shape=[config.num_experts, self.dim_feedforward * 2]
@@ -500,9 +477,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
500477
is_bias=True,
501478
)
502479

503-
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
480+
if self.config.is_moe is True and (
481+
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
482+
):
504483
ffn2_weight = self.create_parameter(
505-
shape=[config.num_experts, self.dim_feedforward, self.embed_dim],
484+
shape=self.moe_ffn2_weight_shape,
506485
attr=ffn2_weight_attr,
507486
dtype=self.create_params_type,
508487
is_bias=False,
@@ -657,6 +636,14 @@ def init_weight_shape(self, config):
657636
)
658637
self.ffn2_weight_shape = [self.dim_feedforward, self.embed_dim]
659638

639+
if self.config.is_moe is True:
640+
self.moe_ffn1_weight_shape = (
641+
[config.num_experts, self.embed_dim, self.dim_feedforward * 2]
642+
if self.activation.endswith("glu")
643+
else [config.num_experts, self.embed_dim, self.dim_feedforward]
644+
)
645+
self.moe_ffn2_weight_shape = [config.num_experts, self.dim_feedforward, self.embed_dim]
646+
660647
def get_weight_create_dype(self):
661648
return self._dtype
662649

@@ -830,7 +817,7 @@ def compute_fused_moe(self, tmp_out, i):
830817
None,
831818
None,
832819
"None",
833-
self._moe_topk,
820+
self.config.moe_topk,
834821
)
835822
return fused_moe_out
836823

@@ -975,7 +962,9 @@ def forward(
975962
# ffn layernorm
976963
tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i)
977964

978-
if self._is_moe is True and ((self._moe_every2 is True and i % 2 == 1) or self._moe_every2 is False):
965+
if self.config.is_moe is True and (
966+
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
967+
):
979968
# fused moe
980969
ffn2_out = self.compute_fused_moe(tmp_out, i)
981970

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,6 @@ def set_state_dict(self, state_dict):
774774
act_scale_map_dict = scale_map_dict["act_scale"]
775775
weight_scale_map_dict = scale_map_dict["weight_scale"]
776776
cache_scale_map_dict = scale_map_dict["cachekv_scale"]
777-
# TODO(RichardWooSJTU): support multi-cards
778777

779778
act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
780779
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")

0 commit comments

Comments
 (0)