Skip to content

Commit 31cc283

Browse files
authored
[Feature] Fused Mixtral support (#8901)
* [Feature] Fused Mixtral support * [Refactor] add MoeConfig and fix static graph export problem * [Bugfix] fix small bug * [Bugfix] fix moe_config bug * [Bugfix] fix moe_config bug * [Refactor] refine code * [Refactor] refine code * [Refactor] refine code * [Refactor] match fused moe api change * [Feature] wint8 support
1 parent 9f6b486 commit 31cc283

File tree

5 files changed

+1593
-43
lines changed

5 files changed

+1593
-43
lines changed

llm/predict/predictor.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,35 @@ def create_predictor(
12621262
)
12631263
model.eval()
12641264

1265+
elif "mixtral" in config.architectures[0].lower():
1266+
if predictor_args.block_attn:
1267+
config.max_seq_len = predictor_args.total_max_length
1268+
config.block_size = predictor_args.block_size
1269+
from paddlenlp.experimental.transformers import (
1270+
MixtralForCausalLMBlockInferenceModel as MixtralInferenceModel,
1271+
)
1272+
1273+
model = MixtralInferenceModel.from_pretrained(
1274+
predictor_args.model_name_or_path,
1275+
config=config,
1276+
dtype=predictor_args.dtype,
1277+
tensor_parallel_degree=tensor_parallel_degree,
1278+
tensor_parallel_rank=tensor_parallel_rank,
1279+
)
1280+
else:
1281+
from paddlenlp.experimental.transformers import (
1282+
MixtralForCausalLMInferenceModel as MixtralInferenceModel,
1283+
)
1284+
1285+
model = MixtralInferenceModel.from_pretrained(
1286+
predictor_args.model_name_or_path,
1287+
config=config,
1288+
dtype=predictor_args.dtype,
1289+
tensor_parallel_degree=tensor_parallel_degree,
1290+
tensor_parallel_rank=tensor_parallel_rank,
1291+
)
1292+
model.eval()
1293+
12651294
elif "opt" in config.architectures[0].lower():
12661295
if model_args.model_type == "opt-img2txt":
12671296
# we use opt for img2txt.
@@ -1405,6 +1434,20 @@ def create_predictor(
14051434
cache_kvs_shape = LlamaInferenceModel.get_cache_kvs_shape(
14061435
config, predictor_args.batch_size, predictor_args.total_max_length
14071436
)
1437+
elif "mixtral" in config.architectures[0].lower():
1438+
if predictor_args.block_attn:
1439+
config.block_size = predictor_args.block_size
1440+
config.max_seq_len = predictor_args.total_max_length
1441+
from paddlenlp.experimental.transformers import (
1442+
MixtralForCausalLMBlockInferenceModel as MixtralInferenceModel,
1443+
)
1444+
else:
1445+
from paddlenlp.experimental.transformers import (
1446+
MixtralForCausalLMInferenceModel as MixtralInferenceModel,
1447+
)
1448+
cache_kvs_shape = MixtralInferenceModel.get_cache_kvs_shape(
1449+
config, predictor_args.batch_size, predictor_args.total_max_length
1450+
)
14081451
elif "chatglmv2forcausallm" in config.architectures[0].lower():
14091452
from paddlenlp.experimental.transformers import (
14101453
ChatGLMv2ForCausalLMInferenceModel,

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .fused_transformer_layers import *
1919
from .gpt import *
2020
from .llama import *
21+
from .mixtral import *
2122
from .opt import *
2223
from .qwen import *
2324
from .qwen2 import *

0 commit comments

Comments
 (0)