Skip to content

Commit 3ef1bb7

Browse files
committed
Merge branch 'master' into v10_neo_patches
2 parents c1556ec + c578e6f commit 3ef1bb7

File tree

25 files changed

+904
-1560
lines changed

25 files changed

+904
-1560
lines changed

.github/workflows/client-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
cd tests
6363
djl-serving -m test::Python=file://$PWD/python &> output.log &
6464
sleep 15
65-
python test_client.py
65+
python integration/test_client.py
6666
jobs
6767
kill %1
6868
- name: On failure step
@@ -109,7 +109,7 @@ jobs:
109109
./gradlew --stop
110110
./gradlew :serving:run --args="-m test::Python=file:$(pwd -W)/tests/python" &> output.log &
111111
sleep 30
112-
cd tests/ && python test_client.py
112+
cd tests/ && python integration/test_client.py
113113
- name: On failure step
114114
if: ${{ failure() }}
115115
shell: bash

.github/workflows/llm_inf2_integration.yml

Lines changed: 0 additions & 463 deletions
This file was deleted.

.github/workflows/llm_integration.yml

Lines changed: 96 additions & 896 deletions
Large diffs are not rendered by default.

engines/python/setup/djl_python/huggingface.py

Lines changed: 40 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,17 @@
2222
AutoModelForQuestionAnswering, StoppingCriteria, StoppingCriteriaList)
2323
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2424
from peft import PeftConfig, PeftModel, PeftModelForCausalLM
25-
from typing import Tuple, List
25+
from typing import Tuple, List, Callable, Dict
2626

2727
from djl_python.encode_decode import encode
2828
from djl_python.inputs import Input
2929
from djl_python.outputs import Output
30+
from djl_python.rolling_batch.rolling_batch import RollingBatch
3031
from djl_python.streaming_utils import StreamingUtils
3132

3233
from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
3334
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
34-
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs
35+
from djl_python.utils import parse_input_with_formatter, InputFormatConfigs, ParsedInput, rolling_batch_inference
3536

3637
ARCHITECTURES_2_TASK = {
3738
"TapasForQuestionAnswering": "table-question-answering",
@@ -140,6 +141,7 @@ def __init__(self):
140141
self.adapters = None
141142
self.hf_configs = None
142143
self.input_format_configs = None
144+
self.parsed_input = None
143145

144146
def initialize(self, properties: dict):
145147
self.hf_configs = HuggingFaceProperties(**properties)
@@ -230,13 +232,14 @@ def parse_input(
230232
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
231233
"""
232234

233-
parsed_input = parse_input_with_formatter(inputs,
234-
self.input_format_configs,
235-
self.adapter_registry)
236-
self.adapters = parsed_input.adapters if parsed_input.found_adapters else None
237-
return parsed_input.input_data, parsed_input.input_size, parsed_input.parameters, parsed_input.errors, parsed_input.batch
235+
self.parsed_input = parse_input_with_formatter(
236+
inputs, self.input_format_configs, self.adapter_registry)
237+
self.adapters = self.parsed_input.adapters
238+
return (self.parsed_input.input_data, self.parsed_input.input_size,
239+
self.parsed_input.parameters, self.parsed_input.errors,
240+
self.parsed_input.batch)
238241

239-
def inference(self, inputs):
242+
def inference(self, inputs: Input) -> Output:
240243
outputs = Output()
241244

242245
input_data, input_size, parameters, errors, batch = self.parse_input(
@@ -254,70 +257,28 @@ def inference(self, inputs):
254257
return outputs
255258

256259
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
257-
if inputs.get_property("reset_rollingbatch"):
258-
self.rolling_batch.reset()
259-
if self.adapters is not None:
260-
adapter_data = [
261-
self.adapter_registry.get(adapter, None)
262-
for adapter in self.adapters
263-
]
264-
else:
265-
adapter_data = None
266-
result = self.rolling_batch.inference(input_data,
267-
parameters,
268-
adapters=adapter_data)
269-
idx = 0
270-
for i in range(len(batch)):
271-
err = errors.get(i)
272-
if err:
273-
err = {"data": "", "last": True, "code": 424, "error": err}
274-
outputs.add(Output.binary_encode(err),
275-
key="data",
276-
batch_index=i)
277-
outputs.add_property(f"batch_{i}_Content-Type",
278-
"application/json")
279-
else:
280-
content_type = result[idx].pop("content_type")
281-
outputs.add(Output.binary_encode(result[idx]),
282-
key="data",
283-
batch_index=i)
284-
if content_type is not None:
285-
outputs.add_property(f"batch_{i}_Content-Type",
286-
content_type)
287-
idx += 1
288-
289-
return outputs
260+
return rolling_batch_inference(self.parsed_input, inputs, outputs,
261+
self.rolling_batch)
290262
elif is_streaming_enabled(self.hf_configs.enable_streaming):
291-
if len(batch) > 1:
292-
raise NotImplementedError(
293-
"Dynamic batch not supported for generic streaming")
294-
outputs.add_property("content-type", "application/jsonlines")
295-
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
296-
outputs.add_stream_content(
297-
StreamingUtils.use_hf_default_streamer(
298-
self.model, self.tokenizer, input_data,
299-
self.hf_configs.device, **parameters[0]))
300-
else:
301-
stream_generator = StreamingUtils.get_stream_generator(
302-
"Accelerate")
303-
outputs.add_stream_content(
304-
stream_generator(self.model, self.tokenizer, input_data,
305-
self.hf_configs.device, **parameters[0]))
306-
return outputs
263+
return self._streaming_inference(batch, input_data, outputs,
264+
parameters)
265+
else:
266+
return self._dynamic_batch_inference(batch, errors, input_data,
267+
input_size, inputs, outputs,
268+
parameters)
307269

270+
def _dynamic_batch_inference(self, batch, errors, input_data, input_size,
271+
inputs, outputs, parameters):
308272
if not all(p == parameters[0] for p in parameters):
309273
raise ValueError(
310274
"In order to enable dynamic batching, all input batches must have the same parameters"
311275
)
312-
313276
if isinstance(self.model, PeftModelForCausalLM):
314277
if self.adapters is None:
315278
# Inference with only base model
316279
self.adapters = [""] * len(input_data)
317280
parameters[0]["adapters"] = self.adapters
318-
319281
prediction = self.hf_pipeline(input_data, **parameters[0])
320-
321282
offset = 0
322283
for i, item in enumerate(batch):
323284
content_type = item.get_property("Content-Type")
@@ -341,7 +302,26 @@ def inference(self, inputs):
341302
accept,
342303
key=inputs.get_content().key_at(i))
343304
offset += input_size[i]
305+
return outputs
344306

307+
def _streaming_inference(self, batch, input_data, outputs, parameters):
308+
if len(batch) > 1:
309+
raise NotImplementedError(
310+
"Dynamic batch not supported for generic streaming")
311+
outputs.add_property("content-type", "application/jsonlines")
312+
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
313+
outputs.add_stream_content(
314+
StreamingUtils.use_hf_default_streamer(self.model,
315+
self.tokenizer,
316+
input_data,
317+
self.hf_configs.device,
318+
**parameters[0]))
319+
else:
320+
stream_generator = StreamingUtils.get_stream_generator(
321+
"Accelerate")
322+
outputs.add_stream_content(
323+
stream_generator(self.model, self.tokenizer, input_data,
324+
self.hf_configs.device, **parameters[0]))
345325
return outputs
346326

347327
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):

engines/python/setup/djl_python/neuron_utils/model_loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from djl_python.neuron_utils.utils import NeuronXModelAdapter, get_neuronxcc_version
2929
from huggingface_hub import hf_hub_download
3030

31+
# Temporary Fix: These loggers are disabled during vLLM import.
32+
# Remove when fixed in vLLM
33+
logging.getLogger("NEURON_CC_WRAPPER").disabled = False
34+
logging.getLogger("NEURON_CACHE").disabled = False
35+
3136

3237
class ModelLoader(ABC):
3338

engines/python/setup/djl_python/properties_manager/lmi_dist_rb_properties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LmiDistRbProperties(Properties):
4545
speculative_length: int = 5
4646
draft_model_tp_size: int = 1
4747
record_acceptance_rate: Optional[bool] = False
48+
speculative_telemetry: Optional[bool] = True
4849
enable_lora: Optional[bool] = False
4950
max_loras: Optional[int] = 4
5051
max_lora_rank: Optional[int] = 16

engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
1313
import logging
14+
import os
1415
from typing import List
1516
from collections import OrderedDict, defaultdict
1617

@@ -26,6 +27,7 @@
2627
get_speculative_decoding_metrics_record, update_request_cache_with_output,
2728
supports_speculative_decoding, get_lora_request_params, DTYPE_MAPPER,
2829
FINISH_REASON_MAPPER)
30+
from djl_python.telemetry import telemetry_manager
2931
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties
3032

3133
_WARMUP_PREFILL_TOKENS = 4096
@@ -187,14 +189,21 @@ def inference(self,
187189
self.request_cache, request_output, self.get_tokenizer())
188190
# Record SD metrics
189191
completion_output = request_output.outputs[0]
190-
if self.lmi_dist_config.record_acceptance_rate and request_output.finished:
191-
if self.supports_speculative_decoding and completion_output.acceptance_history:
192-
record = get_speculative_decoding_metrics_record(
193-
completion_output, request_output)
194-
logging.info(f"Speculative Decoding {record}")
195-
else:
196-
logging.warning(
197-
f"Ignoring logging speculative decoding metrics")
192+
if (
193+
self.lmi_dist_config.record_acceptance_rate
194+
or self.lmi_dist_config.speculative_telemetry
195+
) and self.lmi_dist_config.speculative_draft_model and request_output.finished:
196+
try:
197+
if self.supports_speculative_decoding and completion_output.acceptance_history:
198+
record = get_speculative_decoding_metrics_record(
199+
completion_output, request_output)
200+
if self.lmi_dist_config.record_acceptance_rate:
201+
logging.info(f"Speculative Decoding {record}")
202+
if self.lmi_dist_config.speculative_telemetry and os.environ.get(
203+
"SAGEMAKER_SECURE_MODE") == "true":
204+
telemetry_manager.record_speculative(record)
205+
except:
206+
logging.debug("SD telemetry collection failed, ignore")
198207

199208
for request in self.active_requests:
200209
request_output = request.request_output

engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def get_speculative_decoding_metrics_record(
166166
completion_output.acceptance_history)
167167
else:
168168
record["mean_acceptance"] = 0
169+
record["acceptance_history_len"] = len(
170+
completion_output.acceptance_history)
169171
record["prompt_size"] = len(request_output.prompt_token_ids)
170172
record["output_size"] = len(completion_output.token_ids)
171173
return record

engines/python/setup/djl_python/sm_log_filter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
# https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html
2121
class SMLogFilter(logging.Filter):
22-
sm_log_markers = ['ModelServerError', 'UserScriptError', 'SysHealth']
22+
sm_log_markers = [
23+
'ModelServerError', 'UserScriptError', 'SysHealth',
24+
'ModelServerTelemetry'
25+
]
2326
counter = defaultdict(int)
2427

2528
def filter(self, record):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import logging
15+
import time
16+
17+
SPECULATIVE_FREQUENCY_SEC = 30.0
18+
19+
20+
class TelemetryManager:
21+
22+
def __init__(self):
23+
self.reset_speculative()
24+
25+
def record_speculative(self, data):
26+
self.speculative_acceptance_rate_count = self.speculative_acceptance_rate_count + data[
27+
"acceptance_history_len"]
28+
self.speculative_acceptance_rate_total = self.speculative_acceptance_rate_total + data[
29+
"mean_acceptance"] * data["acceptance_history_len"]
30+
if time.time(
31+
) - self.speculative_sent_time > SPECULATIVE_FREQUENCY_SEC:
32+
mean_acceptance = 1.0 * self.speculative_acceptance_rate_total / self.speculative_acceptance_rate_count
33+
logging.info(
34+
f"ModelServerTelemetry: Speculative Decoding Mean Acceptance: {mean_acceptance} rate"
35+
)
36+
self.reset_speculative()
37+
38+
def reset_speculative(self):
39+
self.speculative_sent_time = time.time()
40+
self.speculative_acceptance_rate_count = 0
41+
self.speculative_acceptance_rate_total = 0.0
42+
43+
44+
telemetry_manager = TelemetryManager()

0 commit comments

Comments
 (0)