2323 StoppingCriteriaList )
2424from transformers .tokenization_utils_base import PreTrainedTokenizerBase
2525from peft import PeftConfig , PeftModel , PeftModelForCausalLM
26- from typing import Tuple , List
26+ from typing import List , Dict
2727
2828from djl_python .encode_decode import encode
2929from djl_python .inputs import Input
3030from djl_python .outputs import Output
31+ from djl_python .request_io import RequestInput
3132from djl_python .streaming_utils import StreamingUtils
3233
3334from djl_python .properties_manager .properties import StreamingEnum , is_rolling_batch_enabled , is_streaming_enabled
3435from djl_python .properties_manager .hf_properties import HuggingFaceProperties
35- from djl_python .utils import rolling_batch_inference
36- from djl_python .input_parser import ParsedInput , InputFormatConfigs , parse_input_with_formatter
36+ from djl_python .utils import rolling_batch_inference , get_input_details
37+ from djl_python .input_parser import parse_input_with_formatter
3738
3839ARCHITECTURES_2_TASK = {
3940 "TapasForQuestionAnswering" : "table-question-answering" ,
@@ -141,8 +142,7 @@ def __init__(self):
141142 self .adapter_registry = {}
142143 self .adapters = None
143144 self .hf_configs = None
144- self .input_format_configs = None
145- self .parsed_input = None
145+ self .input_format_args = None
146146
147147 def initialize (self , properties : dict ):
148148 self .hf_configs = HuggingFaceProperties (** properties )
@@ -174,14 +174,19 @@ def initialize(self, properties: dict):
174174 if "stop_sequence" in properties :
175175 self .load_stopping_criteria_list (properties ["stop_sequence" ])
176176
177- self .input_format_configs = InputFormatConfigs (
178- is_rolling_batch = is_rolling_batch_enabled (
179- self .hf_configs .rolling_batch ),
180- is_adapters_supported = True ,
181- output_formatter = self .hf_configs .output_formatter ,
182- tokenizer = self .tokenizer )
177+ self .input_format_args = self .get_input_format_args ()
183178 self .initialized = True
184179
180+ def get_input_format_args (self ):
181+ return {
182+ "configs" : self .hf_configs ,
183+ "tokenizer" : self .tokenizer ,
184+ "adapter_registry" : self .adapter_registry ,
185+ "model_config" : self .model_config ,
186+ "peft_config" : self .peft_config ,
187+ "rolling_batch" : self .rolling_batch
188+ }
189+
185190 @staticmethod
186191 def parse_stop_sequence_input (stop_sequence ):
187192 """
@@ -216,37 +221,14 @@ def load_stopping_criteria_list(self, stop_sequence):
216221
217222 self .stopping_criteria_list = StoppingCriteriaList (stopwords )
218223
219- def parse_input (
220- self , inputs : Input , tokenizer , output_formatter
221- ) -> Tuple [List [str ], List [int ], List [dict ], dict , list ]:
222- """
223- Preprocessing function that extracts information from Input objects.
224-
225- :param output_formatter: output formatter for the request
226- :param inputs :(Input) a batch of inputs, each corresponding to a new request
227- :param tokenizer: the tokenizer used for inference
228-
229- :return input_data (List[str]): a list of strings, each string being the prompt in a new request
230- :return input_size (List[int]): a list of ints being the size of each new request
231- :return parameters (List[dict]): parameters pertaining to each request
232- :return errors (dict): a dictionary mapping int indices to corresponding error strings if any
233- :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
234- """
235-
236- self .parsed_input = parse_input_with_formatter (
237- inputs , self .input_format_configs , self .adapter_registry )
238- self .adapters = self .parsed_input .adapters
239- return (self .parsed_input .input_data , self .parsed_input .input_size ,
240- self .parsed_input .parameters , self .parsed_input .errors ,
241- self .parsed_input .batch )
242-
243224 def inference (self , inputs : Input ) -> Output :
244225 outputs = Output ()
245-
246- input_data , input_size , parameters , errors , batch = self .parse_input (
247- inputs , self .tokenizer , self .hf_configs .output_formatter )
248- if len (input_data ) == 0 :
249- for i in range (len (batch )):
226+ parsed_input = parse_input_with_formatter (inputs ,
227+ ** self .input_format_args )
228+ requests = parsed_input .requests
229+ errors = parsed_input .errors
230+ if len (requests ) == 0 :
231+ for i in range (len (parsed_input .batch )):
250232 err = errors .get (i )
251233 if is_rolling_batch_enabled (self .hf_configs .rolling_batch ):
252234 err = {"data" : "" , "last" : True , "code" : 424 , "error" : err }
@@ -258,28 +240,29 @@ def inference(self, inputs: Input) -> Output:
258240 return outputs
259241
260242 if is_rolling_batch_enabled (self .hf_configs .rolling_batch ):
261- return rolling_batch_inference (self . parsed_input , inputs , outputs ,
243+ return rolling_batch_inference (parsed_input , inputs , outputs ,
262244 self .rolling_batch )
263245 elif is_streaming_enabled (self .hf_configs .enable_streaming ):
264- return self ._streaming_inference (batch , input_data , outputs ,
265- parameters )
246+ request_input = requests [0 ].request_input
247+ return self ._streaming_inference (parsed_input .batch , request_input ,
248+ outputs )
266249 else :
267- return self ._dynamic_batch_inference (batch , errors , input_data ,
268- input_size , inputs , outputs ,
269- parameters )
250+ return self ._dynamic_batch_inference (parsed_input .batch , errors ,
251+ inputs , outputs , requests )
252+
253+ def _dynamic_batch_inference (self , batch : List , errors : Dict ,
254+ inputs : Input , outputs : Output ,
255+ requests : List ):
256+ # Dynamic batching
257+ input_data , input_size = get_input_details (requests , errors , batch )
258+ parameters = requests [0 ].request_input .server_parameters
270259
271- def _dynamic_batch_inference (self , batch , errors , input_data , input_size ,
272- inputs , outputs , parameters ):
273- if not all (p == parameters [0 ] for p in parameters ):
274- raise ValueError (
275- "In order to enable dynamic batching, all input batches must have the same parameters"
276- )
277260 if isinstance (self .model , PeftModelForCausalLM ):
278261 if self .adapters is None :
279262 # Inference with only base model
280263 self .adapters = ["" ] * len (input_data )
281- parameters [0 ][ "adapters" ] = self .adapters
282- prediction = self .hf_pipeline (input_data , ** parameters [ 0 ] )
264+ parameters ["adapters" ] = self .adapters
265+ prediction = self .hf_pipeline (input_data , ** parameters )
283266 offset = 0
284267 for i , item in enumerate (batch ):
285268 content_type = item .get_property ("Content-Type" )
@@ -305,24 +288,25 @@ def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
305288 offset += input_size [i ]
306289 return outputs
307290
308- def _streaming_inference (self , batch , input_data , outputs , parameters ):
291+ def _streaming_inference (self , batch : List , request_input : RequestInput ,
292+ outputs : Output ):
309293 if len (batch ) > 1 :
310294 raise NotImplementedError (
311295 "Dynamic batch not supported for generic streaming" )
312296 outputs .add_property ("content-type" , "application/jsonlines" )
313297 if self .hf_configs .enable_streaming .value == StreamingEnum .huggingface .value :
314298 outputs .add_stream_content (
315- StreamingUtils .use_hf_default_streamer (self .model ,
316- self .tokenizer ,
317- input_data ,
318- self .hf_configs .device ,
319- ** parameters [0 ]))
299+ StreamingUtils .use_hf_default_streamer (
300+ self .model , self .tokenizer , request_input .input_text ,
301+ self .hf_configs .device , ** request_input .server_parameters ))
320302 else :
321303 stream_generator = StreamingUtils .get_stream_generator (
322304 "Accelerate" )
323305 outputs .add_stream_content (
324- stream_generator (self .model , self .tokenizer , input_data ,
325- self .hf_configs .device , ** parameters [0 ]))
306+ stream_generator (self .model , self .tokenizer ,
307+ request_input .input_text ,
308+ self .hf_configs .device ,
309+ ** request_input .server_parameters ))
326310 return outputs
327311
328312 def get_pipeline (self , task : str , model_id_or_path : str , kwargs ):
0 commit comments