Skip to content

Commit af28006

Browse files
authored
[LLM] Support gpt3 fine grained dybatch v1 (#7080)
* support batch_size=1 * support batch_size > 1 * update * support to_static * add benchmark for dybatch_preprocess * fix code style * fix code style * fix comment * update
1 parent d3c2e4c commit af28006

File tree

7 files changed

+537
-12
lines changed

7 files changed

+537
-12
lines changed

llm/predictor.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def _preprocess(self, source):
357357
self.architectures,
358358
top_p=self.config.top_p,
359359
temperature=self.config.temperature,
360+
benchmark=self.config.benchmark,
360361
)
361362
for i in range(inputs["input_ids"].shape[0]):
362363
length = inputs["seq_len_encoder"][i][0]
@@ -375,6 +376,7 @@ def _preprocess(self, source):
375376
self.architectures,
376377
top_p=self.config.top_p,
377378
temperature=self.config.temperature,
379+
benchmark=self.config.benchmark,
378380
)
379381
for i in range(inputs["input_ids"].shape[0]):
380382
length = inputs["seq_len_encoder"][i][0]
@@ -439,6 +441,7 @@ def _preprocess(self, source):
439441
top_p=self.config.top_p,
440442
temperature=self.config.temperature,
441443
pre_caches_length=pre_caches_length,
444+
benchmark=self.config.benchmark,
442445
)
443446

444447
for i in range(inputs["input_ids"].shape[0]):
@@ -664,8 +667,6 @@ def create_predictor(
664667
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
665668
)
666669

667-
config.tensor_parallel_degree = tensor_parallel_degree
668-
config.tensor_parallel_rank = tensor_parallel_rank
669670
config.quant_bits = -1
670671

671672
if predictor_args.quant_type.startswith("weight_only_int"):
@@ -692,15 +693,26 @@ def create_predictor(
692693
BloomForCausalLMInferenceModel,
693694
)
694695

695-
config.tensor_parallel_degree = tensor_parallel_degree
696-
config.tensor_parallel_rank = tensor_parallel_rank
697696
model = BloomForCausalLMInferenceModel.from_pretrained(
698697
predictor_args.model_name_or_path,
699698
config=config,
700699
dtype=predictor_args.dtype,
701700
)
702701
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
703702
model.eval()
703+
elif "gpt" in config.architectures[0].lower():
704+
from paddlenlp.experimental.transformers import (
705+
GPTForCausalLMInferenceModel,
706+
)
707+
708+
model = GPTForCausalLMInferenceModel.from_pretrained(
709+
predictor_args.model_name_or_path,
710+
config=config,
711+
dtype=predictor_args.dtype,
712+
)
713+
model.eval()
714+
else:
715+
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
704716
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
705717
elif predictor_args.mode == "static":
706718
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
@@ -710,7 +722,6 @@ def create_predictor(
710722
)
711723

712724
cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
713-
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
714725
elif "chatglm" in config.architectures[0].lower():
715726
from paddlenlp.experimental.transformers import (
716727
ChatGLMForCausalLMInferenceModel,
@@ -719,16 +730,21 @@ def create_predictor(
719730
cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape(
720731
config, predictor_args.batch_size
721732
)
722-
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
723733
elif "bloom" in config.architectures[0].lower():
724734
from paddlenlp.experimental.transformers import (
725735
BloomForCausalLMInferenceModel,
726736
)
727737

728738
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
729-
predictor = StaticInferencePredictor(
730-
predictor_args, cache_kvs_shape=cache_kvs_shape, tokenizer=tokenizer
739+
elif "gpt" in config.architectures[0].lower():
740+
from paddlenlp.experimental.transformers import (
741+
GPTForCausalLMInferenceModel,
731742
)
743+
744+
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
745+
else:
746+
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
747+
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
732748
else:
733749
raise ValueError("the `mode` should be one of [dynamic, static]")
734750
return predictor

llm/utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def dybatch_preprocess(
394394
benchmark: bool = False,
395395
):
396396
"""Pre-process generation inputs."""
397+
inputs = {}
397398
if "chatglm" in architectures:
398399
input_ids = []
399400
position_ids = []
@@ -403,7 +404,6 @@ def dybatch_preprocess(
403404
input_ids.append(tokens["input_ids"][0])
404405
position_ids.append(tokens["position_ids"][0])
405406

406-
inputs = {}
407407
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][0]
408408
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
409409
bs = inputs["input_ids"].shape[0]
@@ -413,9 +413,35 @@ def dybatch_preprocess(
413413
for i in range(len(position_ids)):
414414
inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]]))
415415
inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos))
416+
elif "gpt" in architectures:
417+
input_ids = []
418+
if isinstance(texts, str):
419+
texts = [texts]
420+
421+
for text in texts:
422+
tokens = tokenizer(
423+
text,
424+
return_tensors="np",
425+
padding=False,
426+
max_length=src_length,
427+
return_attention_mask=False,
428+
return_token_type_ids=False,
429+
)
430+
input_ids.append(tokens["input_ids"][0])
431+
432+
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
433+
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
434+
bs = inputs["input_ids"].shape[0]
435+
max_len = max(map(len, input_ids))
436+
437+
position_ids = paddle.arange(sum(seq_len), dtype="int64")
438+
pre_len = seq_len[0]
439+
for length in seq_len[1:]:
440+
position_ids[pre_len : length + pre_len] = position_ids[pre_len : length + pre_len] - pre_len
441+
pre_len += length
442+
inputs["position_ids"] = position_ids
416443
else:
417444
input_ids = []
418-
position_ids = []
419445
if isinstance(texts, str):
420446
texts = [texts]
421447

@@ -430,7 +456,6 @@ def dybatch_preprocess(
430456
)
431457
input_ids.append(tokens["input_ids"][0])
432458

433-
inputs = {}
434459
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
435460
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
436461
bs = inputs["input_ids"].shape[0]

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
from .bloom import *
1616
from .chatglm import *
1717
from .fused_transformer_layers import *
18+
from .gpt import *
1819
from .llama import *

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,10 @@ def __init__(
186186
self.norm_type = norm_type
187187
if norm_type == "layernorm":
188188
self.norm_func = fused_layer_norm
189-
else:
189+
elif norm_type == "rmsnorm":
190190
self.norm_func = fused_rms_norm
191+
else:
192+
raise NotImplementedError("Only support norm type of [layernorm, rmsnorm]")
191193
self.use_neox_rotary_style = use_neox_rotary_style
192194
self._norm_weight_dtype = "float32" if self.norm_type == "layernorm" else self._dtype
193195

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def to_static(self, output_path: str, config: dict):
115115
shape=[None, None, None], dtype="int64", name="position_ids"
116116
) # position_ids
117117
input_spec[16] = paddle.static.InputSpec(shape=[None, 2, 1], dtype="int64", name="tgt_pos") # tgt_pos
118+
elif self.config["model_type"] and "gpt" in self.config.model_type:
119+
input_spec[2] = paddle.static.InputSpec(shape=[None], dtype="int64", name="position_ids") # position_ids
118120
model = paddle.jit.to_static(self.generate, input_spec=input_spec)
119121
paddle.jit.save(
120122
model, output_path, skip_prune_program=True
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .modeling import *

0 commit comments

Comments
 (0)