Skip to content

Commit 2224ca3

Browse files
authored
Update predict_generation.py
for chatglm2
1 parent 9eb3cd7 commit 2224ca3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

llm/chatglm/predict_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
chatglm_pad_attention_mask,
2222
chatglm_postprocess_past_key_value,
2323
)
24-
from paddlenlp.transformers import ChatGLMConfig, ChatGLMForCausalLM, ChatGLMTokenizer
24+
from paddlenlp.transformers import ChatGLMConfig, AutoTokenizer, AutoModelForCausalLM
2525

2626

2727
def parse_arguments():
@@ -62,7 +62,7 @@ def __init__(self, args=None, tokenizer=None, model=None, **kwargs):
6262
self.src_length = kwargs["src_length"]
6363
self.tgt_length = kwargs["tgt_length"]
6464
else:
65-
self.tokenizer = ChatGLMTokenizer.from_pretrained(args.model_name_or_path)
65+
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
6666
self.batch_size = args.batch_size
6767
self.args = args
6868
self.src_length = self.args.src_length
@@ -92,7 +92,7 @@ def __init__(self, args=None, tokenizer=None, model=None, **kwargs):
9292
config = ChatGLMConfig.from_pretrained(args.model_name_or_path)
9393
dtype = config.dtype if config.dtype is not None else config.paddle_dtype
9494

95-
self.model = ChatGLMForCausalLM.from_pretrained(
95+
self.model = AutoModelForCausalLM.from_pretrained(
9696
args.model_name_or_path,
9797
tensor_parallel_degree=tensor_parallel_degree,
9898
tensor_parallel_rank=tensor_parallel_rank,

0 commit comments

Comments
 (0)