Skip to content

Commit b3e3595

Browse files
committed
minor fixes
1 parent 1ad7a3d commit b3e3595

File tree

8 files changed

+74
-57
lines changed

8 files changed

+74
-57
lines changed

engines/python/setup/djl_python/huggingface.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(self):
140140
self.peft_config = None
141141
self.stopping_criteria_list = None
142142
self.adapter_registry = {}
143-
self.adapters = None
144143
self.hf_configs = None
145144
self.input_format_args = None
146145

@@ -254,14 +253,13 @@ def _dynamic_batch_inference(self, batch: List, errors: Dict,
254253
inputs: Input, outputs: Output,
255254
requests: List):
256255
# Dynamic batching
257-
input_data, input_size = get_input_details(requests, errors, batch)
258-
parameters = requests[0].request_input.server_parameters
256+
input_data, input_size, adapters, parameters = get_input_details(requests, errors, batch)
259257

260258
if isinstance(self.model, PeftModelForCausalLM):
261-
if self.adapters is None:
259+
if adapters is None:
262260
# Inference with only base model
263-
self.adapters = [""] * len(input_data)
264-
parameters["adapters"] = self.adapters
261+
adapters = [""] * len(input_data)
262+
parameters["adapters"] = adapters
265263
prediction = self.hf_pipeline(input_data, **parameters)
266264
offset = 0
267265
for i, item in enumerate(batch):
@@ -293,20 +291,24 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
293291
if len(batch) > 1:
294292
raise NotImplementedError(
295293
"Dynamic batch not supported for generic streaming")
294+
295+
parameters = request_input.server_parameters
296+
if isinstance(parameters, list):
297+
parameters = parameters[0]
296298
outputs.add_property("content-type", "application/jsonlines")
297299
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
298300
outputs.add_stream_content(
299301
StreamingUtils.use_hf_default_streamer(
300302
self.model, self.tokenizer, request_input.input_text,
301-
self.hf_configs.device, **request_input.server_parameters))
303+
self.hf_configs.device, **parameters))
302304
else:
303305
stream_generator = StreamingUtils.get_stream_generator(
304306
"Accelerate")
305307
outputs.add_stream_content(
306308
stream_generator(self.model, self.tokenizer,
307309
request_input.input_text,
308310
self.hf_configs.device,
309-
**request_input.server_parameters))
311+
**parameters))
310312
return outputs
311313

312314
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):

engines/python/setup/djl_python/input_parser.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
4444
request_id_counter = get_req_id_counter(kwargs)
4545
for i, input_item in enumerate(batch):
4646
try:
47+
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
48+
kwargs.get("configs").rolling_batch)
4749
request_id = request_id_counter.next_id(
4850
) if request_id_counter else i
4951
# TODO: Decide whether it is a text input based on content-type
@@ -70,7 +72,7 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
7072

7173
def get_req_id_counter(kwargs):
7274
req_id_counter = None
73-
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
75+
if kwargs.get("is_rolling_batch"):
7476
req_id_counter = kwargs.get("rolling_batch").req_id_counter
7577
return req_id_counter
7678

@@ -89,26 +91,29 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
8991
invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api")
9092
tokenizer = kwargs.get("tokenizer")
9193
if is_chat_completions_request(input_map):
92-
_inputs, _param = parse_chat_completions_request(
94+
inputs, param = parse_chat_completions_request(
9395
input_map, kwargs.get("is_rolling_batch"), tokenizer)
9496
elif is_3p_request(invoke_type):
95-
_inputs, _param = parse_3p_request(input_map,
96-
kwargs.get("is_rolling_batch"),
97-
tokenizer, invoke_type)
97+
inputs, param = parse_3p_request(input_map,
98+
kwargs.get("is_rolling_batch"),
99+
tokenizer, invoke_type)
98100
else:
99-
_inputs = input_map.pop("inputs", input_map)
100-
_param = input_map.pop("parameters", {})
101-
102-
request_input.input_text = _inputs
103-
request_input.parameters = _param
104-
# assign input_ids
105-
if kwargs.get("tokenizer"):
101+
inputs = input_map.pop("inputs", input_map)
102+
param = input_map.pop("parameters", {})
103+
104+
request_input.input_text = inputs
105+
request_input.parameters = param
106+
# assigns input_ids
107+
# TODO: for dynamic batching, or HF pipeline, tokenizer is applied differently.
108+
if kwargs.get("tokenizer") and kwargs.get("is_rolling_batch"):
106109
request_input.input_ids = tokenizer.encode(request_input.input_text)
107110

111+
# TODO: Instead of modifying user parameters, maintain this in server_parameters.
112+
# Added here for backward compatibility
108113
# re-organize the parameters
109-
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
114+
if kwargs.get("is_rolling_batch"):
110115
if "stream" in input_map:
111-
request_input.parameters["stream"] = input_map.pop("stream")
116+
request_input.stream = input_map.pop("stream")
112117
if "cached_prompt" in input_map:
113118
request_input.parameters["cached_prompt"] = input_map.pop(
114119
"cached_prompt")
@@ -124,18 +129,20 @@ def add_server_maintained_params(request_input: TextInput, input_item: Input,
124129
if input_item.contains_key("seed"):
125130
request_input.server_parameters["seed"] = input_item.get_as_string(
126131
key="seed")
132+
133+
# setting the output formatter
134+
output_formatter = request_input.server_parameters.pop("output_formatter", None)
127135
if not "output_formatter" in request_input.server_parameters:
128-
request_input.server_parameters["output_formatter"] = kwargs.get(
129-
"configs").output_formatter
136+
output_formatter = kwargs.get("configs").output_formatter
130137

131-
request_input.output_formatter = request_input.server_parameters.get(
132-
"output_formatter")
138+
request_input.output_formatter = output_formatter
133139

134140
if request_input.output_formatter == "json" or request_input.output_formatter == "sse":
135-
request_input.tgi_compat = kwargs.get("configs").tgi_compat
141+
request_input.tgi_compat = kwargs.get("configs").tgi_compa
136142

137143
# duplicating parameters for client side batching
138-
if isinstance(request_input.input_text, list):
144+
if isinstance(request_input.input_text, list) and len(
145+
request_input.input_text) > 1:
139146
parameters = []
140147
for _ in range(len(request_input.input_text)):
141148
parameters.append(request_input.server_parameters.copy())
@@ -147,22 +154,28 @@ def parse_adapters(request_input: TextInput, input_item: Input,
147154
adapter_registry = kwargs.get("adapter_registry")
148155
# if adapter registry exists and not empty, then we assume, peft is supported for the incoming
149156
if adapter_registry:
157+
input_len = len(request_input.input_text) if isinstance(
158+
request_input.input_text, list) else 1
150159
adapters_per_item = _fetch_adapters_from_input(input_map, input_item)
151160
if adapters_per_item:
152161
_validate_adapters(adapters_per_item,
153162
kwargs.get("adapter_registry"))
154163
else:
155164
# inference with just base model.
156-
adapters_per_item = [""] * len(request_input.input_text)
165+
adapters_per_item = [""] * input_len
157166

158-
if len(request_input.input_text) != len(adapters_per_item):
167+
if input_len != len(adapters_per_item):
159168
raise ValueError(
160169
f"Number of adapters is not equal to the number of inputs")
161170
# lookup the adapter registry to get the adapter details of the registered adapter.
162-
request_input.adapters = [
171+
adapters_data = [
163172
kwargs.get("adapter_registry").get(adapter, None)
164-
for adapter in adapter_registry
173+
for adapter in adapters_per_item
165174
]
175+
if len(adapters_data) == 1:
176+
adapters_data = adapters_data[0]
177+
178+
request_input.adapters = adapters_data
166179

167180

168181
def _fetch_adapters_from_input(input_map: dict, input_item: Input):

engines/python/setup/djl_python/request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def __init__(self, request_input: TextInput = None):
4242
self.adapter = request_input.adapters
4343

4444
# output formatter
45-
stream = self.request_input.parameters.get("stream", False)
4645
self.output_formatter, self.content_type = get_output_formatter(
47-
request_input.output_formatter, stream, request_input.tgi_compat)
46+
request_input.output_formatter, request_input.stream, request_input.tgi_compat)
47+
request_input.output_formatter = self.output_formatter
4848
self.legacy_formatter = self._is_output_formatter_legacy()
4949

5050
self.request_output = TextGenerationOutput(request_id=self.id,

engines/python/setup/djl_python/request_io.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,12 @@ class RequestInput:
124124
Attributes:
125125
request_id: The request ID.
126126
output_formatter: Output formatter of the request
127-
parameters: parameters in the request payload, will be used in the output formatter
127+
parameters: parameters in the request payload
128+
server_parameters: parameters that are modified by the built-in handlers to support backend engines.
128129
"""
129130
request_id: int
130131
output_formatter: Union[Callable, str] = None
132+
stream: Optional[bool] = False
131133
parameters: Dict = field(default_factory=lambda: {})
132134
server_parameters: Dict = field(default_factory=lambda: {})
133135
tgi_compat: bool = False
@@ -142,11 +144,10 @@ class TextInput(RequestInput):
142144
adapters: adapter used for the request.
143145
tokenizer: tokenizer used for the request.
144146
"""
145-
input_text: str = None
147+
input_text: Union[str, List[str]] = None
146148
input_ids: List[int] = field(default_factory=lambda: [])
147149
adapters: Optional[Any] = None
148150
tokenizer: Optional[Any] = None
149-
found_adapters: bool = False
150151

151152
def prompt_tokens_length(self) -> int:
152153
return len(self.input_ids)

engines/python/setup/djl_python/tensorrt_llm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class TRTLLMService(object):
2929
"""
3030

3131
def __init__(self):
32+
self.input_format_args = None
3233
self.initialized = False
3334
self.trt_configs = None
3435
self.rolling_batch = None
@@ -40,6 +41,7 @@ def initialize(self, properties: dict):
4041
self.rolling_batch = TRTLLMRollingBatch(
4142
self.trt_configs.model_id_or_path, properties, self.trt_configs)
4243
self.tokenizer = self.rolling_batch.get_tokenizer()
44+
self.input_format_args = self.get_input_format_args()
4345
self.initialized = True
4446
return
4547

@@ -54,16 +56,14 @@ def inference(self, inputs: Input) -> Output:
5456
"""
5557
Does preprocessing and sends new requests to the rolling batch script for inference
5658
57-
:param inputs (Input): a batch of inputs, each corresponding to a new request
59+
:param inputs: (Input) a batch of inputs, each corresponding to a new request
5860
5961
:return outputs (Output): a batch of outputs that contain status code, output text, and other information
6062
"""
6163
outputs = Output()
62-
kwargs = self.__dict__
63-
kwargs[
64-
"configs"] = self.trt_configs # TODO: Rename it to configs, so it would uniform in all handlers
6564

66-
parsed_input = parse_input_with_formatter(inputs, **kwargs)
65+
parsed_input = parse_input_with_formatter(inputs,
66+
**self.input_format_args)
6767
if len(parsed_input.requests) == 0:
6868
for i in range(len(parsed_input.batch)):
6969
err = parsed_input.errors.get(i)

engines/python/setup/djl_python/tensorrt_llm_python.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def _get_generation_result_from_python_backend(generations, inputs_size):
5555
log_prob = curr_cum_log_prob - cum_log_probs[i]
5656
token_result = {
5757
'id':
58-
_get_value_based_on_tensor(generation[i].token_id,
59-
index=0),
58+
_get_value_based_on_tensor(generation[i].token_id,
59+
index=0),
6060
'text':
61-
generation[i].token_text,
61+
generation[i].token_text,
6262
'log_prob':
63-
log_prob if i < len(tokens_results) else curr_cum_log_prob,
63+
log_prob if i < len(tokens_results) else curr_cum_log_prob,
6464
}
6565
cum_log_probs[i] = curr_cum_log_prob
6666
tokens_results[i].append(token_result)
@@ -100,7 +100,7 @@ def get_input_format_args(self):
100100
return {
101101
"configs": self.trt_configs,
102102
"tokenizer":
103-
None, # tokenizer, for chat completions is not supported for python backend.
103+
None, # tokenizer, for chat completions is not supported for python backend.
104104
}
105105

106106
def inference(self, inputs: Input) -> Output:
@@ -121,10 +121,9 @@ def inference(self, inputs: Input) -> Output:
121121
outputs.add(err, key="data", batch_index=i)
122122
return outputs
123123

124-
input_data, input_size = get_input_details(parsed_input.requests,
125-
parsed_input.errors,
126-
parsed_input.batch)
127-
params = parsed_input.requests[0].request_input.server_parameters
124+
input_data, input_size, params, _ = get_input_details(parsed_input.requests,
125+
parsed_input.errors,
126+
parsed_input.batch)
128127

129128
if "output_formatter" in params:
130129
# output formatter is not supported for TensorRT-LLM python backend.

engines/python/setup/djl_python/transformers_neuronx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def partition(self, properties: dict):
219219
self.initialized = True
220220

221221
def inference(self, inputs: Input) -> Output:
222-
parsed_input = parse_input_with_formatter(inputs, **self.__dict__)
222+
parsed_input = parse_input_with_formatter(inputs,
223+
**self.input_format_args)
223224
errors = parsed_input.errors
224225
requests = parsed_input.requests
225226
outputs = Output()
@@ -229,8 +230,7 @@ def inference(self, inputs: Input) -> Output:
229230
self.rolling_batch)
230231

231232
batch = parsed_input.batch
232-
input_data, input_size = get_input_details(requests, errors, batch)
233-
parameters = parsed_input.requests[0].request_input.server_parameters
233+
input_data, input_size, parameters, _ = get_input_details(requests, errors, batch)
234234
# Remove rolling batch default parameters
235235
parameters.pop("output_formatter", None)
236236
parameters.pop("stream", None)

engines/python/setup/djl_python/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,10 @@ def get_input_details(requests, errors, batch):
114114
input_size = []
115115
adapters = []
116116
idx = 0
117-
request_input = requests[0].request_input
118-
parameters = request_input.server_parameters
117+
parameters = requests[0].request_input.server_parameters
118+
if isinstance(parameters, list):
119+
parameters = parameters[0]
120+
119121
for i in range(len(batch)):
120122
if i in errors:
121123
input_size.append(0)
@@ -134,4 +136,4 @@ def get_input_details(requests, errors, batch):
134136

135137
idx += 1
136138
adapters = adapters if adapters else None
137-
return input_data, input_size, adapters
139+
return input_data, input_size, parameters, adapters

0 commit comments

Comments
 (0)