Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 4c5be67

Browse files
aguschinmike0sv
andauthored
Add middlewares and one specific to expose /metrics endpoint for prometheus (#629)
did this in spare time instead of #591 using https://github.com/trallnag/prometheus-fastapi-instrumentator these middlewares can be enabled via CLI if referenced by their `type` ClassVar: ``` $ mlem --tb serve fastapi --model ../emoji/lyrics2emoji --middlewares.0 prometheus_fastapi ``` --------- Co-authored-by: mike0sv <[email protected]>
1 parent 19b1066 commit 4c5be67

File tree

12 files changed

+275
-16
lines changed

12 files changed

+275
-16
lines changed

mlem/contrib/fastapi.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FastAPIServer implementation
55
"""
66
import logging
7+
from abc import ABC, abstractmethod
78
from collections.abc import Callable
89
from types import ModuleType
910
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type
@@ -24,6 +25,7 @@
2425
InterfaceArgument,
2526
InterfaceMethod,
2627
)
28+
from mlem.runtime.middleware import Middleware, Middlewares
2729
from mlem.runtime.server import Server
2830
from mlem.ui import EMOJI_NAILS, echo
2931
from mlem.utils.module import get_object_requirements
@@ -48,6 +50,12 @@ def _create_schema_route(app: FastAPI, interface: Interface):
4850
app.add_api_route("/interface.json", lambda: schema, tags=["schema"])
4951

5052

53+
class FastAPIMiddleware(Middleware, ABC):
54+
@abstractmethod
55+
def on_app_init(self, app: FastAPI):
56+
raise NotImplementedError
57+
58+
5159
class FastAPIServer(Server, LibRequirementsMixin):
5260
"""Serves model with http"""
5361

@@ -70,6 +78,7 @@ def _create_handler_executor(
7078
arg_serializers: Dict[str, DataTypeSerializer],
7179
executor: Callable,
7280
response_serializer: DataTypeSerializer,
81+
middlewares: Middlewares,
7382
):
7483
deserialized_model = create_model(
7584
"Model", **{a: (Any, ...) for a in args}
@@ -99,7 +108,9 @@ def serializer_validator(_, values):
99108

100109
def bin_handler(model: schema_model): # type: ignore[valid-type]
101110
values = {a: getattr(model, a) for a in args}
111+
values = middlewares.on_request(values)
102112
result = executor(**values)
113+
result = middlewares.on_response(values, result)
103114
with response_serializer.dump(result) as buffer:
104115
return StreamingResponse(
105116
buffer, media_type="application/octet-stream"
@@ -113,7 +124,9 @@ def bin_handler(model: schema_model): # type: ignore[valid-type]
113124

114125
def handler(model: schema_model): # type: ignore[valid-type]
115126
values = {a: getattr(model, a) for a in args}
127+
values = middlewares.on_request(values)
116128
result = executor(**values)
129+
result = middlewares.on_response(values, result)
117130
response = response_serializer.serialize(result)
118131
return parse_obj_as(response_model, response)
119132

@@ -127,12 +140,15 @@ def _create_handler_executor_binary(
127140
arg_name: str,
128141
executor: Callable,
129142
response_serializer: DataTypeSerializer,
143+
middlewares: Middlewares,
130144
):
131145
if response_serializer.serializer.is_binary:
132146

133147
def bin_handler(file: UploadFile):
134148
arg = serializer.deserialize(_SpooledFileIOWrapper(file.file))
149+
arg = middlewares.on_request(arg)
135150
result = executor(**{arg_name: arg})
151+
result = middlewares.on_response(arg, result)
136152
with response_serializer.dump(result) as buffer:
137153
return StreamingResponse(
138154
buffer, media_type="application/octet-stream"
@@ -146,15 +162,20 @@ def bin_handler(file: UploadFile):
146162

147163
def handler(file: UploadFile):
148164
arg = serializer.deserialize(file.file)
165+
arg = middlewares.on_request(arg)
149166
result = executor(**{arg_name: arg})
150-
167+
result = middlewares.on_response(arg, result)
151168
response = response_serializer.serialize(result)
152169
return parse_obj_as(response_model, response)
153170

154171
return handler, response_model, None
155172

156173
def _create_handler(
157-
self, method_name: str, signature: InterfaceMethod, executor: Callable
174+
self,
175+
method_name: str,
176+
signature: InterfaceMethod,
177+
executor: Callable,
178+
middlewares: Middlewares,
158179
) -> Tuple[Optional[Callable], Optional[Type], Optional[Response]]:
159180
serializers, response_serializer = self._get_serializers(signature)
160181
echo(EMOJI_NAILS + f"Adding route for /{method_name}")
@@ -170,13 +191,15 @@ def _create_handler(
170191
arg_name,
171192
executor,
172193
response_serializer,
194+
middlewares,
173195
)
174196
return self._create_handler_executor(
175197
method_name,
176198
{a.name: a for a in signature.args},
177199
serializers,
178200
executor,
179201
response_serializer,
202+
middlewares,
180203
)
181204

182205
def app_init(self, interface: Interface):
@@ -185,11 +208,15 @@ def app_init(self, interface: Interface):
185208
app.add_api_route(
186209
"/", lambda: RedirectResponse("/docs"), include_in_schema=False
187210
)
211+
for mid in self.middlewares.__root__:
212+
mid.on_init()
213+
if isinstance(mid, FastAPIMiddleware):
214+
mid.on_app_init(app)
188215

189216
for method, signature in interface.iter_methods():
190217
executor = interface.get_method_executor(method)
191218
handler, response_model, response_class = self._create_handler(
192-
method, signature, executor
219+
method, signature, executor, self.middlewares
193220
)
194221

195222
app.add_api_route(

mlem/contrib/prometheus.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Instrumenting FastAPI app to expose metrics for prometheus
2+
Extension type: middleware
3+
4+
Exposes /metrics endpoint
5+
"""
6+
from typing import ClassVar, List, Optional
7+
8+
from fastapi import FastAPI
9+
from prometheus_fastapi_instrumentator import Instrumentator
10+
11+
from mlem.contrib.fastapi import FastAPIMiddleware
12+
from mlem.utils.importing import import_string_with_local
13+
from mlem.utils.module import get_object_requirements
14+
15+
16+
class PrometheusFastAPIMiddleware(FastAPIMiddleware):
17+
"""Middleware for FastAPI server that exposes /metrics endpoint to be scraped by Prometheus"""
18+
19+
type: ClassVar = "prometheus_fastapi"
20+
21+
metrics: List[str] = []
22+
"""Instrumentator instance to use. If not provided, a new one will be created"""
23+
instrumentator_cache: Optional[Instrumentator] = None
24+
25+
class Config:
26+
arbitrary_types_allowed = True
27+
exclude = {"instrumentator_cache"}
28+
29+
@property
30+
def instrumentator(self):
31+
if self.instrumentator_cache is None:
32+
self.instrumentator_cache = self.get_instrumentator()
33+
return self.instrumentator_cache
34+
35+
def on_app_init(self, app: FastAPI):
36+
@app.on_event("startup")
37+
async def _startup():
38+
self.instrumentator.expose(app)
39+
40+
def on_init(self):
41+
pass
42+
43+
def on_request(self, request):
44+
return request
45+
46+
def on_response(self, request, response):
47+
return response
48+
49+
def get_instrumentator(self):
50+
instrumentator = Instrumentator()
51+
for metric in self._iter_metric_objects():
52+
# todo: check object type
53+
instrumentator.add(metric)
54+
return instrumentator
55+
56+
def _iter_metric_objects(self):
57+
for metric in self.metrics:
58+
# todo: meaningful error on import error
59+
yield import_string_with_local(metric)
60+
61+
def get_requirements(self):
62+
reqs = super().get_requirements()
63+
for metric in self._iter_metric_objects():
64+
reqs += get_object_requirements(metric)
65+
return reqs

mlem/contrib/sagemaker/runtime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def app_init(self, interface: Interface):
6262
"invocations",
6363
interface.get_method_signature(self.method),
6464
interface.get_method_executor(self.method),
65+
self.middlewares,
6566
)
6667
app.add_api_route(
6768
"/invocations",

mlem/core/base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import shlex
2-
import sys
32
from collections import defaultdict
43
from inspect import isabstract
54
from typing import (
@@ -22,7 +21,7 @@
2221

2322
from mlem.core.errors import ExtensionRequirementError, UnknownImplementation
2423
from mlem.polydantic import PolyModel
25-
from mlem.utils.importing import import_string
24+
from mlem.utils.importing import import_string_with_local
2625
from mlem.utils.path import make_posix
2726

2827

@@ -64,18 +63,12 @@ def load_impl_ext(
6463

6564
if type_name is not None and "." in type_name:
6665
try:
67-
# this is needed because if run from cli curdir is not checked for
68-
# modules to import
69-
sys.path.append(".")
70-
71-
obj = import_string(type_name)
66+
obj = import_string_with_local(type_name)
7267
if not issubclass(obj, MlemABC):
7368
raise ValueError(f"{obj} is not subclass of MlemABC")
7469
return obj
7570
except ImportError:
7671
pass
77-
finally:
78-
sys.path.remove(".")
7972

8073
eps = load_entrypoints()
8174
for ep in eps.values():

mlem/ext.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class ExtensionLoader:
107107
Extension("mlem.contrib.xgboost", ["xgboost"], False),
108108
Extension("mlem.contrib.docker", ["docker"], False),
109109
Extension("mlem.contrib.fastapi", ["fastapi", "uvicorn"], False),
110+
Extension(
111+
"mlem.contrib.prometheus",
112+
["prometheus-fastapi-instrumentator"],
113+
False,
114+
),
110115
Extension("mlem.contrib.callable", [], True),
111116
Extension("mlem.contrib.rabbitmq", ["pika"], False, extra="rmq"),
112117
Extension("mlem.contrib.github", [], True),

mlem/runtime/middleware.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from abc import abstractmethod
2+
from typing import ClassVar, List
3+
4+
from pydantic import BaseModel
5+
6+
from mlem.core.base import MlemABC
7+
from mlem.core.requirements import Requirements, WithRequirements
8+
9+
10+
class Middleware(MlemABC, WithRequirements):
11+
abs_name: ClassVar = "middleware"
12+
13+
class Config:
14+
type_root = True
15+
16+
@abstractmethod
17+
def on_init(self):
18+
raise NotImplementedError
19+
20+
@abstractmethod
21+
def on_request(self, request):
22+
raise NotImplementedError
23+
24+
@abstractmethod
25+
def on_response(self, request, response):
26+
raise NotImplementedError
27+
28+
29+
class Middlewares(BaseModel):
30+
__root__: List[Middleware] = []
31+
"""Middlewares to add to server"""
32+
33+
def on_init(self):
34+
for middleware in self.__root__:
35+
middleware.on_init()
36+
37+
def on_request(self, request):
38+
for middleware in self.__root__:
39+
request = middleware.on_request(request)
40+
return request
41+
42+
def on_response(self, request, response):
43+
for middleware in reversed(self.__root__):
44+
response = middleware.on_response(request, response)
45+
return response
46+
47+
def get_requirements(self) -> Requirements:
48+
reqs = Requirements.new()
49+
for m in self.__root__:
50+
reqs += m.get_requirements()
51+
return reqs

mlem/runtime/server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
InterfaceDescriptor,
1616
InterfaceMethod,
1717
)
18+
from mlem.runtime.middleware import Middlewares
1819
from mlem.utils.module import get_object_requirements
1920

2021
MethodMapping = Dict[str, str]
@@ -120,6 +121,9 @@ class Config:
120121
additional_source_files: ClassVar[Optional[List[str]]] = None
121122
port_field: ClassVar[Optional[str]] = None
122123

124+
middlewares: Middlewares = Middlewares()
125+
"""Middlewares to add to server"""
126+
123127
# @validator("interface")
124128
# @classmethod
125129
# def validate_interface(cls, value):
@@ -155,8 +159,16 @@ def _get_serializers(
155159
return arg_serializers, returns
156160

157161
def get_requirements(self) -> Requirements:
158-
return super().get_requirements() + get_object_requirements(
159-
[self.request_serializer, self.response_serializer, self.methods]
162+
return (
163+
super().get_requirements()
164+
+ get_object_requirements(
165+
[
166+
self.request_serializer,
167+
self.response_serializer,
168+
self.methods,
169+
]
170+
)
171+
+ self.middlewares.get_requirements()
160172
)
161173

162174
def get_ports(self) -> List[int]:

mlem/utils/importing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def module_imported(module_name):
4646
return sys.modules.get(module_name) is not None
4747

4848

49+
def import_string_with_local(path):
50+
try:
51+
# this is needed because if run from cli curdir is not checked for
52+
# modules to import
53+
sys.path.append(".")
54+
return import_string(path)
55+
finally:
56+
sys.path.remove(".")
57+
58+
4959
# Copyright 2019 Zyfra
5060
# Copyright 2021 Iterative
5161
#

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"xgboost": ["xgboost"],
7575
"lightgbm": ["lightgbm"],
7676
"fastapi": ["uvicorn", "fastapi"],
77+
"prometheus": ["prometheus-fastapi-instrumentator"],
7778
"streamlit": ["uvicorn", "fastapi", "streamlit", "streamlit_pydantic"],
7879
"sagemaker": ["docker", "boto3", "sagemaker"],
7980
"torch": ["torch"],
@@ -214,6 +215,7 @@
214215
"serializer.pil_numpy = mlem.contrib.pil:PILImageSerializer",
215216
"builder.pip = mlem.contrib.pip.base:PipBuilder",
216217
"builder.whl = mlem.contrib.pip.base:WhlBuilder",
218+
"middleware.prometheus_fastapi = mlem.contrib.prometheus:PrometheusFastAPIMiddleware",
217219
"client.rmq = mlem.contrib.rabbitmq:RabbitMQClient",
218220
"server.rmq = mlem.contrib.rabbitmq:RabbitMQServer",
219221
"builder.requirements = mlem.contrib.requirements:RequirementsBuilder",

0 commit comments

Comments
 (0)