2424import numpy as np
2525import paddle
2626import paddle .incubate .multiprocessing as mp
27+ from env import MAX_BSZ , MAX_DRAFT_TOKENS , SPECULATE_MAX_BSZ
2728from paddle .base .framework import in_cinn_mode , in_pir_executor_mode , use_pir_api
2829from paddle .distributed import fleet
30+ from proposers import InferenceWithReferenceProposer
2931
30- from llm .speculate_decoding .proposers import InferenceWithReferenceProposer
3132from paddlenlp .generation import GenerationConfig , TextIteratorStreamer
3233from paddlenlp .peft import LoRAConfig , LoRAModel , PrefixConfig , PrefixModelForCausalLM
3334from paddlenlp .taskflow .utils import static_mode_guard
4748from paddlenlp .utils .import_utils import is_paddlenlp_ops_available
4849from paddlenlp .utils .log import logger
4950
50- # Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
51- MAX_BSZ = 512
52- # Note(@Wanglongzhi2001): SPECULATE_MAX_BSZ must be the same as definition in speculate_get_output / speculate_save_output
53- SPECULATE_MAX_BSZ = 256
54- MAX_DRAFT_TOKENS = 6
55-
5651
5752@dataclass
5853class PredictorArgument :
@@ -108,7 +103,7 @@ class PredictorArgument:
108103 default = "fp16" ,
109104 metadata = {"help" : "avx cachekv type. Supported values: fp16,int8" },
110105 )
111- batch_size : int = field (default = 10 , metadata = {"help" : "The batch size of data." })
106+ batch_size : int = field (default = 1 , metadata = {"help" : "The batch size of data." })
112107 benchmark : bool = field (
113108 default = False ,
114109 metadata = {
@@ -1242,15 +1237,11 @@ def predict(self, input_texts: list[str], return_tokens=False):
12421237 else :
12431238 return outputs
12441239
1245- def _preprocess (self , input_text : list [str ]):
1246- super ()._preprocess (input_text )
1247-
1240+ def init_proposer_args (self ):
12481241 for bid in range (self .config .batch_size ):
12491242 self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
12501243 self .model_inputs ["seq_lens_this_time" ][bid ] - 1
12511244 ] # get the last token before padding of this batch
1252-
1253- def init_proposer_args (self ):
12541245 self .model_inputs ["accept_tokens" ] = paddle .full (
12551246 shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12561247 )
@@ -1286,6 +1277,11 @@ def __init__(
12861277 self .proposer = None
12871278
12881279 def init_proposer_args (self ):
1280+ for bid in range (self .config .batch_size ):
1281+ self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1282+ self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1283+ ] # get the last token before padding of this batch
1284+
12891285 self .model_inputs ["accept_tokens" ] = paddle .full (
12901286 shape = [self .config .batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
12911287 )
@@ -1299,14 +1295,6 @@ def init_proposer_args(self):
12991295 if self .config .speculate_method == "inference_with_reference" :
13001296 self .proposer .input_ids_cpu = self .model_inputs ["input_ids" ].cpu ()
13011297
1302- def _preprocess (self , input_text : list [str ]):
1303- super ()._preprocess (input_text )
1304-
1305- for bid in range (self .config .batch_size ):
1306- self .model_inputs ["pre_ids" ][bid , 0 ] = self .model_inputs ["input_ids" ][bid ][
1307- self .model_inputs ["seq_lens_this_time" ][bid ] - 1
1308- ] # get the last token before padding of this batch
1309-
13101298 def predict (self , input_texts : list [str ], return_tokens = False ):
13111299 s_time = time .time ()
13121300 self ._preprocess (input_texts )
0 commit comments