Skip to content

Commit 9b7c056

Browse files
committed
[Refactor] refine code
1 parent 213ada5 commit 9b7c056

File tree

1 file changed

+23
-30
lines changed

1 file changed

+23
-30
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ class MoeConfig:
157157
norm_topk_prob: bool = True
158158
moe_every2: bool = False
159159

160+
def has_moe(self) -> bool:
161+
return self.num_experts > 1
162+
163+
def use_moe(self, i: int) -> bool:
164+
return self.has_moe() and (self.moe_every2 is False or (self.moe_every2 and i % 2 == 1))
165+
160166

161167
class FusedMultiTransformerConfig:
162168
def __init__(
@@ -270,13 +276,8 @@ def __init__(
270276
self.rank_id = rank_id
271277
self.trans_qkvw = trans_qkvw
272278
self.ring_id = ring_id
273-
if moe_config:
274-
self.is_moe = True
275-
self.moe_every2 = moe_config.moe_every2
276-
self.moe_topk = moe_config.top_k
277-
self.num_experts = moe_config.num_experts
278-
else:
279-
self.is_moe = False
279+
280+
self.moe_config = moe_config
280281

281282

282283
class FusedMultiTransformerBase(Layer):
@@ -432,16 +433,16 @@ def __init__(self, config: FusedMultiTransformerConfig):
432433
)
433434

434435
gate_weight = None
435-
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
436+
if config.moe_config.use_moe(i):
436437
gate_weight = self.create_parameter(
437-
shape=[config.embed_dim, config.num_experts],
438+
shape=[config.embed_dim, self.config.moe_config.num_experts],
438439
attr=gate_weight_attr,
439440
dtype="float32",
440441
is_bias=False,
441442
default_initializer=paddle.nn.initializer.Constant(0),
442443
)
443444

444-
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
445+
if config.moe_config.use_moe(i):
445446
ffn1_weight = self.create_parameter(
446447
shape=self.moe_ffn1_weight_shape,
447448
attr=ffn1_weight_attr,
@@ -458,13 +459,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
458459

459460
ffn1_bias = None
460461
if ffn1_bias_attr:
461-
if self.config.is_moe is True and (
462-
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
463-
):
462+
if self.config.moe_config.use_moe(i):
464463
ffn1_bias = self.create_parameter(
465-
shape=[config.num_experts, self.dim_feedforward * 2]
464+
shape=[self.config.moe_config.num_experts, self.dim_feedforward * 2]
466465
if self.activation.endswith("glu")
467-
else [config.num_experts, self.dim_feedforward],
466+
else [self.config.moe_config.num_experts, self.dim_feedforward],
468467
attr=ffn1_bias_attr,
469468
dtype=self._dtype,
470469
is_bias=True,
@@ -477,9 +476,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
477476
is_bias=True,
478477
)
479478

480-
if self.config.is_moe is True and (
481-
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
482-
):
479+
if self.config.moe_config.use_moe(i):
483480
ffn2_weight = self.create_parameter(
484481
shape=self.moe_ffn2_weight_shape,
485482
attr=ffn2_weight_attr,
@@ -496,11 +493,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
496493

497494
ffn2_bias = None
498495
if ffn2_bias_attr:
499-
if config.is_moe is True and (
500-
(config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False
501-
):
496+
if config.moe_config.use_moe(i):
502497
ffn2_bias = self.create_parameter(
503-
shape=[config.num_experts, config.embed_dim],
498+
shape=[self.config.moe_config.num_experts, config.embed_dim],
504499
attr=ffn2_bias_attr,
505500
dtype=self._dtype,
506501
is_bias=True,
@@ -636,13 +631,13 @@ def init_weight_shape(self, config):
636631
)
637632
self.ffn2_weight_shape = [self.dim_feedforward, self.embed_dim]
638633

639-
if self.config.is_moe is True:
634+
if self.config.moe_config.has_moe() is True:
640635
self.moe_ffn1_weight_shape = (
641-
[config.num_experts, self.embed_dim, self.dim_feedforward * 2]
636+
[self.config.moe_config.num_experts, self.embed_dim, self.dim_feedforward * 2]
642637
if self.activation.endswith("glu")
643-
else [config.num_experts, self.embed_dim, self.dim_feedforward]
638+
else [self.config.moe_config.num_experts, self.embed_dim, self.dim_feedforward]
644639
)
645-
self.moe_ffn2_weight_shape = [config.num_experts, self.dim_feedforward, self.embed_dim]
640+
self.moe_ffn2_weight_shape = [self.config.moe_config.num_experts, self.dim_feedforward, self.embed_dim]
646641

647642
def get_weight_create_dype(self):
648643
return self._dtype
@@ -817,7 +812,7 @@ def compute_fused_moe(self, tmp_out, i):
817812
None,
818813
None,
819814
"None",
820-
self.config.moe_topk,
815+
self.config.moe_config.top_k,
821816
)
822817
return fused_moe_out
823818

@@ -962,9 +957,7 @@ def forward(
962957
# ffn layernorm
963958
tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i)
964959

965-
if self.config.is_moe is True and (
966-
(self.config.moe_every2 is True and i % 2 == 1) or self.config.moe_every2 is False
967-
):
960+
if self.config.moe_config.use_moe(i):
968961
# fused moe
969962
ffn2_out = self.compute_fused_moe(tmp_out, i)
970963

0 commit comments

Comments
 (0)