Skip to content

Commit b6bc5ca

Browse files
[python] refactor input parser to support Request (#2145)
1 parent 70aca0c commit b6bc5ca

19 files changed

+508
-528
lines changed

engines/python/setup/djl_python/huggingface.py

Lines changed: 49 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323
StoppingCriteriaList)
2424
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2525
from peft import PeftConfig, PeftModel, PeftModelForCausalLM
26-
from typing import Tuple, List
26+
from typing import List, Dict
2727

2828
from djl_python.encode_decode import encode
2929
from djl_python.inputs import Input
3030
from djl_python.outputs import Output
31+
from djl_python.request_io import RequestInput
3132
from djl_python.streaming_utils import StreamingUtils
3233

3334
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
3435
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
35-
from djl_python.utils import rolling_batch_inference
36-
from djl_python.input_parser import ParsedInput, InputFormatConfigs, parse_input_with_formatter
36+
from djl_python.utils import rolling_batch_inference, get_input_details
37+
from djl_python.input_parser import parse_input_with_formatter
3738

3839
ARCHITECTURES_2_TASK = {
3940
"TapasForQuestionAnswering": "table-question-answering",
@@ -139,10 +140,8 @@ def __init__(self):
139140
self.peft_config = None
140141
self.stopping_criteria_list = None
141142
self.adapter_registry = {}
142-
self.adapters = None
143143
self.hf_configs = None
144-
self.input_format_configs = None
145-
self.parsed_input = None
144+
self.input_format_args = None
146145

147146
def initialize(self, properties: dict):
148147
self.hf_configs = HuggingFaceProperties(**properties)
@@ -174,14 +173,19 @@ def initialize(self, properties: dict):
174173
if "stop_sequence" in properties:
175174
self.load_stopping_criteria_list(properties["stop_sequence"])
176175

177-
self.input_format_configs = InputFormatConfigs(
178-
is_rolling_batch=is_rolling_batch_enabled(
179-
self.hf_configs.rolling_batch),
180-
is_adapters_supported=True,
181-
output_formatter=self.hf_configs.output_formatter,
182-
tokenizer=self.tokenizer)
176+
self.input_format_args = self.get_input_format_args()
183177
self.initialized = True
184178

179+
def get_input_format_args(self):
180+
return {
181+
"configs": self.hf_configs,
182+
"tokenizer": self.tokenizer,
183+
"adapter_registry": self.adapter_registry,
184+
"model_config": self.model_config,
185+
"peft_config": self.peft_config,
186+
"rolling_batch": self.rolling_batch
187+
}
188+
185189
@staticmethod
186190
def parse_stop_sequence_input(stop_sequence):
187191
"""
@@ -216,37 +220,14 @@ def load_stopping_criteria_list(self, stop_sequence):
216220

217221
self.stopping_criteria_list = StoppingCriteriaList(stopwords)
218222

219-
def parse_input(
220-
self, inputs: Input, tokenizer, output_formatter
221-
) -> Tuple[List[str], List[int], List[dict], dict, list]:
222-
"""
223-
Preprocessing function that extracts information from Input objects.
224-
225-
:param output_formatter: output formatter for the request
226-
:param inputs :(Input) a batch of inputs, each corresponding to a new request
227-
:param tokenizer: the tokenizer used for inference
228-
229-
:return input_data (List[str]): a list of strings, each string being the prompt in a new request
230-
:return input_size (List[int]): a list of ints being the size of each new request
231-
:return parameters (List[dict]): parameters pertaining to each request
232-
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
233-
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
234-
"""
235-
236-
self.parsed_input = parse_input_with_formatter(
237-
inputs, self.input_format_configs, self.adapter_registry)
238-
self.adapters = self.parsed_input.adapters
239-
return (self.parsed_input.input_data, self.parsed_input.input_size,
240-
self.parsed_input.parameters, self.parsed_input.errors,
241-
self.parsed_input.batch)
242-
243223
def inference(self, inputs: Input) -> Output:
244224
outputs = Output()
245-
246-
input_data, input_size, parameters, errors, batch = self.parse_input(
247-
inputs, self.tokenizer, self.hf_configs.output_formatter)
248-
if len(input_data) == 0:
249-
for i in range(len(batch)):
225+
parsed_input = parse_input_with_formatter(inputs,
226+
**self.input_format_args)
227+
requests = parsed_input.requests
228+
errors = parsed_input.errors
229+
if len(requests) == 0:
230+
for i in range(len(parsed_input.batch)):
250231
err = errors.get(i)
251232
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
252233
err = {"data": "", "last": True, "code": 424, "error": err}
@@ -258,28 +239,29 @@ def inference(self, inputs: Input) -> Output:
258239
return outputs
259240

260241
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
261-
return rolling_batch_inference(self.parsed_input, inputs, outputs,
242+
return rolling_batch_inference(parsed_input, inputs, outputs,
262243
self.rolling_batch)
263244
elif is_streaming_enabled(self.hf_configs.enable_streaming):
264-
return self._streaming_inference(batch, input_data, outputs,
265-
parameters)
245+
request_input = requests[0].request_input
246+
return self._streaming_inference(parsed_input.batch, request_input,
247+
outputs)
266248
else:
267-
return self._dynamic_batch_inference(batch, errors, input_data,
268-
input_size, inputs, outputs,
269-
parameters)
249+
return self._dynamic_batch_inference(parsed_input.batch, errors,
250+
inputs, outputs, requests)
251+
252+
def _dynamic_batch_inference(self, batch: List, errors: Dict,
253+
inputs: Input, outputs: Output,
254+
requests: List):
255+
# Dynamic batching
256+
input_data, input_size, parameters, adapters = get_input_details(
257+
requests, errors, batch)
270258

271-
def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
272-
inputs, outputs, parameters):
273-
if not all(p == parameters[0] for p in parameters):
274-
raise ValueError(
275-
"In order to enable dynamic batching, all input batches must have the same parameters"
276-
)
277259
if isinstance(self.model, PeftModelForCausalLM):
278-
if self.adapters is None:
260+
if adapters is None:
279261
# Inference with only base model
280-
self.adapters = [""] * len(input_data)
281-
parameters[0]["adapters"] = self.adapters
282-
prediction = self.hf_pipeline(input_data, **parameters[0])
262+
adapters = [""] * len(input_data)
263+
parameters["adapters"] = adapters
264+
prediction = self.hf_pipeline(input_data, **parameters)
283265
offset = 0
284266
for i, item in enumerate(batch):
285267
content_type = item.get_property("Content-Type")
@@ -305,24 +287,26 @@ def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
305287
offset += input_size[i]
306288
return outputs
307289

308-
def _streaming_inference(self, batch, input_data, outputs, parameters):
290+
def _streaming_inference(self, batch: List, request_input: RequestInput,
291+
outputs: Output):
309292
if len(batch) > 1:
310293
raise NotImplementedError(
311294
"Dynamic batch not supported for generic streaming")
295+
296+
parameters = request_input.server_parameters
312297
outputs.add_property("content-type", "application/jsonlines")
313298
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
314299
outputs.add_stream_content(
315-
StreamingUtils.use_hf_default_streamer(self.model,
316-
self.tokenizer,
317-
input_data,
318-
self.hf_configs.device,
319-
**parameters[0]))
300+
StreamingUtils.use_hf_default_streamer(
301+
self.model, self.tokenizer, request_input.input_text,
302+
self.hf_configs.device, **parameters))
320303
else:
321304
stream_generator = StreamingUtils.get_stream_generator(
322305
"Accelerate")
323306
outputs.add_stream_content(
324-
stream_generator(self.model, self.tokenizer, input_data,
325-
self.hf_configs.device, **parameters[0]))
307+
stream_generator(self.model, self.tokenizer,
308+
request_input.input_text,
309+
self.hf_configs.device, **parameters))
326310
return outputs
327311

328312
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):

0 commit comments

Comments
 (0)