2929from paddlenlp_ops import step_paddle
3030from server .data .processor import DataProcessor
3131from server .engine .config import Config
32+ from paddlenlp .experimental .transformers import InferenceWithReferenceProposer
3233from server .utils import get_logger
3334from task_queue_manager import TaskQueueManager
3435
@@ -46,12 +47,19 @@ def __init__(self, args):
4647
4748 self .config = Config ()
4849 self .model_cfg = self .config .get_model_config ()
50+ self .speculate_config = self .config .get_speculate_config ()
51+ self .is_speculate_decoding = self .speculate_config .speculate_method != "None"
4952 self .format_print_configuration ()
5053
5154 self .args .num_layers = self .get_value (self .model_cfg , ["num_hidden_layers" , "num_layers" ])
5255 self .args .num_attention_heads = self .get_value (self .model_cfg , ["num_attention_heads" , "n_head" ])
5356 self .args .hidden_size = self .model_cfg ["hidden_size" ]
5457
58+ self .reduce_dialogue_repetition = int (os .environ .get ("REDUCE_DIALOGUE_REPETITION" , 0 ))
59+
60+ self .max_stop_seqs_num = int (os .getenv ("MAX_STOP_SEQS_NUM" , 5 ))
61+ self .stop_seqs_max_len = int (os .getenv ("STOP_SEQS_MAX_LEN" , 8 ))
62+
5563 self .nranks = dist .get_world_size ()
5664 self .init_dist_env ()
5765 self .rank = fleet .worker_index ()
@@ -62,6 +70,17 @@ def __init__(self, args):
6270 self .cache_kvs = {}
6371 self .init_inputs ()
6472
73+ if self .is_speculate_decoding :
74+ logger .info (f'Using speculate decoding, method: { self .speculate_config .speculate_method } .' )
75+ if self .speculate_config .speculate_method == "inference_with_reference" :
76+ self .proposer = InferenceWithReferenceProposer (
77+ self .speculate_config .speculate_max_draft_token_num ,
78+ self .speculate_config .speculate_max_ngram_size ,
79+ self .args .max_batch_size ,
80+ self .args .max_seq_len )
81+ else :
82+ self .proposer = None
83+
6584 self .infer_queue = TaskQueueManager (rank = self .rank , mp_num = self .nranks , port = self .config .infer_port )
6685
6786 model_rank_path = os .path .join (self .args .model_dir , f"rank_{ self .rank } " )
@@ -246,6 +265,31 @@ def init_inputs(self):
246265 self .share_inputs ['free_list_len' ] = paddle .full (
247266 shape = [1 ], fill_value = self .free_list_len , dtype = "int32" )
248267
268+ self .share_inputs ['stop_seqs_len' ] = paddle .full (shape = [self .max_stop_seqs_num ,],
269+ fill_value = 0 ,
270+ dtype = "int32" )
271+ self .share_inputs ['stop_seqs' ] = paddle .full (shape = [self .max_stop_seqs_num , self .stop_seqs_max_len ],
272+ fill_value = - 1 ,
273+ dtype = "int64" )
274+
275+ if self .reduce_dialogue_repetition :
276+ self .share_inputs ["first_token_ids" ] = paddle .full (
277+ shape = [self .args .max_batch_size , 1 ], fill_value = - 1 , dtype = "int64" )
278+ self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
279+ shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
280+ # speculate decoding input
281+ if self .is_speculate_decoding :
282+ self .share_inputs ["accept_tokens" ] = paddle .full (
283+ shape = [self .args .max_batch_size , self .speculate_config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
284+ )
285+ self .share_inputs ["accept_num" ] = paddle .full (shape = [self .args .max_batch_size ], fill_value = 0 , dtype = "int32" )
286+ self .share_inputs ["draft_tokens" ] = paddle .full (
287+ shape = [self .args .max_batch_size , self .speculate_config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
288+ )
289+ self .share_inputs ["actual_draft_token_num" ] = paddle .full (
290+ shape = [self .args .max_batch_size ], fill_value = self .speculate_config .speculate_max_draft_token_num , dtype = "int32"
291+ )
292+
249293 def dy_input_preprocess (self , tasks ):
250294 """
251295 dynamic insertion
@@ -279,6 +323,10 @@ def dy_input_preprocess(self, tasks):
279323 self .share_inputs ['max_length' ][idx :idx + 1 ] = max_dec_len
280324 self .share_inputs ['stop_flags' ][idx :idx + 1 ] = False
281325
326+ if self .reduce_dialogue_repetition :
327+ self .share_inputs ['first_token_ids' ][idx :idx + 1 ] = self .share_inputs ['input_ids' ][idx :idx + 1 , :1 ]
328+ self .share_inputs ["ori_seq_lens_encoder" ][idx :idx + 1 ] = length
329+
282330 if "infer_seed" in task :
283331 self .share_inputs ['infer_seed' ][idx :idx + 1 ] = task ['infer_seed' ]
284332
@@ -288,10 +336,29 @@ def dy_input_preprocess(self, tasks):
288336 self .share_inputs ["block_tables" ][idx :idx + 1 , :encoder_block_num ] = np .array (
289337 task ['block_tables' ], dtype = "int32" )
290338
339+ if "stop_seqs_len" in task :
340+ stop_seqs_num = len (task ["stop_seqs_len" ])
341+ for i in range (stop_seqs_num , self .max_stop_seqs_num ):
342+ task ["stop_seqs_len" ].append (0 )
343+ self .share_inputs ['stop_seqs_len' ][:] = np .array (
344+ task ["stop_seqs_len" ], dtype = "int32" )
345+ self .share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
346+ task ["stop_seqs" ], dtype = "int64" )
347+
348+ if self .is_speculate_decoding :
349+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .speculate_config .speculate_max_draft_token_num + 1 ])
350+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .speculate_config .speculate_max_draft_token_num ])
351+
291352 def step_cuda (self , seq_lens_this_time ):
292353 """
293354 step cuda
294355 """
356+ # whether speculate decoding
357+ if self .is_speculate_decoding :
358+ speculate_step_token_num = self .speculate_config .speculate_max_draft_token_num + 1
359+ else :
360+ speculate_step_token_num = 0
361+
295362 step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
296363 self .share_inputs ['step_seq_lens_encoder' ],
297364 self .share_inputs ['seq_lens_encoder' ],
@@ -304,7 +371,8 @@ def step_cuda(self, seq_lens_this_time):
304371 self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
305372 self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
306373 self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
307- self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id )
374+ self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id ,
375+ speculate_step_token_num )
308376
309377 def initialize_engine_ready_check_flag (self ):
310378 """
@@ -429,6 +497,13 @@ def run(self):
429497 time .sleep (0.001 )
430498 continue
431499
500+ if self .proposer is not None :
501+ self .proposer .run (
502+ self .share_inputs ,
503+ real_batch_size = seq_lens_this_time .shape [0 ],
504+ seq_lens_this_time = seq_lens_this_time ,
505+ )
506+
432507 self .infer_engine .predictor .run ()
433508 self .share_inputs ['infer_seed' ].add_ (infer_seed_increment )
434509 self .share_inputs ['infer_seed' ][:] %= self .MAX_INFER_SEED
@@ -474,6 +549,11 @@ def _init_predictor(self):
474549 config .switch_ir_optim (False )
475550 config .enable_use_gpu (100 , device_id )
476551
552+ pir_flag = int (os .environ .get ("FLAGS_enable_pir_api" , 0 ))
553+ if pir_flag == 1 :
554+ config .enable_new_executor ()
555+ config .enable_new_ir ()
556+
477557 # distributed config
478558 if self .mp_degree > 1 :
479559 trainer_endpoints = fleet .worker_endpoints ()
@@ -528,7 +608,7 @@ def parse_args():
528608 """
529609 parse args from command line
530610 """
531- parser = argparse .ArgumentParser ("Deploy LLM Inference" )
611+ parser = argparse .ArgumentParser ("FastDeploy LLM Inference" )
532612 parser .add_argument ('-m' ,
533613 '--model_dir' ,
534614 type = str ,
0 commit comments