2222 AutoModelForQuestionAnswering , StoppingCriteria , StoppingCriteriaList )
2323from transformers .tokenization_utils_base import PreTrainedTokenizerBase
2424from peft import PeftConfig , PeftModel , PeftModelForCausalLM
25- from typing import Tuple , List
25+ from typing import Tuple , List , Callable , Dict
2626
2727from djl_python .encode_decode import encode
2828from djl_python .inputs import Input
2929from djl_python .outputs import Output
30+ from djl_python .rolling_batch .rolling_batch import RollingBatch
3031from djl_python .streaming_utils import StreamingUtils
3132
3233from djl_python .properties_manager .properties import StreamingEnum , is_rolling_batch_enabled , is_streaming_enabled
3334from djl_python .properties_manager .hf_properties import HuggingFaceProperties
34- from djl_python .utils import parse_input_with_formatter , InputFormatConfigs
35+ from djl_python .utils import parse_input_with_formatter , InputFormatConfigs , ParsedInput , rolling_batch_inference
3536
3637ARCHITECTURES_2_TASK = {
3738 "TapasForQuestionAnswering" : "table-question-answering" ,
@@ -140,6 +141,7 @@ def __init__(self):
140141 self .adapters = None
141142 self .hf_configs = None
142143 self .input_format_configs = None
144+ self .parsed_input = None
143145
144146 def initialize (self , properties : dict ):
145147 self .hf_configs = HuggingFaceProperties (** properties )
@@ -230,13 +232,14 @@ def parse_input(
230232 :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
231233 """
232234
233- parsed_input = parse_input_with_formatter (inputs ,
234- self .input_format_configs ,
235- self .adapter_registry )
236- self .adapters = parsed_input .adapters if parsed_input .found_adapters else None
237- return parsed_input .input_data , parsed_input .input_size , parsed_input .parameters , parsed_input .errors , parsed_input .batch
235+ self .parsed_input = parse_input_with_formatter (
236+ inputs , self .input_format_configs , self .adapter_registry )
237+ self .adapters = self .parsed_input .adapters
238+ return (self .parsed_input .input_data , self .parsed_input .input_size ,
239+ self .parsed_input .parameters , self .parsed_input .errors ,
240+ self .parsed_input .batch )
238241
239- def inference (self , inputs ) :
242+ def inference (self , inputs : Input ) -> Output :
240243 outputs = Output ()
241244
242245 input_data , input_size , parameters , errors , batch = self .parse_input (
@@ -254,70 +257,28 @@ def inference(self, inputs):
254257 return outputs
255258
256259 if is_rolling_batch_enabled (self .hf_configs .rolling_batch ):
257- if inputs .get_property ("reset_rollingbatch" ):
258- self .rolling_batch .reset ()
259- if self .adapters is not None :
260- adapter_data = [
261- self .adapter_registry .get (adapter , None )
262- for adapter in self .adapters
263- ]
264- else :
265- adapter_data = None
266- result = self .rolling_batch .inference (input_data ,
267- parameters ,
268- adapters = adapter_data )
269- idx = 0
270- for i in range (len (batch )):
271- err = errors .get (i )
272- if err :
273- err = {"data" : "" , "last" : True , "code" : 424 , "error" : err }
274- outputs .add (Output .binary_encode (err ),
275- key = "data" ,
276- batch_index = i )
277- outputs .add_property (f"batch_{ i } _Content-Type" ,
278- "application/json" )
279- else :
280- content_type = result [idx ].pop ("content_type" )
281- outputs .add (Output .binary_encode (result [idx ]),
282- key = "data" ,
283- batch_index = i )
284- if content_type is not None :
285- outputs .add_property (f"batch_{ i } _Content-Type" ,
286- content_type )
287- idx += 1
288-
289- return outputs
260+ return rolling_batch_inference (self .parsed_input , inputs , outputs ,
261+ self .rolling_batch )
290262 elif is_streaming_enabled (self .hf_configs .enable_streaming ):
291- if len (batch ) > 1 :
292- raise NotImplementedError (
293- "Dynamic batch not supported for generic streaming" )
294- outputs .add_property ("content-type" , "application/jsonlines" )
295- if self .hf_configs .enable_streaming .value == StreamingEnum .huggingface .value :
296- outputs .add_stream_content (
297- StreamingUtils .use_hf_default_streamer (
298- self .model , self .tokenizer , input_data ,
299- self .hf_configs .device , ** parameters [0 ]))
300- else :
301- stream_generator = StreamingUtils .get_stream_generator (
302- "Accelerate" )
303- outputs .add_stream_content (
304- stream_generator (self .model , self .tokenizer , input_data ,
305- self .hf_configs .device , ** parameters [0 ]))
306- return outputs
263+ return self ._streaming_inference (batch , input_data , outputs ,
264+ parameters )
265+ else :
266+ return self ._dynamic_batch_inference (batch , errors , input_data ,
267+ input_size , inputs , outputs ,
268+ parameters )
307269
270+ def _dynamic_batch_inference (self , batch , errors , input_data , input_size ,
271+ inputs , outputs , parameters ):
308272 if not all (p == parameters [0 ] for p in parameters ):
309273 raise ValueError (
310274 "In order to enable dynamic batching, all input batches must have the same parameters"
311275 )
312-
313276 if isinstance (self .model , PeftModelForCausalLM ):
314277 if self .adapters is None :
315278 # Inference with only base model
316279 self .adapters = ["" ] * len (input_data )
317280 parameters [0 ]["adapters" ] = self .adapters
318-
319281 prediction = self .hf_pipeline (input_data , ** parameters [0 ])
320-
321282 offset = 0
322283 for i , item in enumerate (batch ):
323284 content_type = item .get_property ("Content-Type" )
@@ -341,7 +302,26 @@ def inference(self, inputs):
341302 accept ,
342303 key = inputs .get_content ().key_at (i ))
343304 offset += input_size [i ]
305+ return outputs
344306
307+ def _streaming_inference (self , batch , input_data , outputs , parameters ):
308+ if len (batch ) > 1 :
309+ raise NotImplementedError (
310+ "Dynamic batch not supported for generic streaming" )
311+ outputs .add_property ("content-type" , "application/jsonlines" )
312+ if self .hf_configs .enable_streaming .value == StreamingEnum .huggingface .value :
313+ outputs .add_stream_content (
314+ StreamingUtils .use_hf_default_streamer (self .model ,
315+ self .tokenizer ,
316+ input_data ,
317+ self .hf_configs .device ,
318+ ** parameters [0 ]))
319+ else :
320+ stream_generator = StreamingUtils .get_stream_generator (
321+ "Accelerate" )
322+ outputs .add_stream_content (
323+ stream_generator (self .model , self .tokenizer , input_data ,
324+ self .hf_configs .device , ** parameters [0 ]))
345325 return outputs
346326
347327 def get_pipeline (self , task : str , model_id_or_path : str , kwargs ):
0 commit comments