Skip to content

Conversation

@drbh
Copy link
Contributor

@drbh drbh commented Sep 24, 2025

perf: reduce copies in TorchFormatter

This PR make changes the torch formatter to avoid unnecessary copies and casts when converting decoded batches to tensors.

Because many arrays are already in a torch-friendly memory layout and dtype, we can do zero‑copy conversions (torch.from_numpy) and only fall back to as_tensor when a dtype/device change is required. We also consolidate lists of same‑shape tensors with a cheap stack only when safe.

Why it helps

  • Avoids extra materialization and dtype churn during batched map and indexing.
  • Preserves API and outputs; only changes internal conversion logic.

Small benchmark script (based on #6104)

import time
from datasets import load_dataset


def main():
    dataset = load_dataset("NightMachinery/hf_datasets_bug1")
    dataset = dataset["train"] if "train" in dataset else dataset
    t0 = time.time()
    dataset.set_format(type="torch")

    # identity map with small batches
    dataset = dataset.map(lambda x: x, batched=True, batch_size=20)

    # force materialization
    data = dataset[:300]
    print(len(data.keys()))

    t1 = time.time()
    print(f"Duration: {t1 - t0:.2f} s")


if __name__ == "__main__":
    main()

Without changes

uv run bench.py
# 303
# Duration: 7.26 s

With changes

uv run bench.py
# 303
# Duration: 4.43 s

Updated reproduction scripts

Below are some simple test cases using main and this refactor-torch-formatter branch. I've included the two scripts and output when running on a local machine.

# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "datasets",
#     "pillow",
# ]
#
# [tool.uv.sources]
# datasets = { git = "https://github.com/huggingface/datasets.git" }
# ///
import time
import random
import numpy as np
from PIL import Image
from datasets import Dataset, load_dataset
import torch


def create_mock_images_dataset(num_samples=5000):
    """Create a deterministic mock dataset with PIL images."""
    random.seed(42)
    np.random.seed(42)

    images = []
    labels = []

    for i in range(num_samples):
        # Create deterministic RGB image
        width, height = 64, 64
        rgb_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
        image = Image.fromarray(rgb_array)
        images.append(image)
        labels.append(i % 10)  # 10 classes

    return Dataset.from_dict({"image": images, "label": labels})


def create_mock_text_dataset(num_samples=5000):
    """Create a deterministic mock dataset with text."""
    random.seed(42)

    words = ["apple", "banana", "cherry", "date", "elderberry", "fig", "grape", "honeydew"]
    texts = []
    labels = []

    for i in range(num_samples):
        # Create deterministic text
        text_length = 5 + (i % 20)  # 5-24 words
        text = " ".join(random.choices(words, k=text_length))
        texts.append(text)
        labels.append(i % 3)  # 3 classes

    return Dataset.from_dict({"text": texts, "label": labels})


def create_mock_ints_dataset(num_samples=5000):
    """Create a deterministic mock dataset with integers."""
    random.seed(42)

    data = []
    labels = []

    for i in range(num_samples):
        # Create deterministic integer arrays
        arr = [random.randint(0, 1000) for _ in range(50)]  # 50 integers each
        data.append(arr)
        labels.append(i % 5)  # 5 classes

    return Dataset.from_dict({"data": data, "label": labels})


def create_mock_floats_dataset(num_samples=5000):
    """Create a deterministic mock dataset with floats."""
    random.seed(42)

    data = []
    labels = []

    for i in range(num_samples):
        # Create deterministic float arrays
        arr = [random.uniform(0.0, 100.0) for _ in range(30)]  # 30 floats each
        data.append(arr)
        labels.append(i % 4)  # 4 classes

    return Dataset.from_dict({"data": data, "label": labels})


def benchmark_dataset(name, dataset, num_samples=1000):
    """Benchmark dataset access speed."""
    print(f"\n=== {name} Dataset Benchmark ===")

    t0 = time.time()
    dataset.set_format(type="torch")

    # identity map with small batches
    dataset = dataset.map(lambda x: x, batched=True, batch_size=20)

    # force materialization
    data = dataset[:num_samples]
    print(f"Keys: {list(data.keys())}")
    print(f"Sample count: {len(data[list(data.keys())[0]])}")

    t1 = time.time()
    print(f"Duration: {t1 - t0:.2f} s")
    print(f"Speed: {num_samples / (t1 - t0):.1f} samples/s")


def main():
    # PIL Images benchmark
    images_dataset = create_mock_images_dataset()
    benchmark_dataset("PIL Images", images_dataset)

    # Text benchmark
    text_dataset = create_mock_text_dataset()
    benchmark_dataset("Text", text_dataset)

    # Integers benchmark
    ints_dataset = create_mock_ints_dataset()
    benchmark_dataset("Integers", ints_dataset)

    # Floats benchmark
    floats_dataset = create_mock_floats_dataset()
    benchmark_dataset("Floats", floats_dataset)


if __name__ == "__main__":
    main()

output

uv run --refresh example1.py
=== PIL Images Dataset Benchmark ===
Map:   0%|                                                          | 0/5000 [00:00<?, ? examples/s]/Users/drbh/.cache/uv/environments-v2/example1-2aca1a30e84bdead/lib/python3.10/site-packages/datasets/features/image.py:352: UserWarning: Downcasting array dtype int64 to uint8 to be compatible with 'Pillow'
  warnings.warn(f"Downcasting array dtype {dtype} to {dest_dtype} to be compatible with 'Pillow'")
Map: 100%|█████████████████████████████████████████████| 5000/5000 [00:01<00:00, 3669.15 examples/s]
Keys: ['image', 'label']
Sample count: 1000
Duration: 2.14 s
Speed: 466.5 samples/s

=== Text Dataset Benchmark ===
Map: 100%|███████████████████████████████████████████| 5000/5000 [00:00<00:00, 141327.04 examples/s]
Keys: ['text', 'label']
Sample count: 1000
Duration: 0.04 s
Speed: 27004.3 samples/s

=== Integers Dataset Benchmark ===
Map: 100%|███████████████████████████████████████████| 5000/5000 [00:00<00:00, 112904.90 examples/s]
Keys: ['data', 'label']
Sample count: 1000
Duration: 0.05 s
Speed: 21680.6 samples/s

=== Floats Dataset Benchmark ===
Map: 100%|███████████████████████████████████████████| 5000/5000 [00:00<00:00, 104084.25 examples/s]
Keys: ['data', 'label']
Sample count: 1000
Duration: 0.05 s
Speed: 20215.1 samples/s

and this branch specifically

# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "datasets",
#     "pillow",
# ]
#
# [tool.uv.sources]
# datasets = { git = "https://github.com/huggingface/datasets.git", rev = "refactor-torch-formatter" }
# ///
import time
import random
import numpy as np
from PIL import Image
from datasets import Dataset, load_dataset
import torch


def create_mock_images_dataset(num_samples=5000):
    """Create a deterministic mock dataset with PIL images."""
    random.seed(42)
    np.random.seed(42)

    images = []
    labels = []

    for i in range(num_samples):
        # Create deterministic RGB image
        width, height = 64, 64
        rgb_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
        image = Image.fromarray(rgb_array)
        images.append(image)
        labels.append(i % 10)  # 10 classes

    return Dataset.from_dict({"image": images, "label": labels})


def create_mock_text_dataset(num_samples=5000):
    """Create a deterministic mock dataset with text."""
    random.seed(42)

    words = [
        "apple",
        "banana",
        "cherry",
        "date",
        "elderberry",
        "fig",
        "grape",
        "honeydew",
    ]
    texts = []
    labels = []

    for i in range(num_samples):
        # Create deterministic text
        text_length = 5 + (i % 20)  # 5-24 words
        text = " ".join(random.choices(words, k=text_length))
        texts.append(text)
        labels.append(i % 3)  # 3 classes

    return Dataset.from_dict({"text": texts, "label": labels})


def create_mock_ints_dataset(num_samples=5000):
    """Create a deterministic mock dataset with integers."""
    random.seed(42)

    data = []
    labels = []

    for i in range(num_samples):
        # Create deterministic integer arrays
        arr = [random.randint(0, 1000) for _ in range(50)]  # 50 integers each
        data.append(arr)
        labels.append(i % 5)  # 5 classes

    return Dataset.from_dict({"data": data, "label": labels})


def create_mock_floats_dataset(num_samples=5000):
    """Create a deterministic mock dataset with floats."""
    random.seed(42)

    data = []
    labels = []

    for i in range(num_samples):
        # Create deterministic float arrays
        arr = [random.uniform(0.0, 100.0) for _ in range(30)]  # 30 floats each
        data.append(arr)
        labels.append(i % 4)  # 4 classes

    return Dataset.from_dict({"data": data, "label": labels})


def benchmark_dataset(name, dataset, num_samples=1000):
    """Benchmark dataset access speed."""
    print(f"\n=== {name} Dataset Benchmark ===")

    t0 = time.time()
    dataset.set_format(type="torch")

    # identity map with small batches
    dataset = dataset.map(lambda x: x, batched=True, batch_size=20)

    # force materialization
    data = dataset[:num_samples]
    print(f"Keys: {list(data.keys())}")
    print(f"Sample count: {len(data[list(data.keys())[0]])}")

    t1 = time.time()
    print(f"Duration: {t1 - t0:.2f} s")
    print(f"Speed: {num_samples / (t1 - t0):.1f} samples/s")


def main():
    # PIL Images benchmark
    images_dataset = create_mock_images_dataset()
    benchmark_dataset("PIL Images", images_dataset)

    # Text benchmark
    text_dataset = create_mock_text_dataset()
    benchmark_dataset("Text", text_dataset)

    # Integers benchmark
    ints_dataset = create_mock_ints_dataset()
    benchmark_dataset("Integers", ints_dataset)

    # Floats benchmark
    floats_dataset = create_mock_floats_dataset()
    benchmark_dataset("Floats", floats_dataset)


if __name__ == "__main__":
    main()
uv run --refresh example2.py
    Updated https://github.com/huggingface/datasets.git (2cb64d1b6503afb49d822b20979760efe4519d03)
      Built datasets @ git+https://github.com/huggingface/datasets.git@2cb64d1b6503afb49d822b20979760efe
Uninstalled 1 package in 20ms
Installed 1 package in 5ms

=== PIL Images Dataset Benchmark ===
Map:   0%|                                                          | 0/5000 [00:00<?, ? examples/s]/Users/drbh/.cache/uv/environments-v2/example2-d4af608668b706ec/lib/python3.10/site-packages/datasets/features/image.py:352: UserWarning: Downcasting array dtype int64 to uint8 to be compatible with 'Pillow'
  warnings.warn(f"Downcasting array dtype {dtype} to {dest_dtype} to be compatible with 'Pillow'")
Map: 100%|█████████████████████████████████████████████| 5000/5000 [00:01<00:00, 3645.14 examples/s]
Keys: ['image', 'label']
Sample count: 1000
Duration: 2.04 s
Speed: 491.2 samples/s

=== Text Dataset Benchmark ===
Map: 100%|████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 169877.28 examples/s]
Keys: ['text', 'label']
Sample count: 1000
Duration: 0.03 s
Speed: 32236.1 samples/s

=== Integers Dataset Benchmark ===
Map: 100%|████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 131940.33 examples/s]
Keys: ['data', 'label']
Sample count: 1000
Duration: 0.04 s
Speed: 25493.3 samples/s

=== Floats Dataset Benchmark ===
Map: 100%|████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 120621.64 examples/s]
Keys: ['data', 'label']
Sample count: 1000
Duration: 0.04 s
Speed: 23370.6 samples/s

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@drbh drbh marked this pull request as ready for review September 24, 2025 21:54
@drbh drbh requested a review from lhoestq September 25, 2025 13:50
@lhoestq
Copy link
Member

lhoestq commented Sep 25, 2025

can you re-read your PR please ?

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for cleaning it up ! lgtm :)

@lhoestq lhoestq merged commit c412a6f into main Sep 26, 2025
13 of 15 checks passed
@lhoestq lhoestq deleted the refactor-torch-formatter branch September 26, 2025 15:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants