11import os
2+ import threading
23from pathlib import Path
34from time import perf_counter
45
1617 HF_REVISION ,
1718 HF_TASK ,
1819)
20+ from huggingface_inference_toolkit .env_utils import api_inference_compat
1921from huggingface_inference_toolkit .handler import (
2022 get_inference_handler_either_custom_or_default_handler ,
2123)
2830)
2931from huggingface_inference_toolkit .vertex_ai_utils import _load_repository_from_gcs
3032
33+ INFERENCE_HANDLERS = {}
34+ INFERENCE_HANDLERS_LOCK = threading .Lock ()
3135
3236async def prepare_model_artifacts ():
33- global inference_handler
37+ global INFERENCE_HANDLERS
3438 # 1. check if model artifacts available in HF_MODEL_DIR
3539 if len (list (Path (HF_MODEL_DIR ).glob ("**/*" ))) <= 0 :
3640 # 2. if not available, try to load from HF_MODEL_ID
@@ -62,6 +66,7 @@ async def prepare_model_artifacts():
6266 inference_handler = get_inference_handler_either_custom_or_default_handler (
6367 HF_MODEL_DIR , task = HF_TASK
6468 )
69+ INFERENCE_HANDLERS [HF_TASK ] = inference_handler
6570 logger .info ("Model initialized successfully" )
6671
6772
@@ -82,6 +87,7 @@ async def metrics(request):
8287
8388
8489async def predict (request ):
90+ global INFERENCE_HANDLERS
8591 try :
8692 # extracts content from request
8793 content_type = request .headers .get ("content-Type" , os .environ .get ("DEFAULT_CONTENT_TYPE" )).lower ()
@@ -101,6 +107,17 @@ async def predict(request):
101107 dict (request .query_params )
102108 )
103109
110+ # We lazily load pipelines for alt tasks
111+ task = request .path_params .get ("task" , HF_TASK )
112+ inference_handler = INFERENCE_HANDLERS .get (task )
113+ if not inference_handler :
114+ with INFERENCE_HANDLERS_LOCK :
115+ if task not in INFERENCE_HANDLERS :
116+ inference_handler = get_inference_handler_either_custom_or_default_handler (
117+ HF_MODEL_DIR , task = task )
118+ INFERENCE_HANDLERS [task ] = inference_handler
119+ else :
120+ inference_handler = INFERENCE_HANDLERS [task ]
104121 # tracks request time
105122 start_time = perf_counter ()
106123 # run async not blocking call
@@ -149,14 +166,19 @@ async def predict(request):
149166 on_startup = [prepare_model_artifacts ],
150167 )
151168else :
169+ routes = [
170+ Route ("/" , health , methods = ["GET" ]),
171+ Route ("/health" , health , methods = ["GET" ]),
172+ Route ("/" , predict , methods = ["POST" ]),
173+ Route ("/predict" , predict , methods = ["POST" ]),
174+ Route ("/metrics" , metrics , methods = ["GET" ]),
175+ ]
176+ if api_inference_compat ():
177+ routes .append (
178+ Route ("/pipeline/{task:path}" , predict , methods = ["POST" ])
179+ )
152180 app = Starlette (
153181 debug = False ,
154- routes = [
155- Route ("/" , health , methods = ["GET" ]),
156- Route ("/health" , health , methods = ["GET" ]),
157- Route ("/" , predict , methods = ["POST" ]),
158- Route ("/predict" , predict , methods = ["POST" ]),
159- Route ("/metrics" , metrics , methods = ["GET" ]),
160- ],
182+ routes = routes ,
161183 on_startup = [prepare_model_artifacts ],
162184 )
0 commit comments