@@ -271,20 +271,32 @@ def __init__(
271271 else :
272272 raise ValueError ("Please specific the model dtype." )
273273
274+ self .model_config = AutoConfig .from_pretrained (config .model_name_or_path )
274275 self .dtype = dtype
275-
276+ self . architectures = self . model_config . architectures [ 0 ]. lower ()
276277 self .cache_kvs = [paddle .zeros (shape , dtype = dtype ) for shape in cache_kv_shapes ]
277278 self .pre_ids = paddle .full ([config .batch_size , config .max_length + 1 ], - 1 , dtype = "int64" )
278- self .attention_mask = paddle .zeros (
279- shape = (config .batch_size , 1 , config .max_length , config .max_length ),
280- dtype = dtype ,
281- )
279+
280+ if "chatglm" in self .architectures :
281+ self .attention_mask = paddle .ones (
282+ shape = (config .batch_size , 1 , config .max_length , config .max_length ),
283+ dtype = dtype ,
284+ )
285+ self .tgt_pos = paddle .ones (
286+ shape = [config .batch_size , 2 , 1 ],
287+ dtype = "int64" ,
288+ )
289+ else :
290+ self .attention_mask = paddle .zeros (
291+ shape = (config .batch_size , 1 , config .max_length , config .max_length ),
292+ dtype = dtype ,
293+ )
294+
282295 self .tgt_generation_mask = paddle .zeros (
283296 shape = [config .batch_size , 1 , 1 , config .max_length + 1 ],
284297 dtype = dtype ,
285298 )
286299 self .predictor = self ._create_predictor (config )
287- self .model_config = AutoConfig .from_pretrained (config .model_name_or_path )
288300
289301 def _create_predictor (self , predictor_args : PredictorArgument ):
290302 if not is_paddlenlp_ops_available ():
@@ -327,16 +339,30 @@ def _create_predictor(self, predictor_args: PredictorArgument):
327339 return predictor
328340
329341 def _preprocess (self , source ):
330- inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length )
331- for i in range (inputs ["input_ids" ].shape [0 ]):
332- length = inputs ["seq_len_encoder" ][i ][0 ]
333- self .attention_mask [i , 0 , :length , :length ] = paddle .tril (
334- paddle .ones (shape = (length , length ), dtype = "float16" )
335- )
336- self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
342+ if "chatglm" in self .architectures :
343+ inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length , self .architectures )
344+
345+ for i in range (inputs ["input_ids" ].shape [0 ]):
346+ length = inputs ["seq_len_encoder" ][i ][0 ]
347+ self .attention_mask [i , 0 , :length , :length ] = 0
348+ self .attention_mask [i , 0 , : length - 1 , length - 1 ] = 1
349+ self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
350+ self .tgt_pos [i , 0 , 0 ] = paddle .to_tensor ([length ], dtype = "int64" )
351+
352+ inputs ["attention_mask" ] = self .attention_mask
353+ inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
354+ inputs ["tgt_pos" ] = self .tgt_pos .numpy ()
355+ else :
356+ inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length , self .architectures )
357+ for i in range (inputs ["input_ids" ].shape [0 ]):
358+ length = inputs ["seq_len_encoder" ][i ][0 ]
359+ self .attention_mask [i , 0 , :length , :length ] = paddle .tril (
360+ paddle .ones (shape = (length , length ), dtype = "float16" )
361+ )
362+ self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
337363
338- inputs ["attention_mask" ] = self .attention_mask
339- inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
364+ inputs ["attention_mask" ] = self .attention_mask
365+ inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
340366 return inputs
341367
342368 @paddle .no_grad ()
@@ -387,33 +413,61 @@ def __init__(
387413 raise ValueError ("Please specific the model dtype." )
388414
389415 self .dtype = dtype
416+ self .architectures = self .model .config .architectures [0 ].lower ()
390417
391418 self .cache_kvs = [
392419 paddle .zeros (shape , dtype = dtype )
393420 for shape in self .model .get_cache_kvs_shape (self .model .config , config .max_batch_size )
394421 ]
395422 self .pre_ids = paddle .full ([config .max_batch_size , config .max_length ], - 1 , dtype = "int64" )
396- self .attention_mask = paddle .zeros (
397- shape = (config .max_batch_size , 1 , config .max_length , config .max_length ),
398- dtype = dtype ,
399- )
423+ if "chatglm" in self .architectures :
424+ self .attention_mask = paddle .ones (
425+ shape = (config .batch_size , 1 , config .max_length , config .max_length ),
426+ dtype = dtype ,
427+ )
428+ self .tgt_pos = paddle .ones (
429+ shape = [config .batch_size , 2 , 1 ],
430+ dtype = "int64" ,
431+ )
432+ else :
433+ self .attention_mask = paddle .zeros (
434+ shape = (config .batch_size , 1 , config .max_length , config .max_length ),
435+ dtype = dtype ,
436+ )
437+
400438 self .tgt_generation_mask = paddle .zeros (
401439 shape = [config .max_batch_size , 1 , 1 , config .max_length ],
402440 dtype = dtype ,
403441 )
404442
405443 def _preprocess (self , source ):
406- inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length )
407- for i in range (inputs ["input_ids" ].shape [0 ]):
408- length = inputs ["seq_len_encoder" ][i ][0 ]
409- self .attention_mask [i , 0 , :length , :length ] = paddle .tril (
410- paddle .ones (shape = (length , length ), dtype = "float16" )
411- )
444+ if "chatglm" in self .architectures :
445+ inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length , self .architectures )
446+
447+ for i in range (inputs ["input_ids" ].shape [0 ]):
448+ length = inputs ["seq_len_encoder" ][i ][0 ]
449+ self .attention_mask [i , 0 , :length , :length ] = 0
450+ self .attention_mask [i , 0 , : length - 1 , length - 1 ] = 1
451+ self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
452+ self .tgt_pos [i , 0 , 0 ] = paddle .to_tensor ([length ], dtype = "int64" )
453+
412454 inputs ["attention_mask" ] = self .attention_mask
413- self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
414455 inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
415- inputs ["cache_kvs" ] = self .cache_kvs
416- inputs ["pre_ids" ] = self .pre_ids
456+ inputs ["cache_kvs" ] = self .cache_kvs
457+ inputs ["pre_ids" ] = self .pre_ids
458+ inputs ["tgt_pos" ] = self .tgt_pos
459+ else :
460+ inputs = dybatch_preprocess (self .tokenizer , source , self .config .max_length , self .architectures )
461+ for i in range (inputs ["input_ids" ].shape [0 ]):
462+ length = inputs ["seq_len_encoder" ][i ][0 ]
463+ self .attention_mask [i , 0 , :length , :length ] = paddle .tril (
464+ paddle .ones (shape = (length , length ), dtype = "float16" )
465+ )
466+ inputs ["attention_mask" ] = self .attention_mask
467+ self .tgt_generation_mask [i , 0 , 0 , :length ] = paddle .ones (shape = [1 , length ], dtype = "float16" )
468+ inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
469+ inputs ["cache_kvs" ] = self .cache_kvs
470+ inputs ["pre_ids" ] = self .pre_ids
417471
418472 inputs_tensor = {}
419473 for key , value in inputs .items ():
@@ -497,29 +551,51 @@ def create_predictor(
497551 else :
498552 if predictor_args .mode == "dynamic" :
499553 # TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
500- assert (
501- "llama" in predictor_args .model_name_or_path
502- ), "only support llama inference model in dygraph-inference predictor"
503- from paddlenlp .experimental .transformers import (
504- LlamaForCausalLMInferenceModel ,
505- )
506-
507554 config = AutoConfig .from_pretrained (predictor_args .model_name_or_path )
555+ if "llama" in config .architectures [0 ].lower ():
556+ from paddlenlp .experimental .transformers import (
557+ LlamaForCausalLMInferenceModel ,
558+ )
559+
560+ config .tensor_parallel_degree = tensor_parallel_degree
561+ config .tensor_parallel_rank = tensor_parallel_rank
562+ model = LlamaForCausalLMInferenceModel .from_pretrained (
563+ predictor_args .model_name_or_path , config = config , dtype = predictor_args .dtype
564+ )
565+ model .eval ()
566+ elif "chatglm" in config .architectures [0 ].lower ():
567+ from paddlenlp .experimental .transformers import (
568+ ChatGLMForCausalLMInferenceModel ,
569+ )
570+
571+ config .tensor_parallel_degree = tensor_parallel_degree
572+ config .tensor_parallel_rank = tensor_parallel_rank
508573
509- config .tensor_parallel_degree = tensor_parallel_degree
510- config .tensor_parallel_rank = tensor_parallel_rank
511- model = LlamaForCausalLMInferenceModel .from_pretrained (predictor_args .model_name_or_path , config = config )
574+ model = ChatGLMForCausalLMInferenceModel .from_pretrained (
575+ predictor_args .model_name_or_path ,
576+ config = config ,
577+ dtype = predictor_args .dtype ,
578+ )
579+ model .eval ()
512580 predictor = DygraphInferencePredictor (predictor_args , model = model , tokenizer = tokenizer )
513581 elif predictor_args .mode == "static" :
514582 config = AutoConfig .from_pretrained (predictor_args .model_name_or_path )
583+ if "llama" in config .architectures [0 ].lower ():
584+ from paddlenlp .experimental .transformers import (
585+ LlamaForCausalLMInferenceModel ,
586+ )
515587
516- # only support llama inference model currently
517- from paddlenlp .experimental .transformers import (
518- LlamaForCausalLMInferenceModel ,
519- )
588+ cache_kvs_shape = LlamaForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
589+ predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
590+ elif "chatglm" in config .architectures [0 ].lower ():
591+ from paddlenlp .experimental .transformers import (
592+ ChatGLMForCausalLMInferenceModel ,
593+ )
520594
521- cache_kvs_shape = LlamaForCausalLMInferenceModel .get_cache_kvs_shape (config , predictor_args .batch_size )
522- predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
595+ cache_kvs_shape = ChatGLMForCausalLMInferenceModel .get_cache_kvs_shape (
596+ config , predictor_args .batch_size
597+ )
598+ predictor = StaticInferencePredictor (predictor_args , cache_kvs_shape , tokenizer = tokenizer )
523599 else :
524600 raise ValueError ("the `mode` should be one of [dynamic, static]" )
525601 return predictor
0 commit comments