Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 182 additions & 40 deletions src/datasets/formatting/torch_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,90 +21,232 @@
import pyarrow as pa

from .. import config
from ..utils.py_utils import map_nested
from .formatting import TensorFormatter


if TYPE_CHECKING:
import torch

# Import torch once at module level once
try:
import torch

_torch_available = True
except ImportError:
_torch_available = False
torch = None


class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.torch_tensor_kwargs = torch_tensor_kwargs
import torch # noqa import torch at initialization

if not _torch_available:
raise ImportError("PyTorch is required but not available")

def _consolidate(self, column):
import torch

if isinstance(column, list) and column:
if all(
isinstance(x, torch.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype
for x in column
):
return torch.stack(column)
"""Smarter consolidation that only stacks when safe and beneficial."""
if not isinstance(column, list) or not column:
return column

# Check if all items are tensors with matching properties
first = column[0]
if not isinstance(first, torch.Tensor):
return column

# Fast check: if all tensors have same shape, dtype, and device, we can stack
if all(
isinstance(x, torch.Tensor)
and x.shape == first.shape
and x.dtype == first.dtype
and x.device == first.device
for x in column
):
return torch.stack(column)

return column

def _tensorize(self, value):
import torch

"""Zero/low-copy tensor conversion with smart dtype handling."""
# Fast path for strings, bytes, None
if isinstance(value, (str, bytes, type(None))):
return value
elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
return value.tolist()

default_dtype = {}

if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
default_dtype = {"dtype": torch.int64}

# Convert dtype to np.int64 if it's either np.uint16 or np.uint32 to ensure compatibility.
# np.uint64 is excluded from this conversion as there is no compatible PyTorch dtype that can handle it without loss.
if value.dtype in [np.uint16, np.uint32]:
value = value.astype(np.int64)

elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": torch.float32}
# Handle string arrays
if isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
return value.tolist()

# PIL Image fast path - avoid extra copies
if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image

if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
if value.ndim == 2:
value = value[:, :, np.newaxis]
# Single conversion path: PIL -> numpy -> torch
arr = np.asarray(value)
if arr.ndim == 2:
arr = arr[:, :, np.newaxis]
# Use moveaxis instead of transpose
arr = np.moveaxis(arr, -1, 0) # HWC -> CHW
# Ensure contiguous for zero-copy conversion
if not arr.flags.c_contiguous:
arr = np.ascontiguousarray(arr)
# Ensure array is writable for torch conversion
if not arr.flags.writeable:
arr = arr.copy()
return torch.from_numpy(arr)

value = value.transpose((2, 0, 1))
# Video/Audio decoder passthrough
if config.TORCHVISION_AVAILABLE and "torchvision" in sys.modules:
from torchvision.io import VideoReader

if isinstance(value, VideoReader):
return value # TODO(QL): set output to torch tensors ?
return value

if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules:
from torchcodec.decoders import AudioDecoder, VideoDecoder

if isinstance(value, (VideoDecoder, AudioDecoder)):
return value # TODO(QL): set output to jax arrays ?
return value

# Support for other tensor libraries via __array__
if hasattr(value, "__array__") and not isinstance(value, torch.Tensor):
value = value.__array__()

# Fast numpy conversion paths
if isinstance(value, np.ndarray):
# Handle integer types with smart casting
if np.issubdtype(value.dtype, np.integer):
# Check if user specified a dtype, otherwise default to int64
kwargs = self.torch_tensor_kwargs.copy()
target_dtype = kwargs.get("dtype", torch.int64)

# Safe casting for unsigned types
if value.dtype in (np.uint16, np.uint32):
# Cast to int64 in numpy (fast) then convert to torch
value = value.astype(np.int64)
if target_dtype == torch.int64:
if not value.flags.writeable:
value = value.copy()
return torch.from_numpy(value)
else:
if not value.flags.writeable:
value = value.copy()
kwargs.setdefault("dtype", target_dtype)
return torch.as_tensor(value, **kwargs)
elif value.dtype == np.uint64:
# Check if values fit in int64 range
if np.all(value <= np.iinfo(np.int64).max):
value = value.astype(np.int64)
if target_dtype == torch.int64:
if not value.flags.writeable:
value = value.copy()
return torch.from_numpy(value)
else:
if not value.flags.writeable:
value = value.copy()
kwargs.setdefault("dtype", target_dtype)
return torch.as_tensor(value, **kwargs)
else:
# Fallback to safe conversion via Python ints
kwargs.setdefault("dtype", target_dtype)
return torch.tensor(value, **kwargs)
else:
# Use zero-copy conversion for compatible integer types
if value.dtype == np.int64 and target_dtype == torch.int64:
# Perfect match, zero-copy conversion
if not value.flags.writeable:
value = value.copy()
return torch.from_numpy(value)
else:
# Need dtype conversion, use as_tensor for efficiency
if not value.flags.writeable:
value = value.copy()
kwargs.setdefault("dtype", target_dtype)
return torch.as_tensor(value, **kwargs)

# Handle floating point types
elif np.issubdtype(value.dtype, np.floating):
# Check if user specified a dtype, otherwise default to float32
kwargs = self.torch_tensor_kwargs.copy()
target_dtype = kwargs.get("dtype", torch.float32)

if value.dtype == np.float32 and target_dtype == torch.float32:
# Zero-copy conversion, but ensure array is writable
if not value.flags.writeable:
value = value.copy()
return torch.from_numpy(value)
else:
# Need dtype conversion
if not value.flags.writeable:
value = value.copy()
kwargs.setdefault("dtype", target_dtype)
return torch.as_tensor(value, **kwargs)
else:
# Other numpy types, use zero-copy when possible
if not value.flags.writeable:
value = value.copy()
return torch.from_numpy(value)

# Handle numpy scalars
elif isinstance(value, np.number):
kwargs = self.torch_tensor_kwargs.copy()
if np.issubdtype(value.dtype, np.integer):
# Use torch.as_tensor for scalar conversion with dtype control
kwargs.setdefault("dtype", torch.int64)
return torch.as_tensor(value, **kwargs)
elif np.issubdtype(value.dtype, np.floating):
kwargs.setdefault("dtype", torch.float32)
return torch.as_tensor(value, **kwargs)
else:
return torch.as_tensor(value, **kwargs)

# Handle Python lists/tuples of numbers efficiently
elif isinstance(value, (list, tuple)):
# Try to convert to numpy first for faster tensor creation
try:
arr = np.array(value)
if arr.dtype.kind in "iuf": # integer, unsigned, float
return self._tensorize(arr) # Recursive call to handle numpy path
except (ValueError, TypeError):
pass # Fall back to torch.tensor

# Default fallback with dtype defaults
default_dtype = {}
if isinstance(value, (int, float)):
if isinstance(value, int):
default_dtype = {"dtype": torch.int64}
else:
default_dtype = {"dtype": torch.float32}

return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

def _recursive_tensorize(self, data_struct):
import torch

# support for torch, tf, jax etc.
"""Optimized recursive walker with reduced Python overhead."""
# Handle tensor-like objects with __array__ interface
if hasattr(data_struct, "__array__") and not isinstance(data_struct, torch.Tensor):
data_struct = data_struct.__array__()
# support for nested types like struct of list of struct

# Handle object arrays (nested structures)
if isinstance(data_struct, np.ndarray):
if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
if data_struct.dtype == object:
# Use list comprehension instead of map_nested
result = [self._recursive_tensorize(item) for item in data_struct]
return self._consolidate(result)
# Handle lists and tuples
elif isinstance(data_struct, (list, tuple)):
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
result = [self._recursive_tensorize(item) for item in data_struct]
return self._consolidate(result)
# Handle dictionaries
elif isinstance(data_struct, dict):
return {key: self._recursive_tensorize(value) for key, value in data_struct.items()}

# Base case: tensorize the leaf value
return self._tensorize(data_struct)

def recursive_tensorize(self, data_struct: dict):
return map_nested(self._recursive_tensorize, data_struct, map_list=False)
"""Public interface maintaining compatibility."""
return self._recursive_tensorize(data_struct)

def format_row(self, pa_table: pa.Table) -> Mapping:
row = self.numpy_arrow_extractor().extract_row(pa_table)
Expand Down
Loading