@@ -1262,6 +1262,35 @@ def create_predictor(
12621262                    )
12631263                model .eval ()
12641264
1265+             elif  "mixtral"  in  config .architectures [0 ].lower ():
1266+                 if  predictor_args .block_attn :
1267+                     config .max_seq_len  =  predictor_args .total_max_length 
1268+                     config .block_size  =  predictor_args .block_size 
1269+                     from  paddlenlp .experimental .transformers  import  (
1270+                         MixtralForCausalLMBlockInferenceModel  as  MixtralInferenceModel ,
1271+                     )
1272+ 
1273+                     model  =  MixtralInferenceModel .from_pretrained (
1274+                         predictor_args .model_name_or_path ,
1275+                         config = config ,
1276+                         dtype = predictor_args .dtype ,
1277+                         tensor_parallel_degree = tensor_parallel_degree ,
1278+                         tensor_parallel_rank = tensor_parallel_rank ,
1279+                     )
1280+                 else :
1281+                     from  paddlenlp .experimental .transformers  import  (
1282+                         MixtralForCausalLMInferenceModel  as  MixtralInferenceModel ,
1283+                     )
1284+ 
1285+                     model  =  MixtralInferenceModel .from_pretrained (
1286+                         predictor_args .model_name_or_path ,
1287+                         config = config ,
1288+                         dtype = predictor_args .dtype ,
1289+                         tensor_parallel_degree = tensor_parallel_degree ,
1290+                         tensor_parallel_rank = tensor_parallel_rank ,
1291+                     )
1292+                 model .eval ()
1293+ 
12651294            elif  "opt"  in  config .architectures [0 ].lower ():
12661295                if  model_args .model_type  ==  "opt-img2txt" :
12671296                    # we use opt for img2txt. 
@@ -1405,6 +1434,20 @@ def create_predictor(
14051434                cache_kvs_shape  =  LlamaInferenceModel .get_cache_kvs_shape (
14061435                    config , predictor_args .batch_size , predictor_args .total_max_length 
14071436                )
1437+             elif  "mixtral"  in  config .architectures [0 ].lower ():
1438+                 if  predictor_args .block_attn :
1439+                     config .block_size  =  predictor_args .block_size 
1440+                     config .max_seq_len  =  predictor_args .total_max_length 
1441+                     from  paddlenlp .experimental .transformers  import  (
1442+                         MixtralForCausalLMBlockInferenceModel  as  MixtralInferenceModel ,
1443+                     )
1444+                 else :
1445+                     from  paddlenlp .experimental .transformers  import  (
1446+                         MixtralForCausalLMInferenceModel  as  MixtralInferenceModel ,
1447+                     )
1448+                 cache_kvs_shape  =  MixtralInferenceModel .get_cache_kvs_shape (
1449+                     config , predictor_args .batch_size , predictor_args .total_max_length 
1450+                 )
14081451            elif  "chatglmv2forcausallm"  in  config .architectures [0 ].lower ():
14091452                from  paddlenlp .experimental .transformers  import  (
14101453                    ChatGLMv2ForCausalLMInferenceModel ,
0 commit comments