@@ -1395,6 +1395,32 @@ def create_predictor(
13951395 dtype = predictor_args .dtype ,
13961396 )
13971397 model .eval ()
1398+ elif "qwen2" in config .architectures [0 ].lower ():
1399+ if predictor_args .block_attn :
1400+ config .max_seq_len = predictor_args .total_max_length
1401+ config .block_size = predictor_args .block_size
1402+ from paddlenlp .experimental .transformers import (
1403+ Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel ,
1404+ )
1405+
1406+ model = Qwen2InferenceModel .from_pretrained (
1407+ predictor_args .model_name_or_path ,
1408+ config = config ,
1409+ dtype = predictor_args .dtype ,
1410+ tensor_parallel_degree = tensor_parallel_degree ,
1411+ tensor_parallel_rank = tensor_parallel_rank ,
1412+ )
1413+ else :
1414+ from paddlenlp .experimental .transformers import (
1415+ Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel ,
1416+ )
1417+
1418+ model = Qwen2InferenceModel .from_pretrained (
1419+ predictor_args .model_name_or_path ,
1420+ config = config ,
1421+ dtype = predictor_args .dtype ,
1422+ )
1423+ model .eval ()
13981424 elif "qwen" in config .architectures [0 ].lower ():
13991425 if model_args .model_type == "qwen-img2txt" :
14001426 # we use qwen for img2txt.
@@ -1420,6 +1446,16 @@ def create_predictor(
14201446
14211447 elif predictor_args .mode == "static" :
14221448 config = AutoConfig .from_pretrained (predictor_args .model_name_or_path )
1449+ config .quant_type = predictor_args .quant_type
1450+ config .cachekv_int8_type = predictor_args .cachekv_int8_type
1451+
1452+ if config .quantization_config .quant_type is not None :
1453+ predictor_args .quant_type = config .quantization_config .quant_type
1454+ config .quant_type = config .quantization_config .quant_type
1455+ if "c8" in config .quant_type :
1456+ predictor_args .cachekv_int8_type = "static"
1457+ config .cachekv_int8_type = "static"
1458+
14231459 if "llama" in config .architectures [0 ].lower ():
14241460 if predictor_args .block_attn :
14251461 config .block_size = predictor_args .block_size
@@ -1486,6 +1522,21 @@ def create_predictor(
14861522 cache_kvs_shape = GPTForCausalLMInferenceModel .get_cache_kvs_shape (
14871523 config , predictor_args .batch_size , predictor_args .total_max_length
14881524 )
1525+ elif "qwen2" in config .architectures [0 ].lower ():
1526+ if predictor_args .block_attn :
1527+ config .block_size = predictor_args .block_size
1528+ config .max_seq_len = predictor_args .total_max_length
1529+ from paddlenlp .experimental .transformers import (
1530+ Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel ,
1531+ )
1532+ else :
1533+ from paddlenlp .experimental .transformers import (
1534+ Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel ,
1535+ )
1536+ cache_kvs_shape = Qwen2InferenceModel .get_cache_kvs_shape (
1537+ config , predictor_args .batch_size , predictor_args .total_max_length
1538+ )
1539+
14891540 elif "qwen" in config .architectures [0 ].lower ():
14901541 from paddlenlp .experimental .transformers import (
14911542 QWenForCausalLMInferenceModel ,
0 commit comments