Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
current_example = {}
current_example["__key__"] = example_key
current_example["__url__"] = tar_path
current_example[field_name.lower()] = f.read()
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name.lower()])
current_example[field_name] = f.read()
if field_name.split(".")[-1].lower() in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name])
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
with fsspec.open(extracted_file_path) as f:
current_example[field_name.lower()] = f.read()
current_example[field_name] = f.read()
fs.delete(filename)
data_extension = xbasename(extracted_file_path).split(".")[-1]
data_extension = xbasename(extracted_file_path).split(".")[-1].lower()
else:
data_extension = field_name.split(".")[-1]
data_extension = field_name.split(".")[-1].lower()
if data_extension in cls.DECODERS:
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need it lowered to check if it's in cls.DECODERS no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the data_extension is lowered but the field_name is not in the proposed fix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes !

if current_example:
Expand Down Expand Up @@ -91,19 +91,15 @@ def _split_generators(self, dl_manager):
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)

# Set Image types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
extension = field_name.rsplit(".", 1)[-1].lower()
# Set Image types
if extension in self.IMAGE_EXTENSIONS:
features[field_name] = datasets.Image()
# Set Audio types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
# Set Audio types
if extension in self.AUDIO_EXTENSIONS:
features[field_name] = datasets.Audio()
# Set Video types
for field_name in first_examples[0]:
extension = field_name.rsplit(".", 1)[-1]
# Set Video types
if extension in self.VIDEO_EXTENSIONS:
features[field_name] = datasets.Video()
self.info.features = features
Expand Down
64 changes: 64 additions & 0 deletions tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ def image_wds_file(tmp_path, image_file):
return str(filename)


@pytest.fixture
def upper_lower_case_file(tmp_path):
tar_path = tmp_path / "file.tar"
num_examples = 3
variants = [
("INFO1", "json"),
("info2", "json"),
("info3", "JSON"),
("info3", "json"), # should probably remove if testing on a case insensitive filesystem
]
with tarfile.open(tar_path, "w") as tar:
for example_idx in range(num_examples):
example_name = f"{example_idx:05d}_{'a' if example_idx % 2 else 'A'}"
for tag, ext in variants:
caption_path = tmp_path / f"{example_name}.{tag}.{ext}"
caption_text = {"caption": f"caption for {example_name}.{tag}.{ext}"}
caption_path.write_text(json.dumps(caption_text), encoding="utf-8")
tar.add(caption_path, arcname=f"{example_name}.{tag}.{ext}")
return str(tar_path)


@pytest.fixture
def audio_wds_file(tmp_path, audio_file):
json_file = tmp_path / "data.json"
Expand Down Expand Up @@ -133,6 +154,49 @@ def test_image_webdataset(image_wds_file):
assert isinstance(decoded["jpg"], PIL.Image.Image)


def test_upper_lower_case(upper_lower_case_file):
variants = [
("INFO1", "json"),
("info2", "json"),
("info3", "JSON"),
("info3", "json"),
]

data_files = {"train": [upper_lower_case_file]}
webdataset = WebDataset(data_files=data_files)
split_generators = webdataset._split_generators(DownloadManager())

variant_keys = [f"{tag}.{ext}" for tag, ext in variants]
assert webdataset.info.features == Features(
{
"__key__": Value("string"),
"__url__": Value("string"),
**{k: {"caption": Value("string")} for k in variant_keys},
}
)

assert len(split_generators) == 1
split_generator = split_generators[0]
assert split_generator.name == "train"
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
_, examples = zip(*generator)

assert len(examples) == 3
for example_idx, example in enumerate(examples):
example_name = example["__key__"]
expected_example_name = f"{example_idx:05d}_{'a' if example_idx % 2 else 'A'}"

assert example_name == expected_example_name
for key in variant_keys:
assert isinstance(example[key], dict)
assert example[key]["caption"] == f"caption for {example_name}.{key}"

encoded = webdataset.info.features.encode_example(example)
decoded = webdataset.info.features.decode_example(encoded)
for key in variant_keys:
assert decoded[key]["caption"] == example[key]["caption"]


@require_pil
def test_image_webdataset_missing_keys(image_wds_file):
import PIL.Image
Expand Down
Loading