Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict

from djl_python.chat_completions.chat_properties import ChatProperties


def is_chat_completions_request(inputs: map) -> bool:
def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
tokenizer):
if not is_rolling_batch:
raise ValueError(
Expand All @@ -28,14 +30,14 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
raise AttributeError(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")
chat_params = ChatProperties(**inputs)
_param = chat_params.model_dump(by_alias=True, exclude_none=True)
_messages = _param.pop("messages")
_inputs = tokenizer.apply_chat_template(_messages, tokenize=False)
_param[
chat_params = ChatProperties(**input_map)
param = chat_params.model_dump(by_alias=True, exclude_none=True)
messages = param.pop("messages")
inputs = tokenizer.apply_chat_template(messages, tokenize=False)
param[
"do_sample"] = chat_params.temperature is not None and chat_params.temperature > 0.0
_param["details"] = True # Enable details for chat completions
_param[
param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

return _inputs, _param
return inputs, param
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, ParsedInput, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import ParsedInput, InputFormatConfigs, parse_input_with_formatter

ARCHITECTURES_2_TASK = {
"TapasForQuestionAnswering": "table-question-answering",
Expand Down
176 changes: 176 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import List, Union, Callable, Any

from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.encode_decode import decode
from djl_python.three_p.three_p_utils import is_3p_request, parse_3p_request


@dataclass
class ParsedInput:
input_data: List[str]
input_size: List[int]
parameters: List[dict]
errors: dict
batch: list
is_client_side_batch: list = field(default_factory=lambda: [])
adapters: list = None


@dataclass
class InputFormatConfigs:
is_rolling_batch: bool = False
is_adapters_supported: bool = False
output_formatter: Union[str, Callable] = None
tokenizer: Any = None


def parse_input_with_formatter(inputs: Input,
input_format_configs: InputFormatConfigs,
adapter_registry: dict = {}) -> ParsedInput:
"""
Preprocessing function that extracts information from Input objects.
:param input_format_configs: format configurations for the input.
:param inputs :(Input) a batch of inputs, each corresponding to a new request

:return parsed_input: object of data class that contains all parsed input details
"""

input_data = []
input_size = []
parameters = []
adapters = []
errors = {}
found_adapters = False
batch = inputs.get_batches()
# only for dynamic batch
is_client_side_batch = [False for _ in range(len(batch))]
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
invoke_type = item.get_property("X-Amzn-SageMaker-Forwarded-Api")
input_map = decode(item, content_type)
_inputs, _param, is_client_side_batch[i] = _parse_inputs_params(
input_map, item, input_format_configs, invoke_type)
if input_format_configs.is_adapters_supported:
adapters_per_item, found_adapter_per_item = _parse_adapters(
_inputs, input_map, item, adapter_registry)
except Exception as e: # pylint: disable=broad-except
logging.warning(f"Parse input failed: {i}")
input_size.append(0)
errors[i] = str(e)
continue

input_data.extend(_inputs)
input_size.append(len(_inputs))

if input_format_configs.is_adapters_supported:
adapters.extend(adapters_per_item)
found_adapters = found_adapter_per_item or found_adapters

for _ in range(input_size[i]):
parameters.append(_param)

if found_adapters and adapters is not None:
adapter_data = [
adapter_registry.get(adapter, None) for adapter in adapters
]
else:
adapter_data = None

return ParsedInput(input_data=input_data,
input_size=input_size,
parameters=parameters,
errors=errors,
batch=batch,
is_client_side_batch=is_client_side_batch,
adapters=adapter_data)


def _parse_inputs_params(input_map, item, input_format_configs, invoke_type):
if is_chat_completions_request(input_map):
_inputs, _param = parse_chat_completions_request(
input_map, input_format_configs.is_rolling_batch,
input_format_configs.tokenizer)
elif is_3p_request(invoke_type):
_inputs, _param = parse_3p_request(
input_map, input_format_configs.is_rolling_batch,
input_format_configs.tokenizer, invoke_type)
else:
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})

# Add some additional parameters that are necessary.
# Per request streaming is only supported by rolling batch
if input_format_configs.is_rolling_batch:
_param["stream"] = input_map.pop("stream", _param.get("stream", False))

if "cached_prompt" in input_map:
_param["cached_prompt"] = input_map.pop("cached_prompt")
if "seed" not in _param:
# set server provided seed if seed is not part of request
if item.contains_key("seed"):
_param["seed"] = item.get_as_string(key="seed")
if not "output_formatter" in _param:
_param["output_formatter"] = input_format_configs.output_formatter

if isinstance(_inputs, list):
return _inputs, _param, True
else:
return [_inputs], _param, False


def _parse_adapters(_inputs, input_map, item,
adapter_registry) -> (List, bool):
adapters_per_item = _fetch_adapters_from_input(input_map, item)
found_adapter_per_item = False
if adapters_per_item:
_validate_adapters(adapters_per_item, adapter_registry)
found_adapter_per_item = True
else:
# inference with just base model.
adapters_per_item = [""] * len(_inputs)

if len(_inputs) != len(adapters_per_item):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
return adapters_per_item, found_adapter_per_item


def _fetch_adapters_from_input(input_map: dict, inputs: Input):
adapters_per_item = []
if "adapters" in input_map:
adapters_per_item = input_map.pop("adapters", [])

# check content, possible in workflow approach
if inputs.contains_key("adapter"):
adapters_per_item = inputs.get_as_string("adapter")

# check properties, possible from header
if "adapter" in inputs.get_properties():
adapters_per_item = inputs.get_properties()["adapter"]

if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]

return adapters_per_item


def _validate_adapters(adapters_per_item, adapter_registry):
for adapter_name in adapters_per_item:
if adapter_name and adapter_name not in adapter_registry:
raise ValueError(f"Adapter {adapter_name} is not registered")
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from djl_python.rolling_batch.trtllm_rolling_batch import TRTLLMRollingBatch
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.tensorrt_llm_python import TRTLLMPythonService
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
from typing import List, Tuple


Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from djl_python.encode_decode import encode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter


def _get_value_based_on_tensor(value, index=None):
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled
from djl_python.neuron_utils.model_loader import TNXModelLoader, OptimumModelLoader
from djl_python.neuron_utils.utils import task_from_config, build_vllm_rb_properties
from djl_python.utils import InputFormatConfigs, parse_input_with_formatter, rolling_batch_inference
from djl_python.utils import rolling_batch_inference
from djl_python.input_parser import InputFormatConfigs, parse_input_with_formatter
from typing import Tuple, List

model = None
Expand Down
Loading