Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions python/ray/data/_internal/datasource/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,43 +118,48 @@ def list_parquet_urls_from_dataset(
def estimate_inmemory_data_size(self) -> Optional[int]:
return self._dataset.dataset_size

def _read_dataset(self) -> Iterable[Block]:
import numpy as np
import pandas as pd
import pyarrow
# Note: we pass `self` here instead of `self._dataset` because
# we need to trigger the try-import logic at the top of file
# to avoid import error of dataset_modules.
for batch in self._dataset.with_format("arrow").iter(
batch_size=self._batch_size
):
# HuggingFace IterableDatasets do not fully support methods like
# `set_format`, `with_format`, and `formatted_as`, so the dataset
# can return whatever is the default configured batch type, even if
# the format is manually overriden before iterating above.
# Therefore, we limit support to batch formats which have native
# block types in Ray Data (pyarrow.Table, pd.DataFrame),
# or can easily be converted to such (dict, np.array).
# See: https://github.com/huggingface/datasets/issues/3444
if not isinstance(batch, (pyarrow.Table, pd.DataFrame, dict, np.array)):
raise ValueError(
f"Batch format {type(batch)} isn't supported. Only the "
f"following batch formats are supported: "
f"dict (corresponds to `None` in `dataset.with_format()`), "
f"pyarrow.Table, np.array, pd.DataFrame."
)
# Ensure np.arrays are wrapped in a dict
# (subsequently converted to a pyarrow.Table).
if isinstance(batch, np.ndarray):
batch = {"item": batch}
if isinstance(batch, dict):
batch = pyarrow_table_from_pydict(batch)
# Ensure that we return the default block type.
block = BlockAccessor.for_block(batch).to_default()
yield block

def get_read_tasks(
self,
parallelism: int,
) -> List[ReadTask]:
# Note: `parallelism` arg is currently not used by HuggingFaceDatasource.
# We always generate a single ReadTask to perform the read.
_check_pyarrow_version()
import numpy as np
import pandas as pd
import pyarrow

def _read_dataset(dataset: "datasets.IterableDataset") -> Iterable[Block]:
for batch in dataset.with_format("arrow").iter(batch_size=self._batch_size):
# HuggingFace IterableDatasets do not fully support methods like
# `set_format`, `with_format`, and `formatted_as`, so the dataset
# can return whatever is the default configured batch type, even if
# the format is manually overriden before iterating above.
# Therefore, we limit support to batch formats which have native
# block types in Ray Data (pyarrow.Table, pd.DataFrame),
# or can easily be converted to such (dict, np.array).
# See: https://github.com/huggingface/datasets/issues/3444
if not isinstance(batch, (pyarrow.Table, pd.DataFrame, dict, np.array)):
raise ValueError(
f"Batch format {type(batch)} isn't supported. Only the "
f"following batch formats are supported: "
f"dict (corresponds to `None` in `dataset.with_format()`), "
f"pyarrow.Table, np.array, pd.DataFrame."
)
# Ensure np.arrays are wrapped in a dict
# (subsequently converted to a pyarrow.Table).
if isinstance(batch, np.ndarray):
batch = {"item": batch}
if isinstance(batch, dict):
batch = pyarrow_table_from_pydict(batch)
# Ensure that we return the default block type.
block = BlockAccessor.for_block(batch).to_default()
yield block

# TODO(scottjlee): IterableDataset doesn't provide APIs
# for getting number of rows, byte size, etc., so the
Expand All @@ -169,7 +174,7 @@ def _read_dataset(dataset: "datasets.IterableDataset") -> Iterable[Block]:
)
read_tasks: List[ReadTask] = [
ReadTask(
lambda hfds=self._dataset: _read_dataset(hfds),
self._read_dataset,
meta,
)
]
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def test_from_huggingface_streaming(batch_format, ray_start_regular_shared):
assert ds.count() == 355


@pytest.mark.skipif(
datasets.Version(datasets.__version__) < datasets.Version("2.8.0"),
reason="IterableDataset.iter() added in 2.8.0",
)
def test_from_huggingface_dynamic_generated(ray_start_regular_shared):
# https://github.com/ray-project/ray/issues/49529
hfds = datasets.load_dataset(
"dataset-org/dream",
split="test",
streaming=True,
trust_remote_code=True,
)
ds = ray.data.from_huggingface(hfds)
ds.take(1)


if __name__ == "__main__":
import sys

Expand Down