Skip to content

Commit 5b4384c

Browse files
committed
[Feature] Fused Mixtral support
1 parent c4d1abf commit 5b4384c

File tree

5 files changed

+1509
-34
lines changed

5 files changed

+1509
-34
lines changed

llm/predict/predictor.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ class PredictorArgument:
118118
block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."})
119119
cachekv_int8_type: str = field(
120120
default=None,
121-
metadata={"help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically."},
121+
metadata={
122+
"help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically."
123+
},
122124
)
123125

124126
chat_template: str = field(
@@ -1090,9 +1092,7 @@ def __init__(
10901092
if config.cachekv_int8_type is not None:
10911093
cachekv_dtype = "uint8"
10921094
for i in range(len(self.cache_kvs_shape) // 2):
1093-
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(
1094-
self.cache_kvs_shape[2 * i], dtype=cachekv_dtype
1095-
)
1095+
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype=cachekv_dtype)
10961096
self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
10971097
self.cache_kvs_shape[2 * i + 1], dtype=cachekv_dtype
10981098
)
@@ -1413,6 +1413,35 @@ def create_predictor(
14131413
)
14141414
model.eval()
14151415

1416+
elif "mixtral" in config.architectures[0].lower():
1417+
if predictor_args.block_attn:
1418+
config.max_seq_len = predictor_args.total_max_length
1419+
config.block_size = predictor_args.block_size
1420+
from paddlenlp.experimental.transformers import (
1421+
MixtralForCausalLMBlockInferenceModel as MixtralInferenceModel,
1422+
)
1423+
1424+
model = MixtralInferenceModel.from_pretrained(
1425+
predictor_args.model_name_or_path,
1426+
config=config,
1427+
dtype=predictor_args.dtype,
1428+
tensor_parallel_degree=tensor_parallel_degree,
1429+
tensor_parallel_rank=tensor_parallel_rank,
1430+
)
1431+
else:
1432+
from paddlenlp.experimental.transformers import (
1433+
MixtralForCausalLMInferenceModel as MixtralInferenceModel,
1434+
)
1435+
1436+
model = MixtralInferenceModel.from_pretrained(
1437+
predictor_args.model_name_or_path,
1438+
config=config,
1439+
dtype=predictor_args.dtype,
1440+
tensor_parallel_degree=tensor_parallel_degree,
1441+
tensor_parallel_rank=tensor_parallel_rank,
1442+
)
1443+
model.eval()
1444+
14161445
elif "opt" in config.architectures[0].lower():
14171446
if model_args.model_type == "opt-img2txt":
14181447
# we use opt for img2txt.
@@ -1525,6 +1554,20 @@ def create_predictor(
15251554
cache_kvs_shape = LlamaInferenceModel.get_cache_kvs_shape(
15261555
config, predictor_args.batch_size, predictor_args.total_max_length
15271556
)
1557+
elif "mixtral" in config.architectures[0].lower():
1558+
if predictor_args.block_attn:
1559+
config.block_size = predictor_args.block_size
1560+
config.max_seq_len = predictor_args.total_max_length
1561+
from paddlenlp.experimental.transformers import (
1562+
MixtralForCausalLMBlockInferenceModel as MixtralInferenceModel,
1563+
)
1564+
else:
1565+
from paddlenlp.experimental.transformers import (
1566+
MixtralForCausalLMInferenceModel as MixtralInferenceModel,
1567+
)
1568+
cache_kvs_shape = MixtralInferenceModel.get_cache_kvs_shape(
1569+
config, predictor_args.batch_size, predictor_args.total_max_length
1570+
)
15281571
elif "chatglmv2forcausallm" in config.architectures[0].lower():
15291572
from paddlenlp.experimental.transformers import (
15301573
ChatGLMv2ForCausalLMInferenceModel,

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
from .fused_transformer_layers import *
1919
from .gpt import *
2020
from .llama import *
21+
from .mixtral import *
2122
from .opt import *
2223
from .qwen import *

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 137 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from paddle.framework import LayerHelper, core, in_dynamic_mode
1919
from paddle.incubate.nn.functional import (
2020
fused_layer_norm,
21+
fused_moe,
2122
fused_rms_norm,
2223
masked_multihead_attention,
2324
variable_length_memory_efficient_attention,
@@ -167,6 +168,7 @@ def __init__(
167168
linear_bias_attrs=None,
168169
ffn_ln_scale_attrs=None,
169170
ffn_ln_bias_attrs=None,
171+
gate_weight_attrs=None,
170172
ffn1_weight_attrs=None,
171173
ffn1_weight_scale_attrs=None,
172174
ffn1_bias_attrs=None,
@@ -197,12 +199,15 @@ def __init__(
197199
kv_num_heads=-1,
198200
cachekv_int8_type=None,
199201
rank_id=-1,
202+
is_moe=False,
203+
moe_every2=False,
204+
moe_topk=2,
205+
num_experts=1,
200206
):
201207
self.embed_dim = embed_dim
202208
self.num_heads = num_heads
203209
if kv_num_heads > 0:
204210
self.kv_num_heads = kv_num_heads
205-
assert nranks == 1, "nranks should be 1 for kv_num_heads > 0"
206211
else:
207212
self.kv_num_heads = num_heads
208213
self.dim_feedforward = dim_feedforward
@@ -222,6 +227,7 @@ def __init__(
222227
self.linear_bias_attrs = linear_bias_attrs
223228
self.ffn_ln_scale_attrs = ffn_ln_scale_attrs
224229
self.ffn_ln_bias_attrs = ffn_ln_bias_attrs
230+
self.gate_weight_attrs = gate_weight_attrs
225231
self.ffn1_weight_attrs = ffn1_weight_attrs
226232
self.ffn1_weight_scale_attrs = ffn1_weight_scale_attrs
227233
self.ffn1_bias_attrs = ffn1_bias_attrs
@@ -255,6 +261,10 @@ def __init__(
255261
self.rank_id = rank_id
256262
self.trans_qkvw = trans_qkvw
257263
self.ring_id = ring_id
264+
self.is_moe = is_moe
265+
self.moe_every2 = moe_every2
266+
self.moe_topk = moe_topk
267+
self.num_experts = num_experts
258268

259269

260270
class FusedMultiTransformerBase(Layer):
@@ -294,6 +304,10 @@ def __init__(self, config: FusedMultiTransformerConfig):
294304
self.head_dim = config.embed_dim // config.num_heads
295305
assert self.head_dim * config.num_heads == config.embed_dim, "embed_dim must be divisible by num_heads"
296306

307+
self._is_moe = config.is_moe
308+
self._moe_every2 = config.moe_every2
309+
self._moe_topk = config.moe_topk
310+
297311
# tensor model parallel
298312
if config.nranks > 1:
299313
assert config.ring_id != -1
@@ -316,6 +330,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
316330
self.qkv_weights, self.qkv_biases = [], []
317331
self.linear_weights, self.linear_biases = [], []
318332
self.ffn_ln_scales, self.ffn_ln_biases = [], []
333+
self.gate_weights = []
319334
self.ffn1_weights, self.ffn1_biases = [], []
320335
self.ffn2_weights, self.ffn2_biases = [], []
321336
self.cache_k_scales, self.cache_v_scales = [], []
@@ -327,6 +342,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
327342
qkv_weight_attr = self.get_attr(config.qkv_weight_attrs, i)
328343

329344
qkv_bias_attr = self.get_attr(config.qkv_bias_attrs, i)
345+
gate_weight_attr = self.get_attr(config.gate_weight_attrs, i)
330346
linear_weight_attr = self.get_attr(config.linear_weight_attrs, i)
331347
linear_bias_attr = self.get_attr(config.linear_bias_attrs, i)
332348

@@ -407,37 +423,99 @@ def __init__(self, config: FusedMultiTransformerConfig):
407423
dtype=self._norm_weight_dtype,
408424
)
409425

410-
ffn1_weight = self.create_parameter(
411-
shape=self.ffn1_weight_shape,
412-
attr=ffn1_weight_attr,
413-
dtype=self.create_params_type,
414-
is_bias=False,
415-
)
426+
gate_weight = None
427+
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
428+
gate_weight = self.create_parameter(
429+
shape=[config.embed_dim, config.num_experts],
430+
attr=gate_weight_attr,
431+
dtype="float32",
432+
is_bias=False,
433+
default_initializer=paddle.nn.initializer.Constant(0),
434+
)
435+
else:
436+
gate_weight = self.create_parameter(
437+
shape=[1],
438+
attr=gate_weight_attr,
439+
dtype="float32",
440+
is_bias=False,
441+
default_initializer=paddle.nn.initializer.Constant(0),
442+
)
443+
444+
if config.is_moe is False:
445+
gate_weight = None
446+
self.gate_weights = None
447+
448+
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
449+
ffn1_weight = self.create_parameter(
450+
shape=[config.num_experts, self.embed_dim, self.dim_feedforward * 2]
451+
if self.activation.endswith("glu")
452+
else [config.num_experts, self.embed_dim, self.dim_feedforward],
453+
attr=ffn1_weight_attr,
454+
dtype=self.create_params_type,
455+
is_bias=False,
456+
)
457+
else:
458+
ffn1_weight = self.create_parameter(
459+
shape=self.ffn1_weight_shape,
460+
attr=ffn1_weight_attr,
461+
dtype=self.create_params_type,
462+
is_bias=False,
463+
)
416464

417465
ffn1_bias = None
418466
if ffn1_bias_attr:
419-
ffn1_bias = self.create_parameter(
420-
shape=[dim_feedforward * 2] if config.activation.endswith("glu") else [dim_feedforward],
421-
attr=ffn1_bias_attr,
422-
dtype=self._dtype,
423-
is_bias=True,
467+
if config.is_moe is True and (
468+
(config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False
469+
):
470+
ffn1_bias = self.create_parameter(
471+
shape=[config.num_experts, self.dim_feedforward * 2]
472+
if self.activation.endswith("glu")
473+
else [config.num_experts, self.dim_feedforward],
474+
attr=ffn1_bias_attr,
475+
dtype=self._dtype,
476+
is_bias=True,
477+
)
478+
else:
479+
ffn1_bias = self.create_parameter(
480+
shape=[dim_feedforward * 2] if self.activation.endswith("glu") else [dim_feedforward],
481+
attr=ffn1_bias_attr,
482+
dtype=self._dtype,
483+
is_bias=True,
484+
)
485+
486+
if config.is_moe is True and ((config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False):
487+
ffn2_weight = self.create_parameter(
488+
shape=[config.num_experts, self.dim_feedforward, self.embed_dim],
489+
attr=ffn2_weight_attr,
490+
dtype=self.create_params_type,
491+
is_bias=False,
492+
)
493+
else:
494+
ffn2_weight = self.create_parameter(
495+
shape=self.ffn2_weight_shape,
496+
attr=ffn2_weight_attr,
497+
dtype=self.create_params_type,
498+
is_bias=False,
424499
)
425-
426-
ffn2_weight = self.create_parameter(
427-
shape=self.ffn2_weight_shape,
428-
attr=ffn2_weight_attr,
429-
dtype=self.create_params_type,
430-
is_bias=False,
431-
)
432500

433501
ffn2_bias = None
434502
if ffn2_bias_attr:
435-
ffn2_bias = self.create_parameter(
436-
shape=[config.embed_dim],
437-
attr=ffn2_bias_attr,
438-
dtype=self._dtype,
439-
is_bias=True,
440-
)
503+
if config.is_moe is True and (
504+
(config.moe_every2 is True and i % 2 == 1) or config.moe_every2 is False
505+
):
506+
ffn2_bias = self.create_parameter(
507+
shape=[config.num_experts, config.embed_dim],
508+
attr=ffn2_bias_attr,
509+
dtype=self._dtype,
510+
is_bias=True,
511+
)
512+
else:
513+
ffn2_bias = self.create_parameter(
514+
shape=[config.embed_dim],
515+
attr=ffn2_bias_attr,
516+
dtype=self._dtype,
517+
is_bias=True,
518+
)
441519

442520
cache_k_scale = None
443521
if cache_k_scale_attr:
@@ -495,6 +573,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
495573

496574
self.ffn_ln_scales.append(ffn_ln_scale)
497575
self.ffn_ln_biases.append(ffn_ln_bias)
576+
if gate_weight is not None:
577+
self.gate_weights.append(gate_weight)
498578
self.ffn1_weights.append(ffn1_weight)
499579
self.ffn1_biases.append(ffn1_bias)
500580
self.ffn2_weights.append(ffn2_weight)
@@ -713,6 +793,28 @@ def compute_ffn_layernorm(self, out_linear_out, residual_input, i):
713793

714794
return tmp_out, residual_input
715795

796+
def compute_fused_moe(self, tmp_out, i):
797+
# todo[xinhw]: make bias optional
798+
if self.ffn1_biases[i] is None:
799+
shape1 = paddle.to_tensor([self.ffn1_weights[i].shape[0], 1, self.dim_feedforward * 2])
800+
self.ffn1_biases[i] = paddle.zeros(shape1)
801+
if self.ffn2_biases[i] is None:
802+
shape2 = paddle.to_tensor([self.ffn1_weights[i].shape[0], 1, self.embed_dim])
803+
self.ffn2_biases[i] = paddle.zeros(shape2)
804+
fused_moe_out = fused_moe(
805+
tmp_out,
806+
self.gate_weights[i],
807+
self.ffn1_weights[i],
808+
self.ffn1_biases[i],
809+
self.ffn2_weights[i],
810+
self.ffn2_biases[i],
811+
None,
812+
None,
813+
"None",
814+
self._moe_topk,
815+
)
816+
return fused_moe_out
817+
716818
def compute_activation(self, ffn1_out, i):
717819
return fused_act_bias_wrapper(ffn1_out, self.ffn1_biases[i], act_method=self.activation)
718820

@@ -854,12 +956,17 @@ def forward(
854956
# ffn layernorm
855957
tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i)
856958

857-
# ffn1 matmul
858-
ffn1_out = self.compute_ffn1(tmp_out, i)
859-
ffn1_out = self.compute_activation(ffn1_out, i)
959+
if self._is_moe is True and ((self._moe_every2 is True and i % 2 == 1) or self._moe_every2 is False):
960+
# fused moe
961+
ffn2_out = self.compute_fused_moe(tmp_out, i)
962+
963+
else:
964+
# ffn1 matmul
965+
ffn1_out = self.compute_ffn1(tmp_out, i)
966+
ffn1_out = self.compute_activation(ffn1_out, i)
860967

861-
# ffn2 matmul
862-
ffn2_out = self.compute_ffn2(ffn1_out, i)
968+
# ffn2 matmul
969+
ffn2_out = self.compute_ffn2(ffn1_out, i)
863970

864971
# all_reduce
865972
if self.nranks > 1:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .modeling import *

0 commit comments

Comments
 (0)