Skip to content

Commit 490e354

Browse files
committed
merge code from fastdeploy
1 parent d24e2ef commit 490e354

File tree

6 files changed

+250
-46
lines changed

6 files changed

+250
-46
lines changed

llm/server/server/server/data/processor.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def process_request(self, request, max_seq_len=None):
143143
request["eos_token_ids"] = []
144144
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
145145

146+
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
147+
self.update_stop_seq(request)
148+
146149
if "input_ids" not in request or \
147150
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
148151
if "text" in request:
@@ -282,7 +285,7 @@ def _load_tokenizer(self):
282285
"""
283286
if self.config.use_hf_tokenizer:
284287
from transformers import AutoTokenizer
285-
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
288+
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
286289
else:
287290
from paddlenlp.transformers import AutoTokenizer
288291
return AutoTokenizer.from_pretrained(self.config.model_dir)
@@ -334,3 +337,43 @@ def get_pad_id(self):
334337
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
335338
return self.tokenizer.eos_token
336339
return self.tokenizer.pad_token_id
340+
341+
def pad_batch_data(self, insts, pad_id=0, return_seq_len=False, return_array=True, pad_style="right"):
342+
"""Pad the instances to the max sequence length in batch."""
343+
if len(insts) == 0:
344+
padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
345+
if return_seq_len:
346+
seq_len = np.array([], dtype=np.int64) if return_array else []
347+
return padded_insts, seq_len
348+
return padded_insts
349+
350+
max_len = max(map(len, insts))
351+
if pad_style == "left":
352+
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
353+
else:
354+
padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
355+
if return_array:
356+
padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
357+
358+
if return_seq_len:
359+
seq_len = [len(inst) for inst in insts]
360+
if return_array:
361+
seq_len = np.array(seq_len, dtype=np.int64).reshape(-1, 1)
362+
return padded_insts, seq_len
363+
return padded_insts
364+
365+
def update_stop_seq(self, request):
366+
"""
367+
Update stop sequences from request.
368+
"""
369+
stop_seqs = []
370+
for seq in request.get("stop_sequences", []):
371+
if seq != self.tokenizer.eos_token_id:
372+
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
373+
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
374+
stop_seqs,
375+
pad_id=-1,
376+
return_seq_len=True,
377+
return_array=False
378+
)
379+
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")

llm/server/server/server/engine/config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from paddlenlp.generation import GenerationConfig
2121
from server.utils import model_server_logger
22+
from dataclasses import dataclass
2223

2324

2425
class Config:
@@ -203,6 +204,27 @@ def get_model_config(self):
203204
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
204205
return model_config_json
205206

207+
def get_speculate_config(self):
208+
"""
209+
get speculate_decoding related config
210+
211+
Returns:
212+
SpeculateConfig: the speculate related config
213+
"""
214+
speculate_config = SpeculateConfig()
215+
model_cfg = self.get_model_config()
216+
if model_cfg.get("speculate_method", "None") != "None":
217+
speculate_config.speculate_method = str(model_cfg["speculate_method"])
218+
speculate_config.speculate_max_draft_token_num = model_cfg[
219+
"speculate_max_draft_token_num"]
220+
speculate_config.speculate_max_ngram_size = model_cfg[
221+
"speculate_max_ngram_size"]
222+
223+
if speculate_config.speculate_method not in ["None", "inference_with_reference"]:
224+
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")
225+
226+
return speculate_config
227+
206228
def read_from_config(self):
207229
"""
208230
reset model config from json file
@@ -234,3 +256,10 @@ def get_unique_name(self, name):
234256

235257
def __str__(self) -> str:
236258
return json.dumps(self.__dict__, indent=4)
259+
260+
261+
@dataclass
262+
class SpeculateConfig:
263+
speculate_method: str = "None"
264+
speculate_max_draft_token_num: int = 1
265+
speculate_max_ngram_size: int = 1

llm/server/server/server/engine/infer.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from paddlenlp_ops import step_paddle
3030
from server.data.processor import DataProcessor
3131
from server.engine.config import Config
32+
from paddlenlp.experimental.transformers import InferenceWithReferenceProposer
3233
from server.utils import get_logger
3334
from task_queue_manager import TaskQueueManager
3435

@@ -46,12 +47,19 @@ def __init__(self, args):
4647

4748
self.config = Config()
4849
self.model_cfg = self.config.get_model_config()
50+
self.speculate_config = self.config.get_speculate_config()
51+
self.is_speculate_decoding = self.speculate_config.speculate_method != "None"
4952
self.format_print_configuration()
5053

5154
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
5255
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
5356
self.args.hidden_size = self.model_cfg["hidden_size"]
5457

58+
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
59+
60+
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
61+
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
62+
5563
self.nranks = dist.get_world_size()
5664
self.init_dist_env()
5765
self.rank = fleet.worker_index()
@@ -62,6 +70,17 @@ def __init__(self, args):
6270
self.cache_kvs = {}
6371
self.init_inputs()
6472

73+
if self.is_speculate_decoding:
74+
logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.')
75+
if self.speculate_config.speculate_method == "inference_with_reference":
76+
self.proposer = InferenceWithReferenceProposer(
77+
self.speculate_config.speculate_max_draft_token_num,
78+
self.speculate_config.speculate_max_ngram_size,
79+
self.args.max_batch_size,
80+
self.args.max_seq_len)
81+
else:
82+
self.proposer = None
83+
6584
self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
6685

6786
model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
@@ -246,6 +265,31 @@ def init_inputs(self):
246265
self.share_inputs['free_list_len'] = paddle.full(
247266
shape=[1], fill_value=self.free_list_len, dtype="int32")
248267

268+
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[self.max_stop_seqs_num,],
269+
fill_value=0,
270+
dtype="int32")
271+
self.share_inputs['stop_seqs'] = paddle.full(shape=[self.max_stop_seqs_num, self.stop_seqs_max_len],
272+
fill_value=-1,
273+
dtype="int64")
274+
275+
if self.reduce_dialogue_repetition:
276+
self.share_inputs["first_token_ids"] = paddle.full(
277+
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
278+
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
279+
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
280+
# speculate decoding input
281+
if self.is_speculate_decoding:
282+
self.share_inputs["accept_tokens"] = paddle.full(
283+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
284+
)
285+
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
286+
self.share_inputs["draft_tokens"] = paddle.full(
287+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
288+
)
289+
self.share_inputs["actual_draft_token_num"] = paddle.full(
290+
shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32"
291+
)
292+
249293
def dy_input_preprocess(self, tasks):
250294
"""
251295
dynamic insertion
@@ -279,6 +323,10 @@ def dy_input_preprocess(self, tasks):
279323
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
280324
self.share_inputs['stop_flags'][idx:idx + 1] = False
281325

326+
if self.reduce_dialogue_repetition:
327+
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
328+
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length
329+
282330
if "infer_seed" in task:
283331
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
284332

@@ -288,10 +336,29 @@ def dy_input_preprocess(self, tasks):
288336
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
289337
task['block_tables'], dtype="int32")
290338

339+
if "stop_seqs_len" in task:
340+
stop_seqs_num = len(task["stop_seqs_len"])
341+
for i in range(stop_seqs_num, self.max_stop_seqs_num):
342+
task["stop_seqs_len"].append(0)
343+
self.share_inputs['stop_seqs_len'][:] = np.array(
344+
task["stop_seqs_len"], dtype="int32")
345+
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
346+
task["stop_seqs"], dtype="int64")
347+
348+
if self.is_speculate_decoding:
349+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1])
350+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num])
351+
291352
def step_cuda(self, seq_lens_this_time):
292353
"""
293354
step cuda
294355
"""
356+
# whether speculate decoding
357+
if self.is_speculate_decoding:
358+
speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1
359+
else:
360+
speculate_step_token_num = 0
361+
295362
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
296363
self.share_inputs['step_seq_lens_encoder'],
297364
self.share_inputs['seq_lens_encoder'],
@@ -304,7 +371,8 @@ def step_cuda(self, seq_lens_this_time):
304371
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
305372
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
306373
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
307-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
374+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
375+
speculate_step_token_num)
308376

309377
def initialize_engine_ready_check_flag(self):
310378
"""
@@ -429,6 +497,13 @@ def run(self):
429497
time.sleep(0.001)
430498
continue
431499

500+
if self.proposer is not None:
501+
self.proposer.run(
502+
self.share_inputs,
503+
real_batch_size=seq_lens_this_time.shape[0],
504+
seq_lens_this_time=seq_lens_this_time,
505+
)
506+
432507
self.infer_engine.predictor.run()
433508
self.share_inputs['infer_seed'].add_(infer_seed_increment)
434509
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
@@ -474,6 +549,11 @@ def _init_predictor(self):
474549
config.switch_ir_optim(False)
475550
config.enable_use_gpu(100, device_id)
476551

552+
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
553+
if pir_flag == 1:
554+
config.enable_new_executor()
555+
config.enable_new_ir()
556+
477557
# distributed config
478558
if self.mp_degree > 1:
479559
trainer_endpoints = fleet.worker_endpoints()
@@ -528,7 +608,7 @@ def parse_args():
528608
"""
529609
parse args from command line
530610
"""
531-
parser = argparse.ArgumentParser("Deploy LLM Inference")
611+
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
532612
parser.add_argument('-m',
533613
'--model_dir',
534614
type=str,

0 commit comments

Comments
 (0)