Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/audio_load.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,27 @@ For more information about creating your own `AudioFolder` dataset, take a look
</Tip>

For a guide on how to load any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./loading">general loading guide</a>.

## Audio decoding

By default, audio files are decoded sequentially as NumPy arrays when you iterate on a dataset.
However it is possible to speed up the dataset significantly using multithreaded decoding:

```python
>>> import os
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
>>> dataset = dataset.decode(num_threads=num_threads)
>>> for example in dataset: # up to 20 times faster !
... ...
```

You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
However it can be slower than `num_threads=0` for local data on fast disks.

If you are not interested in the images decoded as NumPy arrays and would like to access the path/bytes instead, you can disable decoding:

```python
>>> dataset = dataset.decode(False)
```

Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.
24 changes: 24 additions & 0 deletions docs/source/image_load.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,27 @@ You can load a WebDataset like this:

>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
```

## Image decoding

By default, images are decoded sequentially as `PIL.Images` when you iterate on a dataset.
However it is possible to speed up the dataset significantly using multithreaded decoding:

```python
>>> import os
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
>>> dataset = dataset.decode(num_threads=num_threads)
>>> for example in dataset: # up to 20 times faster !
... ...
```

You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
However it can be slower than `num_threads=0` for local data on fast disks.

If you are not interested in the images decoded as `PIL.Images` and would like to access the path/bytes instead, you can disable decoding:

```python
>>> dataset = dataset.decode(False)
```

Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- select_columns
- cast_column
- cast
- decode
- __iter__
- iter
- map
Expand Down
26 changes: 26 additions & 0 deletions docs/source/video_load.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,29 @@ You can load a WebDataset like this:

>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
```

## Video decoding

By default, videos are decoded sequentially as torchvision `VideoReaders` when you iterate on a dataset.
It sequentially decodes the metadata of the videos, and doesn't read the video frames until you access them.

However it is possible to speed up the dataset significantly using multithreaded decoding:

```python
>>> import os
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
>>> dataset = dataset.decode(num_threads=num_threads)
>>> for example in dataset: # up to 20 times faster !
... ...
```

You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
However it can be slower than `num_threads=0` for local data on fast disks.

If you are not interested in the images decoded as torchvision `VideoReaders` and would like to access the path/bytes instead, you can disable decoding:

```python
>>> dataset = dataset.decode(False)
```

Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3551,7 +3551,7 @@ def iter_outputs(shard_iterable):
task.cancel(msg="KeyboardInterrupt")
try:
loop.run_until_complete(asyncio.gather(*tasks))
except asyncio.CancelledError:
except (asyncio.CancelledError, ValueError):
logger.debug("Tasks canceled.")
raise

Expand Down
121 changes: 117 additions & 4 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import itertools
import multiprocessing.pool
import sys
from collections import Counter
from collections.abc import Iterable, Iterator
Expand All @@ -24,6 +25,7 @@
Value,
_align_features,
_check_if_features_can_be_aligned,
_visit,
cast_to_python_objects,
)
from .formatting import (
Expand Down Expand Up @@ -1010,6 +1012,7 @@ def __init__(
fn_kwargs: Optional[dict] = None,
formatting: Optional["FormattingConfig"] = None,
features: Optional[Features] = None,
max_num_running_async_map_functions_in_parallel: Optional[int] = None,
):
super().__init__()
self.ex_iterable = ex_iterable
Expand All @@ -1023,6 +1026,9 @@ def __init__(
self.fn_kwargs = fn_kwargs or {}
self.formatting = formatting # required for iter_arrow
self._features = features
self.max_num_running_async_map_functions_in_parallel = (
max_num_running_async_map_functions_in_parallel or config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL
)
# sanity checks
if formatting and formatting.is_table:
# batch_size should match for iter_arrow
Expand All @@ -1036,6 +1042,8 @@ def __init__(
f"The {formatting.format_type.capitalize()}-formatted {type(self).__name__} has batch_size={batch_size if batched else 1} which is"
f"different from {ex_iterable.batch_size=} from its underlying iterable."
)
# to enable graceful ends
self._owned_loops_and_tasks: list[tuple[asyncio.AbstractEventLoop, list[asyncio.Task]]] = []

@property
def iter_arrow(self):
Expand Down Expand Up @@ -1174,6 +1182,7 @@ async def async_apply_function(key_example, indices):
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
self._owned_loops_and_tasks.append((loop, tasks))
else:
loop = None

Expand All @@ -1191,15 +1200,15 @@ def iter_outputs():
indices.append(i)
tasks.append(loop.create_task(async_apply_function(key_example, i)))
# keep the total active tasks under a certain number
if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
if len(tasks) >= self.max_num_running_async_map_functions_in_parallel:
done, pending = loop.run_until_complete(
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
)
while tasks and len(pending) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
while tasks and len(pending) >= self.max_num_running_async_map_functions_in_parallel:
done, pending = loop.run_until_complete(
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
)
if len(tasks) >= 10 * config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
if len(tasks) >= 10 * self.max_num_running_async_map_functions_in_parallel:
loop.run_until_complete(tasks[0])
# yield finished tasks
while tasks and tasks[0].done():
Expand Down Expand Up @@ -1257,7 +1266,7 @@ def iter_outputs():
task.cancel(msg="KeyboardInterrupt")
try:
loop.run_until_complete(asyncio.gather(*tasks))
except asyncio.CancelledError:
except (asyncio.CancelledError, ValueError):
logger.debug("Tasks canceled.")
raise

Expand Down Expand Up @@ -1347,6 +1356,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
fn_kwargs=self.fn_kwargs,
formatting=self.formatting,
features=self.features,
max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel,
)

def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable":
Expand All @@ -1363,6 +1373,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "M
fn_kwargs=self.fn_kwargs,
formatting=self.formatting,
features=self.features,
max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel,
)

@property
Expand Down Expand Up @@ -3189,6 +3200,99 @@ def cast(
token_per_repo_id=self._token_per_repo_id,
)

def decode(self, enable: bool = True, num_threads: int = 0) -> "IterableDataset":
"""
Enable or disable the dataset features decoding for audio, image, video.

When enabled (default), media types are decoded:

* audio -> dict of "array" and "sampling_rate" and "path"
* image -> PIL.Image
* video -> torchvision.io.VideoReader

You can enable multithreading using `num_threads`. This is especially useful to speed up remote
data streaming. However it can be slower than `num_threads=0` for local data on fast disks.

Disabling decoding is useful if you want to iterate on the paths or bytes of the media files
without actually decoding their content. To disable decoding you can use `.decode(False)`, which
is equivalent to calling `.cast()` or `.cast_column()` with all the Audio, Image and Video types
set to `decode=False`.

Args:
enable (`bool`, defaults to `True`):
Enable or disable features decoding.
num_threads (`int`, defaults to `0`):
Enable multithreading for features decoding.

Returns:
`IterableDataset`: A copy of the dataset with casted features.

Examples:

Disable decoding:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
>>> next(iter(ds))
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024>,
'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'}
>>> ds = ds.decode(False)
>>> ds.features
{'image': Image(mode=None, decode=False, id=None),
'text': Value(dtype='string', id=None)}
>>> next(iter(ds))
{
'image': {
'path': 'hf://datasets/sshh12/planet-textures@69dc4cef7a5c4b2cfe387727ec8ea73d4bff7302/train/textures/0000.png',
'bytes': None
},
'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'
}
```

Speed up streaming with multithreading:

```py
>>> import os
>>> from datasets import load_dataset
>>> from tqdm import tqdm
>>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
>>> num_threads = min(32, (os.cpu_count() or 1) + 4)
>>> ds = ds.decode(num_threads=num_threads)
>>> for _ in tqdm(ds): # 20 times faster !
... ...
```
"""
if not self.features:
raise ValueError(
"Features decoding is only available for datasets with known features, but features are Unknown. "
"Please set the datasets features with `ds = ds.cast(features)`."
)
ds = self

def set_decoding(decode: bool, feature):
if hasattr(feature, "decode"):
feature.decode = decode

if enable and num_threads > 0:
disabled_decoding_features = self.features.copy()
enabled_decoding_features = self.features.copy()

_visit(disabled_decoding_features, partial(set_decoding, False))
_visit(enabled_decoding_features, partial(set_decoding, True))
ds = ds.cast(disabled_decoding_features)
pool = multiprocessing.pool.ThreadPool(num_threads)
func = partial(_apply_async, pool, enabled_decoding_features.decode_example)
ds = ds.map(func, features=enabled_decoding_features)
assert isinstance(ds._ex_iterable, MappedExamplesIterable)
ds._ex_iterable.max_num_running_async_map_functions_in_parallel = 2 * num_threads
else:
features = ds.features.copy()
_visit(features, partial(set_decoding, enable))
ds = ds.cast(features)
return ds

def _step(self, step: int, offset: int) -> "IterableDataset":
ex_iterable = StepExamplesIterable(self._ex_iterable, step=step, offset=offset)
return IterableDataset(
Expand Down Expand Up @@ -3407,3 +3511,12 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
distributed=distributed,
token_per_repo_id=dataset._token_per_repo_id,
)


async def _apply_async(pool, func, x):
future = pool.apply_async(func, (x,))
while True:
if future.ready():
return future.get()
else:
await asyncio.sleep(0)
28 changes: 28 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2474,3 +2474,31 @@ def test_iterable_dataset_batch():
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]


class DecodableFeature:
decode_example_num_calls = 0

def __init__(self):
self.decode = True

def decode_example(self, example, token_per_repo_id=None):
type(self).decode_example_num_calls += 1
return "decoded" if self.decode else example


def test_decode():
data = [{"i": i} for i in range(10)]
features = Features({"i": DecodableFeature()})
ds = IterableDataset.from_generator(lambda: (x for x in data), features=features)
assert next(iter(ds)) == {"i": "decoded"}
assert DecodableFeature.decode_example_num_calls == 1
ds = ds.decode(False)
assert next(iter(ds)) == {"i": 0}
assert DecodableFeature.decode_example_num_calls == 1
ds = ds.decode(True)
assert next(iter(ds)) == {"i": "decoded"}
assert DecodableFeature.decode_example_num_calls == 2
ds = ds.decode(num_threads=1)
assert next(iter(ds)) == {"i": "decoded"}
assert DecodableFeature.decode_example_num_calls == 4
Loading