Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import tempfile
import time
import warnings
import ipaddress
import socket
from urllib.parse import urlparse
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -287,6 +290,32 @@ def is_valid_url(possible_url: str) -> bool:
)
return is_http_url_like(possible_url) and probe_url(possible_url)

def is_safe_url(url):
try:
# Parse the given URL
parsed_url = urlparse(url)
domain = parsed_url.hostname
except Exception:
return False

# Check if the URL's scheme (protocol) is not http or https
if parsed_url.scheme not in ('http', 'https'):
return False

try:
# Resolve the domain to an IP address using DNS lookup
resolve = socket.gethostbyname(domain)

# Convert the resolved IP address to an ipaddress.IPv4Address or ipaddress.IPv6Address object
ip = ipaddress.ip_address(resolve)

# Check if the IP address is private (within private address ranges) or multicast
if ip.is_private or ip.is_multicast:
return False
except ValueError:
return False
return True


async def get_pred_from_ws(
websocket: WebSocketCommonProtocol,
Expand Down
8 changes: 6 additions & 2 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,16 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
if not client_utils.is_safe_url(url):
raise ValueError("URL is not safe or violates security policy.")
temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))

if not Path(full_temp_file_path).exists():
with sync_client.stream("GET", url, follow_redirects=True) as r, open(
with sync_client.stream("GET", url) as r, open(
full_temp_file_path, "wb"
) as f:
for chunk in r.iter_raw():
Expand All @@ -290,14 +292,16 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
if not client_utils.is_safe_url(url):
raise ValueError("URL is not safe or violates security policy.")
temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))

if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
async with async_client.stream("GET", url) as response:
async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
await f.write(chunk)
Expand Down