Skip to content

Commit 11320c3

Browse files
authored
Add IterableDataset.push_to_hub() (#7595)
* add to_parquet and push_to_hub * style * fix * for datasetdict as well * update docs * docs
1 parent 641128d commit 11320c3

File tree

7 files changed

+824
-23
lines changed

7 files changed

+824
-23
lines changed

docs/source/package_reference/main_classes.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
175175
- take
176176
- shard
177177
- repeat
178+
- to_parquet
179+
- push_to_hub
178180
- load_state_dict
179181
- state_dict
180182
- info
@@ -208,6 +210,7 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable
208210
- rename_column
209211
- rename_columns
210212
- select_columns
213+
- push_to_hub
211214

212215
## Features
213216

src/datasets/dataset_dict.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,3 +2369,301 @@ def cast(
23692369
```
23702370
"""
23712371
return IterableDatasetDict({k: dataset.cast(features=features) for k, dataset in self.items()})
2372+
2373+
def push_to_hub(
2374+
self,
2375+
repo_id,
2376+
config_name: str = "default",
2377+
set_default: Optional[bool] = None,
2378+
data_dir: Optional[str] = None,
2379+
commit_message: Optional[str] = None,
2380+
commit_description: Optional[str] = None,
2381+
private: Optional[bool] = None,
2382+
token: Optional[str] = None,
2383+
revision: Optional[str] = None,
2384+
create_pr: Optional[bool] = False,
2385+
# max_shard_size: Optional[Union[int, str]] = None, # TODO(QL): add arg
2386+
num_shards: Optional[dict[str, int]] = None,
2387+
embed_external_files: bool = True,
2388+
) -> CommitInfo:
2389+
"""Pushes the [`DatasetDict`] to the hub as a Parquet dataset.
2390+
The [`DatasetDict`] is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
2391+
2392+
Each dataset split will be pushed independently. The pushed dataset will keep the original split names.
2393+
2394+
The resulting Parquet files are self-contained by default: if your dataset contains [`Image`] or [`Audio`]
2395+
data, the Parquet files will store the bytes of your images or audio files.
2396+
You can disable this by setting `embed_external_files` to False.
2397+
2398+
Args:
2399+
repo_id (`str`):
2400+
The ID of the repository to push to in the following format: `<user>/<dataset_name>` or
2401+
`<org>/<dataset_name>`. Also accepts `<dataset_name>`, which will default to the namespace
2402+
of the logged-in user.
2403+
config_name (`str`):
2404+
Configuration name of a dataset. Defaults to "default".
2405+
set_default (`bool`, *optional*):
2406+
Whether to set this configuration as the default one. Otherwise, the default configuration is the one
2407+
named "default".
2408+
data_dir (`str`, *optional*):
2409+
Directory name that will contain the uploaded data files. Defaults to the `config_name` if different
2410+
from "default", else "data".
2411+
2412+
<Added version="2.17.0"/>
2413+
commit_message (`str`, *optional*):
2414+
Message to commit while pushing. Will default to `"Upload dataset"`.
2415+
commit_description (`str`, *optional*):
2416+
Description of the commit that will be created.
2417+
Additionally, description of the PR if a PR is created (`create_pr` is True).
2418+
2419+
<Added version="2.16.0"/>
2420+
private (`bool`, *optional*):
2421+
Whether to make the repo private. If `None` (default), the repo will be public unless the
2422+
organization's default is private. This value is ignored if the repo already exists.
2423+
token (`str`, *optional*):
2424+
An optional authentication token for the Hugging Face Hub. If no token is passed, will default
2425+
to the token saved locally when logging in with `huggingface-cli login`. Will raise an error
2426+
if no token is passed and the user is not logged-in.
2427+
revision (`str`, *optional*):
2428+
Branch to push the uploaded files to. Defaults to the `"main"` branch.
2429+
create_pr (`bool`, *optional*, defaults to `False`):
2430+
Whether to create a PR with the uploaded files or directly commit.
2431+
num_shards (`Dict[str, int]`, *optional*):
2432+
Number of shards to write. Equals to this dataset's `.num_shards` by default.
2433+
Use a dictionary to define a different num_shards for each split.
2434+
embed_external_files (`bool`, defaults to `True`):
2435+
Whether to embed file bytes in the shards.
2436+
In particular, this will do the following before the push for the fields of type:
2437+
2438+
- [`Audio`] and [`Image`] removes local path information and embed file content in the Parquet files.
2439+
2440+
Return:
2441+
huggingface_hub.CommitInfo
2442+
2443+
Example:
2444+
2445+
```python
2446+
>>> dataset_dict.push_to_hub("<organization>/<dataset_id>")
2447+
>>> dataset_dict.push_to_hub("<organization>/<dataset_id>", private=True)
2448+
>>> dataset_dict.push_to_hub("<organization>/<dataset_id>", num_shards={"train": 1024, "test": 8})
2449+
```
2450+
2451+
If you want to add a new configuration (or subset) to a dataset (e.g. if the dataset has multiple tasks/versions/languages):
2452+
2453+
```python
2454+
>>> english_dataset.push_to_hub("<organization>/<dataset_id>", "en")
2455+
>>> french_dataset.push_to_hub("<organization>/<dataset_id>", "fr")
2456+
>>> # later
2457+
>>> english_dataset = load_dataset("<organization>/<dataset_id>", "en")
2458+
>>> french_dataset = load_dataset("<organization>/<dataset_id>", "fr")
2459+
```
2460+
"""
2461+
if num_shards is None:
2462+
num_shards = dict.fromkeys(self)
2463+
elif not isinstance(num_shards, dict):
2464+
raise ValueError(
2465+
"Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}"
2466+
)
2467+
2468+
self._check_values_type()
2469+
self._check_values_features()
2470+
total_uploaded_size = 0
2471+
total_dataset_nbytes = 0
2472+
info_to_dump: DatasetInfo = next(iter(self.values())).info.copy()
2473+
info_to_dump.config_name = config_name
2474+
info_to_dump.splits = SplitDict()
2475+
2476+
for split in self.keys():
2477+
if not re.match(_split_re, split):
2478+
raise ValueError(f"Split name should match '{_split_re}' but got '{split}'.")
2479+
2480+
api = HfApi(endpoint=config.HF_ENDPOINT, token=token)
2481+
2482+
repo_url = api.create_repo(
2483+
repo_id,
2484+
token=token,
2485+
repo_type="dataset",
2486+
private=private,
2487+
exist_ok=True,
2488+
)
2489+
repo_id = repo_url.repo_id
2490+
2491+
if revision is not None and not revision.startswith("refs/pr/"):
2492+
# We do not call create_branch for a PR reference: 400 Bad Request
2493+
api.create_branch(
2494+
repo_id,
2495+
branch=revision,
2496+
token=token,
2497+
repo_type="dataset",
2498+
exist_ok=True,
2499+
)
2500+
2501+
if not data_dir:
2502+
data_dir = config_name if config_name != "default" else "data" # for backward compatibility
2503+
2504+
additions = []
2505+
for split in self.keys():
2506+
logger.info(f"Pushing split {split} to the Hub.")
2507+
# The split=key needs to be removed before merging
2508+
split_additions, uploaded_size, dataset_nbytes = self[split]._push_parquet_shards_to_hub(
2509+
repo_id,
2510+
data_dir=data_dir,
2511+
split=split,
2512+
token=token,
2513+
revision=revision,
2514+
create_pr=create_pr,
2515+
# max_shard_size=max_shard_size, # TODO(QL): add arg
2516+
num_shards=num_shards.get(split),
2517+
embed_external_files=embed_external_files,
2518+
)
2519+
additions += split_additions
2520+
total_uploaded_size += uploaded_size
2521+
total_dataset_nbytes += dataset_nbytes
2522+
info_to_dump.splits[split] = SplitInfo(str(split), num_bytes=dataset_nbytes, num_examples=len(self[split]))
2523+
info_to_dump.download_checksums = None
2524+
info_to_dump.download_size = total_uploaded_size
2525+
info_to_dump.dataset_size = total_dataset_nbytes
2526+
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes
2527+
2528+
# Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern)
2529+
# and delete old split shards (if they exist)
2530+
repo_with_dataset_card, repo_with_dataset_infos = False, False
2531+
repo_splits: list[str] = [] # use a list to keep the order of the splits
2532+
deletions: list[CommitOperationDelete] = []
2533+
repo_files_to_add = [addition.path_in_repo for addition in additions]
2534+
for repo_file in api.list_repo_tree(
2535+
repo_id=repo_id,
2536+
revision=revision,
2537+
repo_type="dataset",
2538+
token=token,
2539+
recursive=True,
2540+
):
2541+
if not isinstance(repo_file, RepoFile):
2542+
continue
2543+
if repo_file.rfilename == config.REPOCARD_FILENAME:
2544+
repo_with_dataset_card = True
2545+
elif repo_file.rfilename == config.DATASETDICT_INFOS_FILENAME:
2546+
repo_with_dataset_infos = True
2547+
elif (
2548+
repo_file.rfilename.startswith(tuple(f"{data_dir}/{split}-" for split in self.keys()))
2549+
and repo_file.rfilename not in repo_files_to_add
2550+
):
2551+
deletions.append(CommitOperationDelete(path_in_repo=repo_file.rfilename))
2552+
elif fnmatch.fnmatch(
2553+
repo_file.rfilename,
2554+
PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*"),
2555+
):
2556+
pattern = glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED)
2557+
split_pattern_fields = string_to_dict(repo_file.rfilename, pattern)
2558+
assert split_pattern_fields is not None
2559+
repo_split = split_pattern_fields["split"]
2560+
if repo_split not in repo_splits:
2561+
repo_splits.append(repo_split)
2562+
2563+
# get the info from the README to update them
2564+
if repo_with_dataset_card:
2565+
dataset_card_path = api.hf_hub_download(
2566+
repo_id,
2567+
config.REPOCARD_FILENAME,
2568+
repo_type="dataset",
2569+
revision=revision,
2570+
)
2571+
dataset_card = DatasetCard.load(Path(dataset_card_path))
2572+
dataset_card_data = dataset_card.data
2573+
metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
2574+
# get the deprecated dataset_infos.json to update them
2575+
elif repo_with_dataset_infos:
2576+
dataset_card = None
2577+
dataset_card_data = DatasetCardData()
2578+
metadata_configs = MetadataConfigs()
2579+
else:
2580+
dataset_card = None
2581+
dataset_card_data = DatasetCardData()
2582+
metadata_configs = MetadataConfigs()
2583+
# create the metadata configs if it was uploaded with push_to_hub before metadata configs existed
2584+
if not metadata_configs and repo_splits:
2585+
default_metadata_configs_to_dump = {
2586+
"data_files": [{"split": split, "path": f"data/{split}-*"} for split in repo_splits]
2587+
}
2588+
MetadataConfigs({"default": default_metadata_configs_to_dump}).to_dataset_card_data(dataset_card_data)
2589+
metadata_config_to_dump = {
2590+
"data_files": [{"split": split, "path": f"{data_dir}/{split}-*"} for split in self.keys()],
2591+
}
2592+
configs_to_dump = {config_name: metadata_config_to_dump}
2593+
if set_default and config_name != "default":
2594+
if metadata_configs:
2595+
current_default_config_name = metadata_configs.get_default_config_name()
2596+
if current_default_config_name == "default":
2597+
raise ValueError(
2598+
"There exists a configuration named 'default'. To set a different configuration as default, "
2599+
"rename the 'default' one first."
2600+
)
2601+
if current_default_config_name:
2602+
_ = metadata_configs[current_default_config_name].pop("default")
2603+
configs_to_dump[current_default_config_name] = metadata_configs[current_default_config_name]
2604+
metadata_config_to_dump["default"] = True
2605+
# push to the deprecated dataset_infos.json
2606+
if repo_with_dataset_infos:
2607+
dataset_infos_path = api.hf_hub_download(
2608+
repo_id,
2609+
config.DATASETDICT_INFOS_FILENAME,
2610+
repo_type="dataset",
2611+
revision=revision,
2612+
)
2613+
with open(dataset_infos_path, encoding="utf-8") as f:
2614+
dataset_infos: dict = json.load(f)
2615+
dataset_infos[config_name] = asdict(info_to_dump)
2616+
additions.append(
2617+
CommitOperationAdd(
2618+
path_in_repo=config.DATASETDICT_INFOS_FILENAME,
2619+
path_or_fileobj=json.dumps(dataset_infos, indent=4).encode("utf-8"),
2620+
)
2621+
)
2622+
# push to README
2623+
DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data)
2624+
MetadataConfigs(configs_to_dump).to_dataset_card_data(dataset_card_data)
2625+
dataset_card = DatasetCard(f"---\n{dataset_card_data}\n---\n") if dataset_card is None else dataset_card
2626+
additions.append(
2627+
CommitOperationAdd(
2628+
path_in_repo=config.REPOCARD_FILENAME,
2629+
path_or_fileobj=str(dataset_card).encode(),
2630+
)
2631+
)
2632+
2633+
commit_message = commit_message if commit_message is not None else "Upload dataset"
2634+
if len(additions) <= config.UPLOADS_MAX_NUMBER_PER_COMMIT:
2635+
commit_info = api.create_commit(
2636+
repo_id,
2637+
operations=additions + deletions,
2638+
commit_message=commit_message,
2639+
commit_description=commit_description,
2640+
token=token,
2641+
repo_type="dataset",
2642+
revision=revision,
2643+
create_pr=create_pr,
2644+
)
2645+
else:
2646+
logger.info(
2647+
f"Number of files to upload is larger than {config.UPLOADS_MAX_NUMBER_PER_COMMIT}. Splitting the push into multiple commits."
2648+
)
2649+
num_commits = math.ceil(len(additions) / config.UPLOADS_MAX_NUMBER_PER_COMMIT)
2650+
for i in range(0, num_commits):
2651+
operations = additions[
2652+
i * config.UPLOADS_MAX_NUMBER_PER_COMMIT : (i + 1) * config.UPLOADS_MAX_NUMBER_PER_COMMIT
2653+
] + (deletions if i == 0 else [])
2654+
commit_info = api.create_commit(
2655+
repo_id,
2656+
operations=operations,
2657+
commit_message=commit_message + f" (part {i:05d}-of-{num_commits:05d})",
2658+
commit_description=commit_description,
2659+
token=token,
2660+
repo_type="dataset",
2661+
revision=revision,
2662+
create_pr=create_pr,
2663+
)
2664+
logger.info(
2665+
f"Commit #{i + 1} completed"
2666+
+ (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "")
2667+
+ "."
2668+
)
2669+
return commit_info

src/datasets/features/audio.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str
241241
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
242242
return array_cast(storage, self.pa_type)
243243

244-
def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
244+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
245245
"""Embed audio files into the Arrow array.
246246
247247
Args:
@@ -252,12 +252,20 @@ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
252252
`pa.StructArray`: Array in the Audio arrow storage type, that is
253253
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
254254
"""
255+
if token_per_repo_id is None:
256+
token_per_repo_id = {}
255257

256258
@no_op_if_value_is_null
257259
def path_to_bytes(path):
258-
with xopen(path, "rb") as f:
259-
bytes_ = f.read()
260-
return bytes_
260+
source_url = path.split("::")[-1]
261+
pattern = (
262+
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
263+
)
264+
source_url_fields = string_to_dict(source_url, pattern)
265+
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
266+
download_config = DownloadConfig(token=token)
267+
with xopen(path, "rb", download_config=download_config) as f:
268+
return f.read()
261269

262270
bytes_array = pa.array(
263271
[

src/datasets/features/image.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
250250
)
251251
return array_cast(storage, self.pa_type)
252252

253-
def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
253+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
254254
"""Embed image files into the Arrow array.
255255
256256
Args:
@@ -261,12 +261,20 @@ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
261261
`pa.StructArray`: Array in the Image arrow storage type, that is
262262
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
263263
"""
264+
if token_per_repo_id is None:
265+
token_per_repo_id = {}
264266

265267
@no_op_if_value_is_null
266268
def path_to_bytes(path):
267-
with xopen(path, "rb") as f:
268-
bytes_ = f.read()
269-
return bytes_
269+
source_url = path.split("::")[-1]
270+
pattern = (
271+
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
272+
)
273+
source_url_fields = string_to_dict(source_url, pattern)
274+
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
275+
download_config = DownloadConfig(token=token)
276+
with xopen(path, "rb", download_config=download_config) as f:
277+
return f.read()
270278

271279
bytes_array = pa.array(
272280
[

0 commit comments

Comments
 (0)