2727from env import MAX_BSZ , MAX_DRAFT_TOKENS , SPECULATE_MAX_BSZ
2828from paddle .base .framework import in_cinn_mode , in_pir_executor_mode , use_pir_api
2929from paddle .distributed import fleet
30+ from paddlenlp_ops import speculate_update_input_ids_cpu
3031from proposers import InferenceWithReferenceProposer
3132
3233from paddlenlp .generation import GenerationConfig , TextIteratorStreamer
@@ -103,7 +104,7 @@ class PredictorArgument:
103104 default = "fp16" ,
104105 metadata = {"help" : "avx cachekv type. Supported values: fp16,int8" },
105106 )
106- batch_size : int = field (default = 1 , metadata = {"help" : "The batch size of data." })
107+ batch_size : int = field (default = 10 , metadata = {"help" : "The batch size of data." })
107108 benchmark : bool = field (
108109 default = False ,
109110 metadata = {
@@ -964,6 +965,8 @@ def _preprocess(self, input_text: list[str]):
964965 for k , v in self .model_inputs .items ():
965966 v .name = k
966967
968+ return seq_lens
969+
967970
968971class DygraphBlockInferencePredictor (BlockInferencePredictorMixin ):
969972 def __init__ (
@@ -1187,7 +1190,7 @@ def __init__(
11871190
11881191 @paddle .no_grad ()
11891192 def predict (self , input_texts : list [str ], return_tokens = False ):
1190- self ._preprocess (input_texts )
1193+ self .seq_lens = self . _preprocess (input_texts )
11911194
11921195 # Parameters such as seq_lens_encoder have been set in the preprocessor function,
11931196 # then we use them to init the proposer's args
@@ -1238,10 +1241,6 @@ def predict(self, input_texts: list[str], return_tokens=False):
12381241 return outputs
12391242
12401243 def init_proposer_args (self ):
1241- for bid in range (self .config .batch_size ):
1242- self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1243- self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1244- ] # get the last token before padding of this batch
12451244 self .model_inputs ["accept_tokens" ] = paddle .full (
12461245 shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12471246 )
@@ -1252,9 +1251,23 @@ def init_proposer_args(self):
12521251 self .model_inputs ["actual_draft_token_num" ] = paddle .full (
12531252 shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_token_num , dtype = "int32"
12541253 )
1255- if self .config .speculate_method == "inference_with_reference" :
1256- self .proposer .input_ids_cpu = self .model_inputs ["input_ids" ].cpu ()
1257- self .proposer .input_ids_len = self .model_inputs ["seq_lens_encoder" ].astype ("int64" ).cpu ()
1254+ self .model_inputs ["input_ids_cpu" ] = paddle .full (
1255+ shape = [self .config .batch_size , self .config .total_max_length ], fill_value = 1 , dtype = "int64"
1256+ ).cpu ()
1257+
1258+ for bid in range (self .config .batch_size ):
1259+ if self .config .speculate_method == "inference_with_reference" :
1260+ speculate_update_input_ids_cpu (
1261+ self .model_inputs ["input_ids_cpu" ],
1262+ self .model_inputs ["input_ids" ][bid ].cpu ().tolist (),
1263+ bid ,
1264+ self .config .max_length ,
1265+ )
1266+ self .proposer .update (bid , self .seq_lens [bid ])
1267+
1268+ self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1269+ self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1270+ ] # get the last token before padding of this batch
12581271
12591272
12601273class StaticSpeculateInferencePredictor (StaticBlockInferencePredictor ):
@@ -1277,11 +1290,6 @@ def __init__(
12771290 self .proposer = None
12781291
12791292 def init_proposer_args (self ):
1280- for bid in range (self .config .batch_size ):
1281- self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1282- self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1283- ] # get the last token before padding of this batch
1284-
12851293 self .model_inputs ["accept_tokens" ] = paddle .full (
12861294 shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12871295 )
@@ -1292,12 +1300,27 @@ def init_proposer_args(self):
12921300 self .model_inputs ["actual_draft_token_num" ] = paddle .full (
12931301 shape = [self .config .batch_size ], fill_value = self .config .speculate_max_draft_token_num , dtype = "int32"
12941302 )
1295- if self .config .speculate_method == "inference_with_reference" :
1296- self .proposer .input_ids_cpu = self .model_inputs ["input_ids" ].cpu ()
1303+ self .model_inputs ["input_ids_cpu" ] = paddle .full (
1304+ shape = [self .config .batch_size , self .config .total_max_length ], fill_value = 1 , dtype = "int64"
1305+ ).cpu ()
1306+
1307+ for bid in range (self .config .batch_size ):
1308+ if self .config .speculate_method == "inference_with_reference" :
1309+ speculate_update_input_ids_cpu (
1310+ self .model_inputs ["input_ids_cpu" ],
1311+ self .model_inputs ["input_ids" ][bid ].cpu ().tolist (),
1312+ bid ,
1313+ self .config .max_length ,
1314+ )
1315+ self .proposer .update (bid , self .seq_lens [bid ])
1316+
1317+ self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1318+ self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1319+ ] # get the last token before padding of this batch
12971320
12981321 def predict (self , input_texts : list [str ], return_tokens = False ):
12991322 s_time = time .time ()
1300- self ._preprocess (input_texts )
1323+ self .seq_lens = self . _preprocess (input_texts )
13011324
13021325 # Parameters such as seq_lens_encoder have been set in the preprocessor function,
13031326 # then we use them to init the proposer's args
0 commit comments