Skip to content

Commit 534c358

Browse files
Merge commit from fork
* fix: remediated ssrf * fix: ran pre-commit * feat: added verbose error handling * feat: restructured exception handling around specific operations --------- Co-authored-by: nkoorty <[email protected]>
1 parent c8ffb58 commit 534c358

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

src/_bentoml_impl/serde.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from _bentoml_sdk.validators import DataframeSchema
2222
from _bentoml_sdk.validators import TensorSchema
2323
from bentoml._internal.utils.uri import is_http_url
24+
from bentoml._internal.utils.uri import is_safe_url
2425

2526
if t.TYPE_CHECKING:
2627
from starlette.requests import Request
@@ -163,7 +164,8 @@ async def parse_request(self, request: Request, cls: type[T]) -> T:
163164

164165
body = await request.body()
165166
if issubclass(cls, IORootModel) and cls.multipart_fields:
166-
if is_http_url(url := body.decode("utf-8", "ignore")):
167+
url = body.decode("utf-8", "ignore")
168+
if is_http_url(url) and is_safe_url(url):
167169
async with httpx.AsyncClient() as client:
168170
logger.debug("Request with URL, downloading file from %s", url)
169171
resp = await client.get(url)
@@ -189,12 +191,15 @@ async def ensure_file(obj: str | UploadFile) -> UploadFile:
189191

190192
if isinstance(obj, UploadFile):
191193
return obj
194+
195+
url = obj.strip("\"'")
196+
if not is_safe_url(url):
197+
raise ValueError("URL not allowed for security reasons")
198+
192199
async with httpx.AsyncClient() as client:
193-
obj = obj.strip("\"'") # The url may be JSON encoded
194-
logger.debug("Request with URL, downloading file from %s", obj)
195-
resp = await client.get(obj)
200+
resp = await client.get(url)
196201
body = io.BytesIO(await resp.aread())
197-
parsed = urlparse(obj)
202+
parsed = urlparse(url)
198203
return UploadFile(
199204
body,
200205
size=len(body.getvalue()),

src/bentoml/_internal/utils/uri.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import ipaddress
12
import os
23
import pathlib
4+
import socket
35
from urllib.parse import quote
46
from urllib.parse import unquote
57
from urllib.parse import urlparse
@@ -50,3 +52,45 @@ def encode_path_for_uri(path: str) -> str:
5052

5153
def is_http_url(url: str) -> bool:
5254
return urlparse(url).scheme in {"http", "https"}
55+
56+
57+
def is_safe_url(url: str) -> bool:
58+
"""Check if URL is safe for download (prevents basic SSRF)."""
59+
try:
60+
parsed = urlparse(url)
61+
except (ValueError, TypeError):
62+
return False
63+
64+
if parsed.scheme not in {"http", "https"}:
65+
return False
66+
67+
hostname = parsed.hostname
68+
if not hostname:
69+
return False
70+
71+
if hostname.lower() in {"localhost", "127.0.0.1", "::1", "169.254.169.254"}:
72+
return False
73+
74+
try:
75+
ip = ipaddress.ip_address(hostname)
76+
return not (ip.is_private or ip.is_loopback or ip.is_link_local)
77+
except ValueError:
78+
# hostname is not an IP address, need to resolve it
79+
pass
80+
81+
try:
82+
addr_info = socket.getaddrinfo(hostname, None)
83+
except socket.gaierror:
84+
# DNS resolution failed
85+
return False
86+
87+
for info in addr_info:
88+
try:
89+
ip = ipaddress.ip_address(info[4][0])
90+
if ip.is_private or ip.is_loopback or ip.is_link_local:
91+
return False
92+
except (ValueError, IndexError):
93+
# Skip malformed addresses
94+
continue
95+
96+
return True

0 commit comments

Comments
 (0)