2323 StoppingCriteriaList )
2424from transformers .tokenization_utils_base import PreTrainedTokenizerBase
2525from peft import PeftConfig , PeftModel , PeftModelForCausalLM
26- from typing import Tuple , List
26+ from typing import List , Dict
2727
2828from djl_python .encode_decode import encode
2929from djl_python .inputs import Input
3030from djl_python .outputs import Output
31+ from djl_python .request_io import RequestInput
3132from djl_python .streaming_utils import StreamingUtils
3233
3334from djl_python .properties_manager .properties import StreamingEnum , is_rolling_batch_enabled , is_streaming_enabled
3435from djl_python .properties_manager .hf_properties import HuggingFaceProperties
35- from djl_python .utils import rolling_batch_inference
36- from djl_python .input_parser import ParsedInput , InputFormatConfigs , parse_input_with_formatter
36+ from djl_python .utils import rolling_batch_inference , get_input_details
37+ from djl_python .input_parser import parse_input_with_formatter
3738
3839ARCHITECTURES_2_TASK = {
3940 "TapasForQuestionAnswering" : "table-question-answering" ,
@@ -139,10 +140,8 @@ def __init__(self):
139140 self .peft_config = None
140141 self .stopping_criteria_list = None
141142 self .adapter_registry = {}
142- self .adapters = None
143143 self .hf_configs = None
144- self .input_format_configs = None
145- self .parsed_input = None
144+ self .input_format_args = None
146145
147146 def initialize (self , properties : dict ):
148147 self .hf_configs = HuggingFaceProperties (** properties )
@@ -174,14 +173,19 @@ def initialize(self, properties: dict):
174173 if "stop_sequence" in properties :
175174 self .load_stopping_criteria_list (properties ["stop_sequence" ])
176175
177- self .input_format_configs = InputFormatConfigs (
178- is_rolling_batch = is_rolling_batch_enabled (
179- self .hf_configs .rolling_batch ),
180- is_adapters_supported = True ,
181- output_formatter = self .hf_configs .output_formatter ,
182- tokenizer = self .tokenizer )
176+ self .input_format_args = self .get_input_format_args ()
183177 self .initialized = True
184178
179+ def get_input_format_args (self ):
180+ return {
181+ "configs" : self .hf_configs ,
182+ "tokenizer" : self .tokenizer ,
183+ "adapter_registry" : self .adapter_registry ,
184+ "model_config" : self .model_config ,
185+ "peft_config" : self .peft_config ,
186+ "rolling_batch" : self .rolling_batch
187+ }
188+
185189 @staticmethod
186190 def parse_stop_sequence_input (stop_sequence ):
187191 """
@@ -216,37 +220,14 @@ def load_stopping_criteria_list(self, stop_sequence):
216220
217221 self .stopping_criteria_list = StoppingCriteriaList (stopwords )
218222
219- def parse_input (
220- self , inputs : Input , tokenizer , output_formatter
221- ) -> Tuple [List [str ], List [int ], List [dict ], dict , list ]:
222- """
223- Preprocessing function that extracts information from Input objects.
224-
225- :param output_formatter: output formatter for the request
226- :param inputs :(Input) a batch of inputs, each corresponding to a new request
227- :param tokenizer: the tokenizer used for inference
228-
229- :return input_data (List[str]): a list of strings, each string being the prompt in a new request
230- :return input_size (List[int]): a list of ints being the size of each new request
231- :return parameters (List[dict]): parameters pertaining to each request
232- :return errors (dict): a dictionary mapping int indices to corresponding error strings if any
233- :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
234- """
235-
236- self .parsed_input = parse_input_with_formatter (
237- inputs , self .input_format_configs , self .adapter_registry )
238- self .adapters = self .parsed_input .adapters
239- return (self .parsed_input .input_data , self .parsed_input .input_size ,
240- self .parsed_input .parameters , self .parsed_input .errors ,
241- self .parsed_input .batch )
242-
243223 def inference (self , inputs : Input ) -> Output :
244224 outputs = Output ()
245-
246- input_data , input_size , parameters , errors , batch = self .parse_input (
247- inputs , self .tokenizer , self .hf_configs .output_formatter )
248- if len (input_data ) == 0 :
249- for i in range (len (batch )):
225+ parsed_input = parse_input_with_formatter (inputs ,
226+ ** self .input_format_args )
227+ requests = parsed_input .requests
228+ errors = parsed_input .errors
229+ if len (requests ) == 0 :
230+ for i in range (len (parsed_input .batch )):
250231 err = errors .get (i )
251232 if is_rolling_batch_enabled (self .hf_configs .rolling_batch ):
252233 err = {"data" : "" , "last" : True , "code" : 424 , "error" : err }
@@ -258,28 +239,29 @@ def inference(self, inputs: Input) -> Output:
258239 return outputs
259240
260241 if is_rolling_batch_enabled (self .hf_configs .rolling_batch ):
261- return rolling_batch_inference (self . parsed_input , inputs , outputs ,
242+ return rolling_batch_inference (parsed_input , inputs , outputs ,
262243 self .rolling_batch )
263244 elif is_streaming_enabled (self .hf_configs .enable_streaming ):
264- return self ._streaming_inference (batch , input_data , outputs ,
265- parameters )
245+ request_input = requests [0 ].request_input
246+ return self ._streaming_inference (parsed_input .batch , request_input ,
247+ outputs )
266248 else :
267- return self ._dynamic_batch_inference (batch , errors , input_data ,
268- input_size , inputs , outputs ,
269- parameters )
249+ return self ._dynamic_batch_inference (parsed_input .batch , errors ,
250+ inputs , outputs , requests )
251+
252+ def _dynamic_batch_inference (self , batch : List , errors : Dict ,
253+ inputs : Input , outputs : Output ,
254+ requests : List ):
255+ # Dynamic batching
256+ input_data , input_size , parameters , adapters = get_input_details (
257+ requests , errors , batch )
270258
271- def _dynamic_batch_inference (self , batch , errors , input_data , input_size ,
272- inputs , outputs , parameters ):
273- if not all (p == parameters [0 ] for p in parameters ):
274- raise ValueError (
275- "In order to enable dynamic batching, all input batches must have the same parameters"
276- )
277259 if isinstance (self .model , PeftModelForCausalLM ):
278- if self . adapters is None :
260+ if adapters is None :
279261 # Inference with only base model
280- self . adapters = ["" ] * len (input_data )
281- parameters [0 ][ "adapters" ] = self . adapters
282- prediction = self .hf_pipeline (input_data , ** parameters [ 0 ] )
262+ adapters = ["" ] * len (input_data )
263+ parameters ["adapters" ] = adapters
264+ prediction = self .hf_pipeline (input_data , ** parameters )
283265 offset = 0
284266 for i , item in enumerate (batch ):
285267 content_type = item .get_property ("Content-Type" )
@@ -305,24 +287,26 @@ def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
305287 offset += input_size [i ]
306288 return outputs
307289
308- def _streaming_inference (self , batch , input_data , outputs , parameters ):
290+ def _streaming_inference (self , batch : List , request_input : RequestInput ,
291+ outputs : Output ):
309292 if len (batch ) > 1 :
310293 raise NotImplementedError (
311294 "Dynamic batch not supported for generic streaming" )
295+
296+ parameters = request_input .server_parameters
312297 outputs .add_property ("content-type" , "application/jsonlines" )
313298 if self .hf_configs .enable_streaming .value == StreamingEnum .huggingface .value :
314299 outputs .add_stream_content (
315- StreamingUtils .use_hf_default_streamer (self .model ,
316- self .tokenizer ,
317- input_data ,
318- self .hf_configs .device ,
319- ** parameters [0 ]))
300+ StreamingUtils .use_hf_default_streamer (
301+ self .model , self .tokenizer , request_input .input_text ,
302+ self .hf_configs .device , ** parameters ))
320303 else :
321304 stream_generator = StreamingUtils .get_stream_generator (
322305 "Accelerate" )
323306 outputs .add_stream_content (
324- stream_generator (self .model , self .tokenizer , input_data ,
325- self .hf_configs .device , ** parameters [0 ]))
307+ stream_generator (self .model , self .tokenizer ,
308+ request_input .input_text ,
309+ self .hf_configs .device , ** parameters ))
326310 return outputs
327311
328312 def get_pipeline (self , task : str , model_id_or_path : str , kwargs ):
0 commit comments