@@ -357,6 +357,7 @@ def _preprocess(self, source):
357357 self .architectures ,
358358 top_p = self .config .top_p ,
359359 temperature = self .config .temperature ,
360+ benchmark = self .config .benchmark ,
360361 )
361362 for i in range (inputs ["input_ids" ].shape [0 ]):
362363 length = inputs ["seq_len_encoder" ][i ][0 ]
@@ -375,6 +376,7 @@ def _preprocess(self, source):
375376 self .architectures ,
376377 top_p = self .config .top_p ,
377378 temperature = self .config .temperature ,
379+ benchmark = self .config .benchmark ,
378380 )
379381 for i in range (inputs ["input_ids" ].shape [0 ]):
380382 length = inputs ["seq_len_encoder" ][i ][0 ]
@@ -439,6 +441,7 @@ def _preprocess(self, source):
439441 top_p = self .config .top_p ,
440442 temperature = self .config .temperature ,
441443 pre_caches_length = pre_caches_length ,
444+ benchmark = self .config .benchmark ,
442445 )
443446
444447 for i in range (inputs ["input_ids" ].shape [0 ]):
@@ -664,8 +667,6 @@ def create_predictor(
664667 LlamaForCausalLMInferenceModel as LlamaInferenceModel ,
665668 )
666669
667- config .tensor_parallel_degree = tensor_parallel_degree
668- config .tensor_parallel_rank = tensor_parallel_rank
669670 config .quant_bits = - 1
670671
671672 if predictor_args .quant_type .startswith ("weight_only_int" ):
@@ -692,15 +693,26 @@ def create_predictor(
692693 BloomForCausalLMInferenceModel ,
693694 )
694695
695- config .tensor_parallel_degree = tensor_parallel_degree
696- config .tensor_parallel_rank = tensor_parallel_rank
697696 model = BloomForCausalLMInferenceModel .from_pretrained (
698697 predictor_args .model_name_or_path ,
699698 config = config ,
700699 dtype = predictor_args .dtype ,
701700 )
702701 cache_kvs_shape = BloomForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
703702 model .eval ()
703+ elif "gpt" in config .architectures [0 ].lower ():
704+ from paddlenlp .experimental .transformers import (
705+ GPTForCausalLMInferenceModel ,
706+ )
707+
708+ model = GPTForCausalLMInferenceModel .from_pretrained (
709+ predictor_args .model_name_or_path ,
710+ config = config ,
711+ dtype = predictor_args .dtype ,
712+ )
713+ model .eval ()
714+ else :
715+ raise ValueError ("the `model type` should be one of [llama, chatglm, bloom, gpt]" )
704716 predictor = DygraphInferencePredictor (predictor_args , model = model , tokenizer = tokenizer )
705717 elif predictor_args .mode == "static" :
706718 config = AutoConfig .from_pretrained (predictor_args .model_name_or_path )
@@ -710,7 +722,6 @@ def create_predictor(
710722 )
711723
712724 cache_kvs_shape = LlamaForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
713- predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
714725 elif "chatglm" in config .architectures [0 ].lower ():
715726 from paddlenlp .experimental .transformers import (
716727 ChatGLMForCausalLMInferenceModel ,
@@ -719,16 +730,21 @@ def create_predictor(
719730 cache_kvs_shape = ChatGLMForCausalLMInferenceModel .get_cache_kvs_shape (
720731 config , predictor_args .batch_size
721732 )
722- predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
723733 elif "bloom" in config .architectures [0 ].lower ():
724734 from paddlenlp .experimental .transformers import (
725735 BloomForCausalLMInferenceModel ,
726736 )
727737
728738 cache_kvs_shape = BloomForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
729- predictor = StaticInferencePredictor (
730- predictor_args , cache_kvs_shape = cache_kvs_shape , tokenizer = tokenizer
739+ elif "gpt" in config .architectures [0 ].lower ():
740+ from paddlenlp .experimental .transformers import (
741+ GPTForCausalLMInferenceModel ,
731742 )
743+
744+ cache_kvs_shape = GPTForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
745+ else :
746+ raise ValueError ("the `model type` should be one of [llama, chatglm, bloom, gpt]" )
747+ predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
732748 else :
733749 raise ValueError ("the `mode` should be one of [dynamic, static]" )
734750 return predictor
0 commit comments