Skip to content

Commit c0a0e42

Browse files
committed
feat(hf-inference): fork for hf-inference optim (overcommit) and widget compat
* Env var settings: customize default num inference steps default content type env var default accept env var Diffusers, txt2img (and img2img when supported), make sure guidance scale defaults to 0 when num steps <=4 * Content-type / accept / serialization fixes: content type case ignore fix: content-type and accept parsing, more flexibility than an exact string match since there can be some additional params application/octet-stream support in content type deserialization, no reason not to accept it fix: avoid returning none as a serializer, return an error instead fix: de/serializer is not optional, do not support content type which we do not know what to do with fix: explicit error message when no content-type is provided * HF inference specificities Multi task support + /pipeline/<task> support for api-inference backward compat api inference compat responses fix(api inference): compat for text-classification token-classification fix: token classification api-inference-compat fix: image segmentation on hf inference zero shot classif: api inference compat substitute /pipeline/sentence-embeddings to /pipeline/feature-extraction for sentence transformers fix(api-inference): feature-extraction, flatten array, discard the batch size dim feat(hf-inference): disable custom handler * Build: add timm hf_xet dependencies (for object detection, xethub support) Dockerfile refacto: split requirements and source code layers, to optimize build time and enhance layer reuse * Memory footprint + kick and respawn (primary memory gc) feat(memory): reduce memory footprint on idle service backported and adapted from https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/app/idle.py 1. adding gunicorn instead of uvicorn to allow for wsgi/asgi workers to easily be suppressed when idle whithout stopping the entire service -> easy way to release memory whithout digging into the depth of the imported modules 2. memory consuming libs lazy load (transformers, diffusers, sentence_transformers) 3. pipeline lazy load as well The first 'cold start' request tends to be a bit slower than others but the footprint is reduced to the minimum when idle
1 parent 833b5a2 commit c0a0e42

21 files changed

+584
-323
lines changed

dockerfiles/pytorch/Dockerfile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ RUN apt-get update && \
3232
&& apt-get clean autoremove --yes \
3333
&& rm -rf /var/lib/{apt,dpkg,cache,log}
3434

35-
# Copying only necessary files as filtered by .dockerignore
36-
COPY . .
35+
RUN mkdir -p /var/lib/dpkg && touch /var/lib/dpkg/status
3736

3837
# Set Python 3.11 as the default python version
3938
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
@@ -47,6 +46,11 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
4746
# Upgrade pip
4847
RUN pip install --no-cache-dir --upgrade pip
4948

49+
COPY requirements.txt .
50+
RUN pip install -r requirements.txt && rm -rf /root/.cache
51+
52+
# Copying only necessary files as filtered by .dockerignore
53+
COPY . .
5054
# Install wheel and setuptools
5155
RUN pip install --no-cache-dir --upgrade pip ".[torch,st,diffusers]"
5256

requirements.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
kenlm@ git+https://github.com/kpu/kenlm@ba83eafdce6553addd885ed3da461bb0d60f8df7
2+
transformers[audio,sentencepiece,sklearn,vision]==4.51.3
3+
huggingface_hub[hf_transfer,hf_xet]==0.31.1
4+
Pillow
5+
librosa
6+
pyctcdecode>=0.3.0
7+
phonemizer
8+
ffmpeg
9+
starlette
10+
uvicorn
11+
gunicorn
12+
pandas
13+
orjson
14+
einops
15+
timm
16+
sentence_transformers==4.0.2
17+
diffusers==0.33.1
18+
accelerate==1.6.0
19+
torch==2.5.1
20+
torchvision
21+
torchaudio
22+
peft==0.15.1

scripts/entrypoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ if [[ ! -z "${HF_MODEL_DIR}" ]]; then
5959
fi
6060

6161
# Start the server
62-
exec uvicorn webservice_starlette:app --host 0.0.0.0 --port ${PORT}
62+
exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT}

setup.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
from __future__ import absolute_import
2-
2+
import os
33
from setuptools import find_packages, setup
44

5+
lib_folder = os.path.dirname(os.path.realpath(__file__))
6+
requirements_path = f"{lib_folder}/requirements.txt"
7+
install_requires = [] # Here we'll add: ["gunicorn", "docutils>=0.3", "lxml==0.5a7"]
8+
if os.path.isfile(requirements_path):
9+
with open(requirements_path) as f:
10+
install_requires = f.read().splitlines()
11+
12+
test_requirements_path = f"{lib_folder}/test-requirements.txt"
13+
if os.path.isfile(test_requirements_path):
14+
with open(test_requirements_path) as f:
15+
test_requirements = f.read().splitlines()
16+
517
# We don't declare our dependency on transformers here because we build with
618
# different packages for different variants
719

@@ -12,47 +24,14 @@
1224
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
1325
# libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg
1426

15-
install_requires = [
16-
# Due to an error affecting kenlm and cmake (see https://github.com/kpu/kenlm/pull/464)
17-
# Also see the transformers patch for it https://github.com/huggingface/transformers/pull/37091
18-
"kenlm@git+https://github.com/kpu/kenlm@ba83eafdce6553addd885ed3da461bb0d60f8df7",
19-
"transformers[sklearn,sentencepiece,audio,vision]==4.51.3",
20-
"huggingface_hub[hf_transfer]==0.30.2",
21-
# vision
22-
"Pillow",
23-
"librosa",
24-
# speech + torchaudio
25-
"pyctcdecode>=0.3.0",
26-
"phonemizer",
27-
"ffmpeg",
28-
# web api
29-
"starlette",
30-
"uvicorn",
31-
"pandas",
32-
"orjson",
33-
"einops",
34-
]
35-
3627
extras = {}
37-
3828
extras["st"] = ["sentence_transformers==4.0.2"]
3929
extras["diffusers"] = ["diffusers==0.33.1", "accelerate==1.6.0"]
4030
# Includes `peft` as PEFT requires `torch` so having `peft` as a core dependency
4131
# means that `torch` will be installed even if the `torch` extra is not specified.
4232
extras["torch"] = ["torch==2.5.1", "torchvision", "torchaudio", "peft==0.15.1"]
43-
extras["test"] = [
44-
"pytest==7.2.1",
45-
"pytest-xdist",
46-
"parameterized",
47-
"psutil",
48-
"datasets",
49-
"pytest-sugar",
50-
"mock==2.0.0",
51-
"docker",
52-
"requests",
53-
"tenacity",
54-
]
5533
extras["quality"] = ["isort", "ruff"]
34+
extras["test"] = test_requirements
5635
extras["inf2"] = ["optimum-neuron"]
5736
extras["google"] = ["google-cloud-storage", "crcmod==1.7"]
5837

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
import os
23
from typing import Union
34

45
from transformers.utils.import_utils import is_torch_bf16_gpu_available
@@ -63,6 +64,16 @@ def __call__(
6364
kwargs.pop("num_images_per_prompt")
6465
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
6566

67+
if "num_inference_steps" not in kwargs:
68+
default_num_steps = os.environ.get("DEFAULT_NUM_INFERENCE_STEPS")
69+
if default_num_steps:
70+
kwargs["num_inference_steps"] = int(default_num_steps)
71+
72+
if "guidance_scale" not in kwargs:
73+
guidance_scale = os.environ.get("DEFAULT_GUIDANCE_SCALE")
74+
if guidance_scale is not None:
75+
kwargs["guidance_scale"] = float(guidance_scale)
76+
6677
if "target_size" in kwargs:
6778
kwargs["height"] = kwargs["target_size"].pop("height")
6879
kwargs["width"] = kwargs["target_size"].pop("width")

src/huggingface_inference_toolkit/env_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
3+
14
def strtobool(val: str) -> bool:
25
"""Convert a string representation of truth to True or False booleans.
36
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
@@ -20,3 +23,11 @@ def strtobool(val: str) -> bool:
2023
raise ValueError(
2124
f"Invalid truth value, it should be a string but {val} was provided instead."
2225
)
26+
27+
28+
def api_inference_compat():
29+
return strtobool(os.getenv("API_INFERENCE_COMPAT", "false"))
30+
31+
32+
def ignore_custom_handler():
33+
return strtobool(os.getenv("IGNORE_CUSTOM_HANDLER", "false"))

src/huggingface_inference_toolkit/handler.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
from pathlib import Path
33
from typing import Any, Dict, Literal, Optional, Union
44

5+
from huggingface_inference_toolkit import logging
56
from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
6-
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
7-
from huggingface_inference_toolkit.utils import (
8-
check_and_register_custom_pipeline_from_directory,
9-
get_pipeline,
10-
)
7+
from huggingface_inference_toolkit.env_utils import api_inference_compat, ignore_custom_handler
8+
from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory
119

1210

1311
class HuggingFaceHandler:
@@ -19,20 +17,25 @@ class HuggingFaceHandler:
1917
def __init__(
2018
self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt"
2119
) -> None:
20+
from huggingface_inference_toolkit.heavy_utils import get_pipeline
2221
self.pipeline = get_pipeline(
2322
model_dir=model_dir, # type: ignore
2423
task=task, # type: ignore
2524
framework=framework,
2625
trust_remote_code=HF_TRUST_REMOTE_CODE,
2726
)
2827

29-
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28+
def __call__(self, data: Dict[str, Any]):
3029
"""
3130
Handles an inference request with input data and makes a prediction.
3231
Args:
3332
:data: (obj): the raw request body data.
3433
:return: prediction output
3534
"""
35+
36+
# import as late as possible to reduce the footprint
37+
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
38+
3639
inputs = data.pop("inputs", data)
3740
parameters = data.pop("parameters", {})
3841

@@ -101,9 +104,82 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
101104
"or `candidateLabels`."
102105
)
103106

104-
return (
105-
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
106-
)
107+
if api_inference_compat():
108+
if self.pipeline.task == "text-classification" and isinstance(inputs, str):
109+
inputs = [inputs]
110+
parameters.setdefault("top_k", os.environ.get("DEFAULT_TOP_K", 5))
111+
if self.pipeline.task == "token-classification":
112+
parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple"))
113+
114+
resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \
115+
self.pipeline(inputs, **parameters)
116+
117+
if api_inference_compat():
118+
if self.pipeline.task == "text-classification":
119+
# We don't want to return {} but [{}] in any case
120+
if isinstance(resp, list) and len(resp) > 0:
121+
if not isinstance(resp[0], list):
122+
return [resp]
123+
return resp
124+
if self.pipeline.task == "feature-extraction":
125+
# If the library used is Transformers then the feature-extraction is returning the headless encoder
126+
# outputs as embeddings. The shape is a 3D or 4D array
127+
# [n_inputs, batch_size = 1, n_sentence_tokens, num_hidden_dim].
128+
# Let's just discard the batch size dim that always seems to be 1 and return a 2D/3D array
129+
# https://github.com/huggingface/transformers/blob/5c47d08b0d6835b8d8fc1c06d9a1bc71f6e78ace/src/transformers/pipelines/feature_extraction.py#L27
130+
# for api inference (reason: mainly display)
131+
new_resp = []
132+
if isinstance(inputs, list):
133+
if isinstance(resp, list) and len(resp) == len(inputs):
134+
for it in resp:
135+
# Batch size dim is the first it level, discard it
136+
if isinstance(it, list) and len(it) == 1:
137+
new_resp.append(it[0])
138+
else:
139+
logging.logger.warning("One of the output batch size differs from 1: %d", len(it))
140+
return resp
141+
return new_resp
142+
else:
143+
logging.logger.warning("Inputs and resp len differ (or resp is not a list, type %s)",
144+
type(resp))
145+
return resp
146+
elif isinstance(inputs, str):
147+
if isinstance(resp, list) and len(resp) == 1:
148+
return resp[0]
149+
else:
150+
logging.logger.warning("The output batch size differs from 1: %d", len(resp))
151+
return resp
152+
else:
153+
logging.logger.warning("Output unexpected type %s", type(resp))
154+
return resp
155+
if self.pipeline.task == "image-segmentation":
156+
if isinstance(resp, list):
157+
new_resp = []
158+
for el in resp:
159+
if isinstance(el, dict) and el.get("score") is None:
160+
el["score"] = 1
161+
new_resp.append(el)
162+
resp = new_resp
163+
if self.pipeline.task == "zero-shot-classification":
164+
try:
165+
if isinstance(resp, dict):
166+
if 'labels' in resp and 'scores' in resp:
167+
labels = resp['labels']
168+
scores = resp['scores']
169+
if len(labels) == len(scores):
170+
new_resp = []
171+
for label, score in zip(labels, scores):
172+
new_resp.append({"label": label, "score": score})
173+
resp = new_resp
174+
else:
175+
raise Exception("labels and scores do not have the same len, {} != {}".format(
176+
len(labels), len(scores)))
177+
else:
178+
raise Exception("Missing labels or scores key in response dict {}".format(resp))
179+
except Exception as e:
180+
logging.logger.warning("Unable to remap response for api inference compat")
181+
logging.logger.exception(e)
182+
return resp
107183

108184

109185
class VertexAIHandler(HuggingFaceHandler):
@@ -149,7 +225,10 @@ def get_inference_handler_either_custom_or_default_handler(model_dir: Path, task
149225
Returns:
150226
InferenceHandler: The appropriate inference handler based on the given model directory and task.
151227
"""
152-
custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir)
228+
if ignore_custom_handler():
229+
custom_pipeline = None
230+
else:
231+
custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir)
153232
if custom_pipeline is not None:
154233
return custom_pipeline
155234

0 commit comments

Comments
 (0)