Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyarrow as pa

import datasets
from datasets.features.features import cast_to_python_objects


logger = datasets.utils.logging.get_logger(__name__)
Expand Down Expand Up @@ -64,7 +65,10 @@ def _split_generators(self, dl_manager):
"The TAR archives of the dataset should be in WebDataset format, "
"but the files in the archive don't share the same prefix or the same types."
)
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
pa_tables = [
pa.Table.from_pylist(cast_to_python_objects([example], only_1d_for_numpy=True))
for example in first_examples
]
if datasets.config.PYARROW_VERSION.major < 14:
inferred_arrow_schema = pa.concat_tables(pa_tables, promote=True).schema
else:
Expand Down Expand Up @@ -256,16 +260,21 @@ def cbor_loads(data: bytes):
return cbor.loads(data)


def torch_loads(data: bytes):
import torch

return torch.load(io.BytesIO(data), weights_only=True)


# Obtained by checking `decoders` in `webdataset.autodecode`
# and removing unsafe extension decoders.
# Removed Pickle decoders:
# - "pyd": lambda data: pickle.loads(data)
# - "pickle": lambda data: pickle.loads(data)
# Removed Torch decoders:
# - "pth": lambda data: torch_loads(data)
# Modified NumPy decoders to fix CVE-2019-6446 (add allow_pickle=False):
# Modified NumPy decoders to fix CVE-2019-6446 (add allow_pickle=False and weights_only=True):
# - "npy": npy_loads,
# - "npz": lambda data: np.load(io.BytesIO(data)),
# - "pth": lambda data: torch_loads(data)
DECODERS = {
"txt": text_loads,
"text": text_loads,
Expand All @@ -284,5 +293,6 @@ def cbor_loads(data: bytes):
"npy": npy_loads,
"npz": npz_loads,
"cbor": cbor_loads,
"pth": torch_loads,
}
WebDataset.DECODERS = DECODERS
10 changes: 10 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,16 @@ def audio_file():
return os.path.join("tests", "features", "data", "test_audio_44100.wav")


@pytest.fixture(scope="session")
def tensor_file(tmp_path_factory):
import torch

path = tmp_path_factory.mktemp("data") / "tensor.pth"
with open(path, "wb") as f:
torch.save(torch.ones(128), f)
return path


@pytest.fixture(scope="session")
def zip_image_path(image_file, tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "dataset.img.zip"
Expand Down
49 changes: 47 additions & 2 deletions tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
import pytest

from datasets import Audio, DownloadManager, Features, Image, Value
from datasets import Audio, DownloadManager, Features, Image, Sequence, Value
from datasets.packaged_modules.webdataset.webdataset import WebDataset

from ..utils import require_pil, require_sndfile
from ..utils import require_pil, require_sndfile, require_torch


@pytest.fixture
Expand Down Expand Up @@ -50,6 +50,20 @@ def bad_wds_file(tmp_path, image_file, text_file):
return str(filename)


@pytest.fixture
def tensor_wds_file(tmp_path, tensor_file):
json_file = tmp_path / "data.json"
filename = tmp_path / "file.tar"
num_examples = 3
with json_file.open("w", encoding="utf-8") as f:
f.write(json.dumps({"text": "this is a text"}))
with tarfile.open(str(filename), "w") as f:
for example_idx in range(num_examples):
f.add(json_file, f"{example_idx:05d}.json")
f.add(tensor_file, f"{example_idx:05d}.pth")
return str(filename)


@require_pil
def test_image_webdataset(image_wds_file):
import PIL.Image
Expand Down Expand Up @@ -145,3 +159,34 @@ def test_webdataset_with_features(image_wds_file):
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["caption"], str)
assert isinstance(decoded["jpg"], PIL.Image.Image)


@require_torch
def test_tensor_webdataset(tensor_wds_file):
import torch

data_files = {"train": [tensor_wds_file]}
webdataset = WebDataset(data_files=data_files)
split_generators = webdataset._split_generators(DownloadManager())
assert webdataset.info.features == Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
"json": {"text": Value("string")},
"pth": Sequence(Value("float32")),
}
)
assert len(split_generators) == 1
split_generator = split_generators[0]
assert split_generator.name == "train"
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, examples = zip(*generator)
assert len(examples) == 3
assert isinstance(examples[0]["json"], dict)
assert isinstance(examples[0]["json"]["text"], str)
assert isinstance(examples[0]["pth"], torch.Tensor) # keep encoded to avoid unecessary copies
encoded = webdataset.info.features.encode_example(examples[0])
decoded = webdataset.info.features.decode_example(encoded)
assert isinstance(decoded["json"], dict)
assert isinstance(decoded["json"]["text"], str)
assert isinstance(decoded["pth"], list)