Skip to content

Commit 0a5de12

Browse files
authored
[LLM Inference] Support qwen2 (#8893)
* stage 1 * update * update * support qwen2 bf16/wint8 * add qwen2 ptq map * update * fix tune_cublaslt_gemm.cu
1 parent 6f3e736 commit 0a5de12

File tree

8 files changed

+1008
-3
lines changed

8 files changed

+1008
-3
lines changed

csrc/generation/tune_cublaslt_gemm.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
327327
sizeof(customOption)));
328328
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
329329
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)));
330-
int splitK_val = 0;
330+
int splitK_val = 1;
331331
uint32_t redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
332332
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
333333
&algo,
@@ -346,10 +346,10 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
346346
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
347347
&splitKSequenceA[l - 1],
348348
sizeof(splitKSequenceA[l - 1])));
349-
for (redScheme = 0;
349+
for (redScheme = 1;
350350
redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK &&
351351
(AlgoCount < AlgoCombinations);
352-
redScheme++) {
352+
redScheme <<= 1) {
353353
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
354354
&algo,
355355
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,

llm/predict/predictor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,32 @@ def create_predictor(
13951395
dtype=predictor_args.dtype,
13961396
)
13971397
model.eval()
1398+
elif "qwen2" in config.architectures[0].lower():
1399+
if predictor_args.block_attn:
1400+
config.max_seq_len = predictor_args.total_max_length
1401+
config.block_size = predictor_args.block_size
1402+
from paddlenlp.experimental.transformers import (
1403+
Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel,
1404+
)
1405+
1406+
model = Qwen2InferenceModel.from_pretrained(
1407+
predictor_args.model_name_or_path,
1408+
config=config,
1409+
dtype=predictor_args.dtype,
1410+
tensor_parallel_degree=tensor_parallel_degree,
1411+
tensor_parallel_rank=tensor_parallel_rank,
1412+
)
1413+
else:
1414+
from paddlenlp.experimental.transformers import (
1415+
Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel,
1416+
)
1417+
1418+
model = Qwen2InferenceModel.from_pretrained(
1419+
predictor_args.model_name_or_path,
1420+
config=config,
1421+
dtype=predictor_args.dtype,
1422+
)
1423+
model.eval()
13981424
elif "qwen" in config.architectures[0].lower():
13991425
if model_args.model_type == "qwen-img2txt":
14001426
# we use qwen for img2txt.
@@ -1420,6 +1446,16 @@ def create_predictor(
14201446

14211447
elif predictor_args.mode == "static":
14221448
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
1449+
config.quant_type = predictor_args.quant_type
1450+
config.cachekv_int8_type = predictor_args.cachekv_int8_type
1451+
1452+
if config.quantization_config.quant_type is not None:
1453+
predictor_args.quant_type = config.quantization_config.quant_type
1454+
config.quant_type = config.quantization_config.quant_type
1455+
if "c8" in config.quant_type:
1456+
predictor_args.cachekv_int8_type = "static"
1457+
config.cachekv_int8_type = "static"
1458+
14231459
if "llama" in config.architectures[0].lower():
14241460
if predictor_args.block_attn:
14251461
config.block_size = predictor_args.block_size
@@ -1486,6 +1522,21 @@ def create_predictor(
14861522
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
14871523
config, predictor_args.batch_size, predictor_args.total_max_length
14881524
)
1525+
elif "qwen2" in config.architectures[0].lower():
1526+
if predictor_args.block_attn:
1527+
config.block_size = predictor_args.block_size
1528+
config.max_seq_len = predictor_args.total_max_length
1529+
from paddlenlp.experimental.transformers import (
1530+
Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel,
1531+
)
1532+
else:
1533+
from paddlenlp.experimental.transformers import (
1534+
Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel,
1535+
)
1536+
cache_kvs_shape = Qwen2InferenceModel.get_cache_kvs_shape(
1537+
config, predictor_args.batch_size, predictor_args.total_max_length
1538+
)
1539+
14891540
elif "qwen" in config.architectures[0].lower():
14901541
from paddlenlp.experimental.transformers import (
14911542
QWenForCausalLMInferenceModel,

llm/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
AutoTokenizer,
3535
ChatGLMv2Tokenizer,
3636
LlamaForCausalLMPipe,
37+
PretrainedConfig,
3738
Qwen2ForCausalLMPipe,
3839
)
3940
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .llama import *
2121
from .opt import *
2222
from .qwen import *
23+
from .qwen2 import *
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .modeling import *

0 commit comments

Comments
 (0)