Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 153 additions & 9 deletions flaxvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,121 @@
import numpy as np
import jax.numpy as jnp
from flax import nn
import logging
import os

try:
import psutil
PSUTIL_AVAILABLE = True
except ImportError:
PSUTIL_AVAILABLE = False


class TensorSecurityError(Exception):
"""Exception raised for tensor security violations"""
pass


def validate_tensor_security(tensor, max_size_gb=1.0):
"""
Validate tensor for security issues including:
- Finite values (no NaN/Inf)
- Memory size limits
- Valid data types
"""
if tensor is None:
raise TensorSecurityError("Tensor cannot be None")

# Convert to numpy if it's a torch tensor
if hasattr(tensor, 'detach'):
numpy_tensor = tensor.detach().numpy()
else:
numpy_tensor = tensor

# Check for finite values
if not np.isfinite(numpy_tensor).all():
raise TensorSecurityError("Tensor contains non-finite values (NaN or Inf)")

# Check tensor size limit (prevent memory exhaustion)
tensor_size_bytes = numpy_tensor.nbytes
max_size_bytes = max_size_gb * 1024 * 1024 * 1024

if tensor_size_bytes > max_size_bytes:
raise TensorSecurityError(f"Tensor size ({tensor_size_bytes} bytes) exceeds maximum allowed size ({max_size_bytes} bytes)")

# Check for reasonable tensor dimensions
if numpy_tensor.ndim > 6: # More than 6 dimensions is suspicious
raise TensorSecurityError(f"Tensor has too many dimensions ({numpy_tensor.ndim})")

# Check for valid data types
if not np.issubdtype(numpy_tensor.dtype, np.number):
raise TensorSecurityError(f"Tensor has invalid data type: {numpy_tensor.dtype}")

return numpy_tensor


def validate_parameter_structure(params_dict, expected_keys=None, max_params=10000):
"""
Validate parameter dictionary structure:
- Check parameter count limits
- Validate key structure
- Ensure reasonable parameter sizes
"""
if not isinstance(params_dict, dict):
raise TensorSecurityError("Parameters must be a dictionary")

# Check parameter count
total_params = len(params_dict)
if total_params > max_params:
raise TensorSecurityError(f"Too many parameters ({total_params}), maximum allowed: {max_params}")

# Validate each parameter
for key, value in params_dict.items():
if not isinstance(key, str):
raise TensorSecurityError(f"Parameter key must be string, got: {type(key)}")

if len(key) > 200: # Prevent extremely long keys
raise TensorSecurityError(f"Parameter key too long: {len(key)} characters")

# Recursively validate nested dictionaries
if isinstance(value, dict):
validate_parameter_structure(value, expected_keys, max_params)


def monitor_memory_usage():
"""Monitor current memory usage and warn if approaching limits"""
if not PSUTIL_AVAILABLE:
logging.warning("psutil not available, memory monitoring disabled")
return 0.0

try:
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
memory_mb = memory_info.rss / (1024 * 1024)

# Warning threshold at 2GB
if memory_mb > 2048:
logging.warning(f"High memory usage detected: {memory_mb:.2f} MB")

return memory_mb
except Exception as e:
logging.warning(f"Failed to monitor memory usage: {e}")
return 0.0


def load_torch_params(url):
return torch.hub.load_state_dict_from_url(url)
return torch.hub.load_state_dict_from_url(url)


def torch_to_flax(torch_params, get_flax_keys):
"""Convert PyTorch parameters to nested dictionaries"""

"""Convert PyTorch parameters to nested dictionaries with security validation"""

# Monitor memory usage at start
initial_memory = monitor_memory_usage()

# Validate input parameter structure
validate_parameter_structure(torch_params)

def add_to_params(params_dict, nested_keys, param, is_conv=False):
if len(nested_keys) == 1:
key, = nested_keys
Expand All @@ -33,19 +139,41 @@ def add_to_state(state_dict, keys, param):

flax_params, flax_state = {}, {}
for key, tensor in torch_params.items():
# Validate tensor security before processing
try:
validated_tensor = validate_tensor_security(tensor)
except TensorSecurityError as e:
logging.error(f"Tensor security validation failed for key '{key}': {e}")
raise

flax_keys = get_flax_keys(key.split('.'))
if flax_keys[-1] is None:
continue
flax_keys = get_flax_keys(key.split('.'))
if flax_keys[-1] == 'mean' or flax_keys[-1] == 'var':
add_to_state(flax_state, flax_keys, tensor.detach().numpy())
add_to_state(flax_state, flax_keys, validated_tensor)
else:
add_to_params(flax_params, flax_keys, tensor.detach().numpy())
add_to_params(flax_params, flax_keys, validated_tensor)

# Monitor memory usage during processing
current_memory = monitor_memory_usage()
if current_memory > initial_memory + 1000: # 1GB increase
logging.warning(f"Memory usage increased significantly during parameter conversion: {current_memory - initial_memory:.2f} MB")

# Validate final parameter structures
validate_parameter_structure(flax_params)
validate_parameter_structure(flax_state)

return flax_params, flax_state


def torch_to_linen(torch_params, get_flax_keys):
"""Convert PyTorch parameters to Linen nested dictionaries"""
"""Convert PyTorch parameters to Linen nested dictionaries with security validation"""

# Monitor memory usage at start
initial_memory = monitor_memory_usage()

# Validate input parameter structure
validate_parameter_structure(torch_params)

def add_to_params(params_dict, nested_keys, param, is_conv=False):
if len(nested_keys) == 1:
Expand All @@ -61,11 +189,27 @@ def add_to_params(params_dict, nested_keys, param, is_conv=False):

flax_params = {'params': {}, 'batch_stats': {}}
for key, tensor in torch_params.items():
# Validate tensor security before processing
try:
validated_tensor = validate_tensor_security(tensor)
except TensorSecurityError as e:
logging.error(f"Tensor security validation failed for key '{key}': {e}")
raise

flax_keys = get_flax_keys(key.split('.'))
if flax_keys[-1] is not None:
if flax_keys[-1] in ('mean', 'var'):
add_to_params(flax_params['batch_stats'], flax_keys, tensor.detach().numpy())
add_to_params(flax_params['batch_stats'], flax_keys, validated_tensor)
else:
add_to_params(flax_params['params'], flax_keys, tensor.detach().numpy())
add_to_params(flax_params['params'], flax_keys, validated_tensor)

# Monitor memory usage during processing
current_memory = monitor_memory_usage()
if current_memory > initial_memory + 1000: # 1GB increase
logging.warning(f"Memory usage increased significantly during parameter conversion: {current_memory - initial_memory:.2f} MB")

# Validate final parameter structures
validate_parameter_structure(flax_params['params'])
validate_parameter_structure(flax_params['batch_stats'])

return flax_params
6 changes: 6 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading