Skip to content

Commit 1ad7a3d

Browse files
committed
[python] refactor input parser to support Request
1 parent 1482ace commit 1ad7a3d

18 files changed

+454
-517
lines changed

engines/python/setup/djl_python/huggingface.py

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323
StoppingCriteriaList)
2424
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2525
from peft import PeftConfig, PeftModel, PeftModelForCausalLM
26-
from typing import Tuple, List
26+
from typing import List, Dict
2727

2828
from djl_python.encode_decode import encode
2929
from djl_python.inputs import Input
3030
from djl_python.outputs import Output
31+
from djl_python.request_io import RequestInput
3132
from djl_python.streaming_utils import StreamingUtils
3233

3334
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
3435
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
35-
from djl_python.utils import rolling_batch_inference
36-
from djl_python.input_parser import ParsedInput, InputFormatConfigs, parse_input_with_formatter
36+
from djl_python.utils import rolling_batch_inference, get_input_details
37+
from djl_python.input_parser import parse_input_with_formatter
3738

3839
ARCHITECTURES_2_TASK = {
3940
"TapasForQuestionAnswering": "table-question-answering",
@@ -141,8 +142,7 @@ def __init__(self):
141142
self.adapter_registry = {}
142143
self.adapters = None
143144
self.hf_configs = None
144-
self.input_format_configs = None
145-
self.parsed_input = None
145+
self.input_format_args = None
146146

147147
def initialize(self, properties: dict):
148148
self.hf_configs = HuggingFaceProperties(**properties)
@@ -174,14 +174,19 @@ def initialize(self, properties: dict):
174174
if "stop_sequence" in properties:
175175
self.load_stopping_criteria_list(properties["stop_sequence"])
176176

177-
self.input_format_configs = InputFormatConfigs(
178-
is_rolling_batch=is_rolling_batch_enabled(
179-
self.hf_configs.rolling_batch),
180-
is_adapters_supported=True,
181-
output_formatter=self.hf_configs.output_formatter,
182-
tokenizer=self.tokenizer)
177+
self.input_format_args = self.get_input_format_args()
183178
self.initialized = True
184179

180+
def get_input_format_args(self):
181+
return {
182+
"configs": self.hf_configs,
183+
"tokenizer": self.tokenizer,
184+
"adapter_registry": self.adapter_registry,
185+
"model_config": self.model_config,
186+
"peft_config": self.peft_config,
187+
"rolling_batch": self.rolling_batch
188+
}
189+
185190
@staticmethod
186191
def parse_stop_sequence_input(stop_sequence):
187192
"""
@@ -216,37 +221,14 @@ def load_stopping_criteria_list(self, stop_sequence):
216221

217222
self.stopping_criteria_list = StoppingCriteriaList(stopwords)
218223

219-
def parse_input(
220-
self, inputs: Input, tokenizer, output_formatter
221-
) -> Tuple[List[str], List[int], List[dict], dict, list]:
222-
"""
223-
Preprocessing function that extracts information from Input objects.
224-
225-
:param output_formatter: output formatter for the request
226-
:param inputs :(Input) a batch of inputs, each corresponding to a new request
227-
:param tokenizer: the tokenizer used for inference
228-
229-
:return input_data (List[str]): a list of strings, each string being the prompt in a new request
230-
:return input_size (List[int]): a list of ints being the size of each new request
231-
:return parameters (List[dict]): parameters pertaining to each request
232-
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
233-
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
234-
"""
235-
236-
self.parsed_input = parse_input_with_formatter(
237-
inputs, self.input_format_configs, self.adapter_registry)
238-
self.adapters = self.parsed_input.adapters
239-
return (self.parsed_input.input_data, self.parsed_input.input_size,
240-
self.parsed_input.parameters, self.parsed_input.errors,
241-
self.parsed_input.batch)
242-
243224
def inference(self, inputs: Input) -> Output:
244225
outputs = Output()
245-
246-
input_data, input_size, parameters, errors, batch = self.parse_input(
247-
inputs, self.tokenizer, self.hf_configs.output_formatter)
248-
if len(input_data) == 0:
249-
for i in range(len(batch)):
226+
parsed_input = parse_input_with_formatter(inputs,
227+
**self.input_format_args)
228+
requests = parsed_input.requests
229+
errors = parsed_input.errors
230+
if len(requests) == 0:
231+
for i in range(len(parsed_input.batch)):
250232
err = errors.get(i)
251233
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
252234
err = {"data": "", "last": True, "code": 424, "error": err}
@@ -258,28 +240,29 @@ def inference(self, inputs: Input) -> Output:
258240
return outputs
259241

260242
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
261-
return rolling_batch_inference(self.parsed_input, inputs, outputs,
243+
return rolling_batch_inference(parsed_input, inputs, outputs,
262244
self.rolling_batch)
263245
elif is_streaming_enabled(self.hf_configs.enable_streaming):
264-
return self._streaming_inference(batch, input_data, outputs,
265-
parameters)
246+
request_input = requests[0].request_input
247+
return self._streaming_inference(parsed_input.batch, request_input,
248+
outputs)
266249
else:
267-
return self._dynamic_batch_inference(batch, errors, input_data,
268-
input_size, inputs, outputs,
269-
parameters)
250+
return self._dynamic_batch_inference(parsed_input.batch, errors,
251+
inputs, outputs, requests)
252+
253+
def _dynamic_batch_inference(self, batch: List, errors: Dict,
254+
inputs: Input, outputs: Output,
255+
requests: List):
256+
# Dynamic batching
257+
input_data, input_size = get_input_details(requests, errors, batch)
258+
parameters = requests[0].request_input.server_parameters
270259

271-
def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
272-
inputs, outputs, parameters):
273-
if not all(p == parameters[0] for p in parameters):
274-
raise ValueError(
275-
"In order to enable dynamic batching, all input batches must have the same parameters"
276-
)
277260
if isinstance(self.model, PeftModelForCausalLM):
278261
if self.adapters is None:
279262
# Inference with only base model
280263
self.adapters = [""] * len(input_data)
281-
parameters[0]["adapters"] = self.adapters
282-
prediction = self.hf_pipeline(input_data, **parameters[0])
264+
parameters["adapters"] = self.adapters
265+
prediction = self.hf_pipeline(input_data, **parameters)
283266
offset = 0
284267
for i, item in enumerate(batch):
285268
content_type = item.get_property("Content-Type")
@@ -305,24 +288,25 @@ def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
305288
offset += input_size[i]
306289
return outputs
307290

308-
def _streaming_inference(self, batch, input_data, outputs, parameters):
291+
def _streaming_inference(self, batch: List, request_input: RequestInput,
292+
outputs: Output):
309293
if len(batch) > 1:
310294
raise NotImplementedError(
311295
"Dynamic batch not supported for generic streaming")
312296
outputs.add_property("content-type", "application/jsonlines")
313297
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
314298
outputs.add_stream_content(
315-
StreamingUtils.use_hf_default_streamer(self.model,
316-
self.tokenizer,
317-
input_data,
318-
self.hf_configs.device,
319-
**parameters[0]))
299+
StreamingUtils.use_hf_default_streamer(
300+
self.model, self.tokenizer, request_input.input_text,
301+
self.hf_configs.device, **request_input.server_parameters))
320302
else:
321303
stream_generator = StreamingUtils.get_stream_generator(
322304
"Accelerate")
323305
outputs.add_stream_content(
324-
stream_generator(self.model, self.tokenizer, input_data,
325-
self.hf_configs.device, **parameters[0]))
306+
stream_generator(self.model, self.tokenizer,
307+
request_input.input_text,
308+
self.hf_configs.device,
309+
**request_input.server_parameters))
326310
return outputs
327311

328312
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):

0 commit comments

Comments
 (0)