Skip to content

Commit 9c04a01

Browse files
committed
[python] move parse input functions to input_parser.py
1 parent c9e7222 commit 9c04a01

File tree

7 files changed

+196
-175
lines changed

7 files changed

+196
-175
lines changed

engines/python/setup/djl_python/chat_completions/chat_utils.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
13+
from typing import Dict
14+
1315
from djl_python.chat_completions.chat_properties import ChatProperties
1416

1517

16-
def is_chat_completions_request(inputs: map) -> bool:
18+
def is_chat_completions_request(inputs: Dict) -> bool:
1719
return "messages" in inputs
1820

1921

20-
def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
22+
def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
2123
tokenizer):
2224
if not is_rolling_batch:
2325
raise ValueError(
@@ -28,14 +30,14 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
2830
raise AttributeError(
2931
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
3032
f"please ensure that your tokenizer supports chat templates.")
31-
chat_params = ChatProperties(**inputs)
32-
_param = chat_params.model_dump(by_alias=True, exclude_none=True)
33-
_messages = _param.pop("messages")
34-
_inputs = tokenizer.apply_chat_template(_messages, tokenize=False)
35-
_param[
33+
chat_params = ChatProperties(**input_map)
34+
param = chat_params.model_dump(by_alias=True, exclude_none=True)
35+
messages = param.pop("messages")
36+
inputs = tokenizer.apply_chat_template(messages, tokenize=False)
37+
param[
3638
"do_sample"] = chat_params.temperature is not None and chat_params.temperature > 0.0
37-
_param["details"] = True # Enable details for chat completions
38-
_param[
39+
param["details"] = True # Enable details for chat completions
40+
param[
3941
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"
4042

41-
return _inputs, _param
43+
return inputs, param

engines/python/setup/djl_python/huggingface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232

3333
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
3434
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
35-
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, ParsedInput, rolling_batch_inference
35+
from djl_python.utils import rolling_batch_inference
36+
from djl_python.input_parser import ParsedInput, InputFormatConfigs, parse_input_with_formatter
3637

3738
ARCHITECTURES_2_TASK = {
3839
"TapasForQuestionAnswering": "table-question-answering",
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
import logging
14+
from dataclasses import dataclass, field
15+
from typing import List, Union, Callable, Any
16+
17+
from djl_python import Input
18+
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
19+
from djl_python.encode_decode import decode
20+
from djl_python.three_p.three_p_utils import is_3p_request, parse_3p_request
21+
22+
23+
@dataclass
24+
class ParsedInput:
25+
input_data: List[str]
26+
input_size: List[int]
27+
parameters: List[dict]
28+
errors: dict
29+
batch: list
30+
is_client_side_batch: list = field(default_factory=lambda: [])
31+
adapters: list = None
32+
33+
34+
@dataclass
35+
class InputFormatConfigs:
36+
is_rolling_batch: bool = False
37+
is_adapters_supported: bool = False
38+
output_formatter: Union[str, Callable] = None
39+
tokenizer: Any = None
40+
41+
42+
def parse_input_with_formatter(inputs: Input,
43+
input_format_configs: InputFormatConfigs,
44+
adapter_registry: dict = {}) -> ParsedInput:
45+
"""
46+
Preprocessing function that extracts information from Input objects.
47+
:param input_format_configs: format configurations for the input.
48+
:param inputs :(Input) a batch of inputs, each corresponding to a new request
49+
50+
:return parsed_input: object of data class that contains all parsed input details
51+
"""
52+
53+
input_data = []
54+
input_size = []
55+
parameters = []
56+
adapters = []
57+
errors = {}
58+
found_adapters = False
59+
batch = inputs.get_batches()
60+
# only for dynamic batch
61+
is_client_side_batch = [False for _ in range(len(batch))]
62+
for i, item in enumerate(batch):
63+
try:
64+
content_type = item.get_property("Content-Type")
65+
invoke_type = item.get_property("X-Amzn-SageMaker-Forwarded-Api")
66+
input_map = decode(item, content_type)
67+
_inputs, _param, is_client_side_batch[i] = _parse_inputs_params(
68+
input_map, item, input_format_configs, invoke_type)
69+
if input_format_configs.is_adapters_supported:
70+
adapters_per_item, found_adapter_per_item = _parse_adapters(
71+
_inputs, input_map, item, adapter_registry)
72+
except Exception as e: # pylint: disable=broad-except
73+
logging.warning(f"Parse input failed: {i}")
74+
input_size.append(0)
75+
errors[i] = str(e)
76+
continue
77+
78+
input_data.extend(_inputs)
79+
input_size.append(len(_inputs))
80+
81+
if input_format_configs.is_adapters_supported:
82+
adapters.extend(adapters_per_item)
83+
found_adapters = found_adapter_per_item or found_adapters
84+
85+
for _ in range(input_size[i]):
86+
parameters.append(_param)
87+
88+
if found_adapters and adapters is not None:
89+
adapter_data = [
90+
adapter_registry.get(adapter, None) for adapter in adapters
91+
]
92+
else:
93+
adapter_data = None
94+
95+
return ParsedInput(input_data=input_data,
96+
input_size=input_size,
97+
parameters=parameters,
98+
errors=errors,
99+
batch=batch,
100+
is_client_side_batch=is_client_side_batch,
101+
adapters=adapter_data)
102+
103+
104+
def _parse_inputs_params(input_map, item, input_format_configs, invoke_type):
105+
if is_chat_completions_request(input_map):
106+
_inputs, _param = parse_chat_completions_request(
107+
input_map, input_format_configs.is_rolling_batch,
108+
input_format_configs.tokenizer)
109+
elif is_3p_request(invoke_type):
110+
_inputs, _param = parse_3p_request(
111+
input_map, input_format_configs.is_rolling_batch,
112+
input_format_configs.tokenizer, invoke_type)
113+
else:
114+
_inputs = input_map.pop("inputs", input_map)
115+
_param = input_map.pop("parameters", {})
116+
117+
# Add some additional parameters that are necessary.
118+
# Per request streaming is only supported by rolling batch
119+
if input_format_configs.is_rolling_batch:
120+
_param["stream"] = input_map.pop("stream", _param.get("stream", False))
121+
122+
if "cached_prompt" in input_map:
123+
_param["cached_prompt"] = input_map.pop("cached_prompt")
124+
if "seed" not in _param:
125+
# set server provided seed if seed is not part of request
126+
if item.contains_key("seed"):
127+
_param["seed"] = item.get_as_string(key="seed")
128+
if not "output_formatter" in _param:
129+
_param["output_formatter"] = input_format_configs.output_formatter
130+
131+
if isinstance(_inputs, list):
132+
return _inputs, _param, True
133+
else:
134+
return [_inputs], _param, False
135+
136+
137+
def _parse_adapters(_inputs, input_map, item,
138+
adapter_registry) -> (List, bool):
139+
adapters_per_item = _fetch_adapters_from_input(input_map, item)
140+
found_adapter_per_item = False
141+
if adapters_per_item:
142+
_validate_adapters(adapters_per_item, adapter_registry)
143+
found_adapter_per_item = True
144+
else:
145+
# inference with just base model.
146+
adapters_per_item = [""] * len(_inputs)
147+
148+
if len(_inputs) != len(adapters_per_item):
149+
raise ValueError(
150+
f"Number of adapters is not equal to the number of inputs")
151+
return adapters_per_item, found_adapter_per_item
152+
153+
154+
def _fetch_adapters_from_input(input_map: dict, inputs: Input):
155+
adapters_per_item = []
156+
if "adapters" in input_map:
157+
adapters_per_item = input_map.pop("adapters", [])
158+
159+
# check content, possible in workflow approach
160+
if inputs.contains_key("adapter"):
161+
adapters_per_item = inputs.get_as_string("adapter")
162+
163+
# check properties, possible from header
164+
if "adapter" in inputs.get_properties():
165+
adapters_per_item = inputs.get_properties()["adapter"]
166+
167+
if not isinstance(adapters_per_item, list):
168+
adapters_per_item = [adapters_per_item]
169+
170+
return adapters_per_item
171+
172+
173+
def _validate_adapters(adapters_per_item, adapter_registry):
174+
for adapter_name in adapters_per_item:
175+
if adapter_name and adapter_name not in adapter_registry:
176+
raise ValueError(f"Adapter {adapter_name} is not registered")

engines/python/setup/djl_python/tensorrt_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from djl_python.rolling_batch.trtllm_rolling_batch import TRTLLMRollingBatch
1717
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
1818
from djl_python.tensorrt_llm_python import TRTLLMPythonService
19-
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, rolling_batch_inference
19+
from djl_python.utils import rolling_batch_inference
20+
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
2021
from typing import List, Tuple
2122

2223

engines/python/setup/djl_python/tensorrt_llm_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from djl_python.encode_decode import encode
1313
from djl_python.inputs import Input
1414
from djl_python.outputs import Output
15-
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs
15+
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
1616

1717

1818
def _get_value_based_on_tensor(value, index=None):

engines/python/setup/djl_python/transformers_neuronx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled
2626
from djl_python.neuron_utils.model_loader import TNXModelLoader, OptimumModelLoader
2727
from djl_python.neuron_utils.utils import task_from_config, build_vllm_rb_properties
28-
from djl_python.utils import InputFormatConfigs, parse_input_with_formatter, rolling_batch_inference
28+
from djl_python.utils import rolling_batch_inference
29+
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
2930
from typing import Tuple, List
3031

3132
model = None

0 commit comments

Comments
 (0)