2727from paddle .base .framework import in_cinn_mode , in_pir_executor_mode , use_pir_api
2828from paddle .distributed import fleet
2929
30- from llm .speculate_decoding .proposer import InferenceWithReferenceProposer
30+ from llm .speculate_decoding .proposers import InferenceWithReferenceProposer
3131from paddlenlp .generation import GenerationConfig , TextIteratorStreamer
3232from paddlenlp .peft import LoRAConfig , LoRAModel , PrefixConfig , PrefixModelForCausalLM
3333from paddlenlp .taskflow .utils import static_mode_guard
4949
5050# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
5151MAX_BSZ = 512
52+ # Note(@Wanglongzhi2001): SPECULATE_MAX_BSZ must be the same as definition in speculate_get_output / speculate_save_output
53+ SPECULATE_MAX_BSZ = 256
5254MAX_DRAFT_TOKENS = 6
5355
5456
@@ -106,7 +108,7 @@ class PredictorArgument:
106108 default = "fp16" ,
107109 metadata = {"help" : "avx cachekv type. Supported values: fp16,int8" },
108110 )
109- batch_size : int = field (default = 1 , metadata = {"help" : "The batch size of data." })
111+ batch_size : int = field (default = 10 , metadata = {"help" : "The batch size of data." })
110112 benchmark : bool = field (
111113 default = False ,
112114 metadata = {
@@ -142,7 +144,7 @@ class PredictorArgument:
142144 "help" : "speculate method, it should be one of ['None', 'autoregressive', 'inference_with_reference']"
143145 },
144146 )
145- speculate_max_draft_tokens : int = field (
147+ speculate_max_draft_token_num : int = field (
146148 default = 1 ,
147149 metadata = {"help" : "the max length of draft tokens for speculate method." },
148150 )
@@ -1180,7 +1182,7 @@ def __init__(
11801182 # init speculate components
11811183 if config .speculate_method == "inference_with_reference" :
11821184 self .proposer = InferenceWithReferenceProposer (
1183- config .speculate_max_draft_tokens ,
1185+ config .speculate_max_draft_token_num ,
11841186 config .speculate_max_ngram_size ,
11851187 config .batch_size ,
11861188 config .max_length ,
@@ -1192,7 +1194,7 @@ def __init__(
11921194 def predict (self , input_texts : list [str ], return_tokens = False ):
11931195 self ._preprocess (input_texts )
11941196
1195- # Parameters such as seq_lens_encoder have been set in the preprocessor function,
1197+ # Parameters such as seq_lens_encoder have been set in the preprocessor function,
11961198 # then we use them to init the proposer's args
11971199 self .init_proposer_args ()
11981200
@@ -1206,7 +1208,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
12061208 read_res_process .start ()
12071209
12081210 output_tensor = paddle .full (
1209- shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
1211+ shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
12101212 ).cpu ()
12111213 tensor_queue .put (output_tensor )
12121214 if self .tensor_parallel_rank == 0 :
@@ -1250,14 +1252,14 @@ def _preprocess(self, input_text: list[str]):
12501252
12511253 def init_proposer_args (self ):
12521254 self .model_inputs ["accept_tokens" ] = paddle .full (
1253- shape = [self .config .batch_size , self .config .speculate_max_draft_tokens + 1 ], fill_value = 0 , dtype = "int64"
1255+ shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12541256 )
12551257 self .model_inputs ["accept_num" ] = paddle .full (shape = [self .config .batch_size ], fill_value = 0 , dtype = "int32" )
12561258 self .model_inputs ["draft_tokens" ] = paddle .full (
1257- shape = [self .config .batch_size , self .config .speculate_max_draft_tokens + 1 ], fill_value = 0 , dtype = "int64"
1259+ shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12581260 )
12591261 self .model_inputs ["actual_draft_token_num" ] = paddle .full (
1260- shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_tokens , dtype = "int32"
1262+ shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_token_num , dtype = "int32"
12611263 )
12621264 if self .config .speculate_method == "inference_with_reference" :
12631265 self .proposer .input_ids_cpu = self .model_inputs ["input_ids" ].cpu ()
@@ -1275,7 +1277,7 @@ def __init__(
12751277 # init speculate components
12761278 if config .speculate_method == "inference_with_reference" :
12771279 self .proposer = InferenceWithReferenceProposer (
1278- config .speculate_max_draft_tokens ,
1280+ config .speculate_max_draft_token_num ,
12791281 config .speculate_max_ngram_size ,
12801282 config .batch_size ,
12811283 config .max_length ,
@@ -1285,14 +1287,14 @@ def __init__(
12851287
12861288 def init_proposer_args (self ):
12871289 self .model_inputs ["accept_tokens" ] = paddle .full (
1288- shape = [self .config .batch_size , self .config .speculate_max_draft_tokens + 1 ], fill_value = 0 , dtype = "int64"
1290+ shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12891291 )
12901292 self .model_inputs ["accept_num" ] = paddle .full (shape = [self .config .batch_size ], fill_value = 0 , dtype = "int32" )
12911293 self .model_inputs ["draft_tokens" ] = paddle .full (
1292- shape = [self .config .batch_size , self .config .speculate_max_draft_tokens + 1 ], fill_value = 0 , dtype = "int64"
1294+ shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12931295 )
12941296 self .model_inputs ["actual_draft_token_num" ] = paddle .full (
1295- shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_tokens , dtype = "int32"
1297+ shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_token_num , dtype = "int32"
12961298 )
12971299 if self .config .speculate_method == "inference_with_reference" :
12981300 self .proposer .input_ids_cpu = self .model_inputs ["input_ids" ].cpu ()
@@ -1309,7 +1311,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
13091311 s_time = time .time ()
13101312 self ._preprocess (input_texts )
13111313
1312- # Parameters such as seq_lens_encoder have been set in the preprocessor function,
1314+ # Parameters such as seq_lens_encoder have been set in the preprocessor function,
13131315 # then we use them to init the proposer's args
13141316 self .init_proposer_args ()
13151317 logger .info (f"preprocess spend { time .time () - s_time } " )
@@ -1332,7 +1334,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
13321334 if self .tensor_parallel_rank == 0 :
13331335 read_res_process .start ()
13341336 output_tensor = paddle .full (
1335- shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
1337+ shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
13361338 ).cpu ()
13371339 tensor_queue .put (output_tensor )
13381340 if self .tensor_parallel_rank == 0 :
@@ -1505,7 +1507,7 @@ def create_predictor(
15051507 elif predictor_args .speculate_method is not None :
15061508 config .max_seq_len = predictor_args .total_max_length
15071509 config .block_size = predictor_args .block_size
1508- config .speculate_max_draft_tokens = predictor_args .speculate_max_draft_tokens
1510+ config .speculate_max_draft_token_num = predictor_args .speculate_max_draft_token_num
15091511 config .speculate_max_ngram_size = predictor_args .speculate_max_ngram_size
15101512 config .speculate_verify_window = predictor_args .speculate_verify_window
15111513 config .speculate_max_candidate_len = predictor_args .speculate_max_candidate_len
@@ -1738,7 +1740,7 @@ def create_predictor(
17381740 elif predictor_args .speculate_method is not None :
17391741 config .max_seq_len = predictor_args .total_max_length
17401742 config .block_size = predictor_args .block_size
1741- config .speculate_max_draft_tokens = predictor_args .speculate_max_draft_tokens
1743+ config .speculate_max_draft_token_num = predictor_args .speculate_max_draft_token_num
17421744 config .speculate_max_ngram_size = predictor_args .speculate_max_ngram_size
17431745 config .speculate_verify_window = predictor_args .speculate_verify_window
17441746 config .speculate_max_candidate_len = predictor_args .speculate_max_candidate_len
0 commit comments