Skip to content

Commit 7903bcc

Browse files
Support chatglm fine grained dybatch v1. (#6798)
* support_chatglm_fine * unify_fused_layer * fix_utils * delete_model_type * delete_set_state_dict_external * fix_useless_log * add_update_kwags_self --------- Co-authored-by: zhengzekang <>
1 parent 71354f8 commit 7903bcc

File tree

8 files changed

+906
-104
lines changed

8 files changed

+906
-104
lines changed

llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ python predictor.py \
244244
python export_model.py \
245245
--model_name_or_path meta-llama/Llama-2-7b-chat \
246246
--output_path ./inference \
247-
--dtype float16 \
247+
--dtype float16
248248

249249

250250
# 静态图模型推理

llm/predictor.py

Lines changed: 120 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -271,20 +271,32 @@ def __init__(
271271
else:
272272
raise ValueError("Please specific the model dtype.")
273273

274+
self.model_config = AutoConfig.from_pretrained(config.model_name_or_path)
274275
self.dtype = dtype
275-
276+
self.architectures = self.model_config.architectures[0].lower()
276277
self.cache_kvs = [paddle.zeros(shape, dtype=dtype) for shape in cache_kv_shapes]
277278
self.pre_ids = paddle.full([config.batch_size, config.max_length + 1], -1, dtype="int64")
278-
self.attention_mask = paddle.zeros(
279-
shape=(config.batch_size, 1, config.max_length, config.max_length),
280-
dtype=dtype,
281-
)
279+
280+
if "chatglm" in self.architectures:
281+
self.attention_mask = paddle.ones(
282+
shape=(config.batch_size, 1, config.max_length, config.max_length),
283+
dtype=dtype,
284+
)
285+
self.tgt_pos = paddle.ones(
286+
shape=[config.batch_size, 2, 1],
287+
dtype="int64",
288+
)
289+
else:
290+
self.attention_mask = paddle.zeros(
291+
shape=(config.batch_size, 1, config.max_length, config.max_length),
292+
dtype=dtype,
293+
)
294+
282295
self.tgt_generation_mask = paddle.zeros(
283296
shape=[config.batch_size, 1, 1, config.max_length + 1],
284297
dtype=dtype,
285298
)
286299
self.predictor = self._create_predictor(config)
287-
self.model_config = AutoConfig.from_pretrained(config.model_name_or_path)
288300

289301
def _create_predictor(self, predictor_args: PredictorArgument):
290302
if not is_paddlenlp_ops_available():
@@ -327,16 +339,30 @@ def _create_predictor(self, predictor_args: PredictorArgument):
327339
return predictor
328340

329341
def _preprocess(self, source):
330-
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length)
331-
for i in range(inputs["input_ids"].shape[0]):
332-
length = inputs["seq_len_encoder"][i][0]
333-
self.attention_mask[i, 0, :length, :length] = paddle.tril(
334-
paddle.ones(shape=(length, length), dtype="float16")
335-
)
336-
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
342+
if "chatglm" in self.architectures:
343+
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length, self.architectures)
344+
345+
for i in range(inputs["input_ids"].shape[0]):
346+
length = inputs["seq_len_encoder"][i][0]
347+
self.attention_mask[i, 0, :length, :length] = 0
348+
self.attention_mask[i, 0, : length - 1, length - 1] = 1
349+
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
350+
self.tgt_pos[i, 0, 0] = paddle.to_tensor([length], dtype="int64")
351+
352+
inputs["attention_mask"] = self.attention_mask
353+
inputs["tgt_generation_mask"] = self.tgt_generation_mask
354+
inputs["tgt_pos"] = self.tgt_pos.numpy()
355+
else:
356+
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length, self.architectures)
357+
for i in range(inputs["input_ids"].shape[0]):
358+
length = inputs["seq_len_encoder"][i][0]
359+
self.attention_mask[i, 0, :length, :length] = paddle.tril(
360+
paddle.ones(shape=(length, length), dtype="float16")
361+
)
362+
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
337363

338-
inputs["attention_mask"] = self.attention_mask
339-
inputs["tgt_generation_mask"] = self.tgt_generation_mask
364+
inputs["attention_mask"] = self.attention_mask
365+
inputs["tgt_generation_mask"] = self.tgt_generation_mask
340366
return inputs
341367

342368
@paddle.no_grad()
@@ -387,33 +413,61 @@ def __init__(
387413
raise ValueError("Please specific the model dtype.")
388414

389415
self.dtype = dtype
416+
self.architectures = self.model.config.architectures[0].lower()
390417

391418
self.cache_kvs = [
392419
paddle.zeros(shape, dtype=dtype)
393420
for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size)
394421
]
395422
self.pre_ids = paddle.full([config.max_batch_size, config.max_length], -1, dtype="int64")
396-
self.attention_mask = paddle.zeros(
397-
shape=(config.max_batch_size, 1, config.max_length, config.max_length),
398-
dtype=dtype,
399-
)
423+
if "chatglm" in self.architectures:
424+
self.attention_mask = paddle.ones(
425+
shape=(config.batch_size, 1, config.max_length, config.max_length),
426+
dtype=dtype,
427+
)
428+
self.tgt_pos = paddle.ones(
429+
shape=[config.batch_size, 2, 1],
430+
dtype="int64",
431+
)
432+
else:
433+
self.attention_mask = paddle.zeros(
434+
shape=(config.batch_size, 1, config.max_length, config.max_length),
435+
dtype=dtype,
436+
)
437+
400438
self.tgt_generation_mask = paddle.zeros(
401439
shape=[config.max_batch_size, 1, 1, config.max_length],
402440
dtype=dtype,
403441
)
404442

405443
def _preprocess(self, source):
406-
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length)
407-
for i in range(inputs["input_ids"].shape[0]):
408-
length = inputs["seq_len_encoder"][i][0]
409-
self.attention_mask[i, 0, :length, :length] = paddle.tril(
410-
paddle.ones(shape=(length, length), dtype="float16")
411-
)
444+
if "chatglm" in self.architectures:
445+
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length, self.architectures)
446+
447+
for i in range(inputs["input_ids"].shape[0]):
448+
length = inputs["seq_len_encoder"][i][0]
449+
self.attention_mask[i, 0, :length, :length] = 0
450+
self.attention_mask[i, 0, : length - 1, length - 1] = 1
451+
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
452+
self.tgt_pos[i, 0, 0] = paddle.to_tensor([length], dtype="int64")
453+
412454
inputs["attention_mask"] = self.attention_mask
413-
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
414455
inputs["tgt_generation_mask"] = self.tgt_generation_mask
415-
inputs["cache_kvs"] = self.cache_kvs
416-
inputs["pre_ids"] = self.pre_ids
456+
inputs["cache_kvs"] = self.cache_kvs
457+
inputs["pre_ids"] = self.pre_ids
458+
inputs["tgt_pos"] = self.tgt_pos
459+
else:
460+
inputs = dybatch_preprocess(self.tokenizer, source, self.config.max_length, self.architectures)
461+
for i in range(inputs["input_ids"].shape[0]):
462+
length = inputs["seq_len_encoder"][i][0]
463+
self.attention_mask[i, 0, :length, :length] = paddle.tril(
464+
paddle.ones(shape=(length, length), dtype="float16")
465+
)
466+
inputs["attention_mask"] = self.attention_mask
467+
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype="float16")
468+
inputs["tgt_generation_mask"] = self.tgt_generation_mask
469+
inputs["cache_kvs"] = self.cache_kvs
470+
inputs["pre_ids"] = self.pre_ids
417471

418472
inputs_tensor = {}
419473
for key, value in inputs.items():
@@ -497,29 +551,51 @@ def create_predictor(
497551
else:
498552
if predictor_args.mode == "dynamic":
499553
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
500-
assert (
501-
"llama" in predictor_args.model_name_or_path
502-
), "only support llama inference model in dygraph-inference predictor"
503-
from paddlenlp.experimental.transformers import (
504-
LlamaForCausalLMInferenceModel,
505-
)
506-
507554
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
555+
if "llama" in config.architectures[0].lower():
556+
from paddlenlp.experimental.transformers import (
557+
LlamaForCausalLMInferenceModel,
558+
)
559+
560+
config.tensor_parallel_degree = tensor_parallel_degree
561+
config.tensor_parallel_rank = tensor_parallel_rank
562+
model = LlamaForCausalLMInferenceModel.from_pretrained(
563+
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
564+
)
565+
model.eval()
566+
elif "chatglm" in config.architectures[0].lower():
567+
from paddlenlp.experimental.transformers import (
568+
ChatGLMForCausalLMInferenceModel,
569+
)
570+
571+
config.tensor_parallel_degree = tensor_parallel_degree
572+
config.tensor_parallel_rank = tensor_parallel_rank
508573

509-
config.tensor_parallel_degree = tensor_parallel_degree
510-
config.tensor_parallel_rank = tensor_parallel_rank
511-
model = LlamaForCausalLMInferenceModel.from_pretrained(predictor_args.model_name_or_path, config=config)
574+
model = ChatGLMForCausalLMInferenceModel.from_pretrained(
575+
predictor_args.model_name_or_path,
576+
config=config,
577+
dtype=predictor_args.dtype,
578+
)
579+
model.eval()
512580
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
513581
elif predictor_args.mode == "static":
514582
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
583+
if "llama" in config.architectures[0].lower():
584+
from paddlenlp.experimental.transformers import (
585+
LlamaForCausalLMInferenceModel,
586+
)
515587

516-
# only support llama inference model currently
517-
from paddlenlp.experimental.transformers import (
518-
LlamaForCausalLMInferenceModel,
519-
)
588+
cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
589+
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
590+
elif "chatglm" in config.architectures[0].lower():
591+
from paddlenlp.experimental.transformers import (
592+
ChatGLMForCausalLMInferenceModel,
593+
)
520594

521-
cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
522-
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
595+
cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape(
596+
config, predictor_args.batch_size
597+
)
598+
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
523599
else:
524600
raise ValueError("the `mode` should be one of [dynamic, static]")
525601
return predictor

llm/utils.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -337,32 +337,54 @@ def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
337337
return inst_data.astype("int64").reshape([-1, max_len])
338338

339339

340-
def dybatch_preprocess(tokenizer, texts: list[str], max_length: int):
340+
def dybatch_preprocess(tokenizer, texts: list[str], max_length: int, architectures: str):
341341
"""Pre-process generation inputs."""
342-
input_ids = []
343-
if isinstance(texts, str):
344-
texts = [texts]
345-
346-
for text in texts:
347-
tokens = tokenizer(
348-
text,
349-
return_tensors="np",
350-
padding=False,
351-
return_attention_mask=False,
352-
return_token_type_ids=False,
353-
)
354-
input_ids.append(tokens["input_ids"][0])
355-
356-
inputs = {}
357-
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
358-
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
359-
bs = inputs["input_ids"].shape[0]
360-
max_len = max(map(len, input_ids))
361-
362-
position_ids = paddle.zeros(shape=[bs, max_len], dtype="int64")
363-
for i in range(bs):
364-
position_ids[i, : seq_len[i]] = paddle.arange(seq_len[i])
365-
inputs["position_ids"] = position_ids
342+
if "chatglm" in architectures:
343+
input_ids = []
344+
position_ids = []
345+
346+
for text in texts:
347+
tokens = tokenizer(text, return_tensors="np", padding=True)
348+
input_ids.append(tokens["input_ids"][0])
349+
position_ids.append(tokens["position_ids"][0])
350+
351+
inputs = {}
352+
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][0]
353+
354+
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
355+
bs = inputs["input_ids"].shape[0]
356+
max_len = max(map(len, input_ids))
357+
358+
inst_data_pos = []
359+
for i in range(len(position_ids)):
360+
inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]]))
361+
inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos))
362+
else:
363+
input_ids = []
364+
position_ids = []
365+
if isinstance(texts, str):
366+
texts = [texts]
367+
368+
for text in texts:
369+
tokens = tokenizer(
370+
text,
371+
return_tensors="np",
372+
padding=False,
373+
return_attention_mask=False,
374+
return_token_type_ids=False,
375+
)
376+
input_ids.append(tokens["input_ids"][0])
377+
378+
inputs = {}
379+
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
380+
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
381+
bs = inputs["input_ids"].shape[0]
382+
max_len = max(map(len, input_ids))
383+
384+
position_ids = paddle.zeros(shape=[bs, max_len], dtype="int64")
385+
for i in range(bs):
386+
position_ids[i, : seq_len[i]] = paddle.arange(seq_len[i])
387+
inputs["position_ids"] = position_ids
366388

367389
tgt_ids = [input[-1:] for input in input_ids]
368390
tgt_pos = []

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .chatglm import *
1516
from .fused_transformer_layers import *
1617
from .llama import *
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .modeling import *

0 commit comments

Comments
 (0)