Skip to content

Commit d10e846

Browse files
simonreiselhoestq
andauthored
Add custom fingerprint support to from_generator (#7533)
* Add custom suffix support to from_generator * Renamed a new arg to fingerprint * Changed name to config_id in builder * Change version * Added a test * Version update * Update version * Update tests/test_arrow_dataset.py * Rename config_id to fingerprint in generator.py * Apply suggestions from code review * Update src/datasets/io/generator.py * Apply suggestions from code review --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent fb445ff commit d10e846

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

src/datasets/arrow_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,7 @@ def from_generator(
11201120
gen_kwargs: Optional[dict] = None,
11211121
num_proc: Optional[int] = None,
11221122
split: NamedSplit = Split.TRAIN,
1123+
fingerprint: Optional[str] = None,
11231124
**kwargs,
11241125
):
11251126
"""Create a Dataset from a generator.
@@ -1146,6 +1147,12 @@ def from_generator(
11461147
Split name to be assigned to the dataset.
11471148
11481149
<Added version="2.21.0"/>
1150+
fingerprint (`str`, *optional*):
1151+
Fingerprint that will be used to generate dataset ID.
1152+
By default `fingerprint` is generated by hashing the generator function and all the args which can be slow
1153+
if it uses large objects like AI models.
1154+
1155+
<Added version="4.3.0"/>
11491156
**kwargs (additional keyword arguments):
11501157
Keyword arguments to be passed to :[`GeneratorConfig`].
11511158
@@ -1183,6 +1190,7 @@ def from_generator(
11831190
gen_kwargs=gen_kwargs,
11841191
num_proc=num_proc,
11851192
split=split,
1193+
fingerprint=fingerprint,
11861194
**kwargs,
11871195
).read()
11881196

src/datasets/builder.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def __init__(
313313
data_dir: Optional[str] = None,
314314
storage_options: Optional[dict] = None,
315315
writer_batch_size: Optional[int] = None,
316+
config_id: Optional[str] = None,
316317
**config_kwargs,
317318
):
318319
# DatasetBuilder name
@@ -343,6 +344,7 @@ def __init__(
343344
self.config, self.config_id = self._create_builder_config(
344345
config_name=config_name,
345346
custom_features=features,
347+
config_id=config_id,
346348
**config_kwargs,
347349
)
348350

@@ -502,7 +504,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st
502504
return legacy_relative_data_dir
503505

504506
def _create_builder_config(
505-
self, config_name=None, custom_features=None, **config_kwargs
507+
self, config_name=None, custom_features=None, config_id=None, **config_kwargs
506508
) -> tuple[BuilderConfig, str]:
507509
"""Create and validate BuilderConfig object as well as a unique config id for this config.
508510
Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None.
@@ -570,10 +572,11 @@ def _create_builder_config(
570572
)
571573

572574
# compute the config id that is going to be used for caching
573-
config_id = builder_config.create_config_id(
574-
config_kwargs,
575-
custom_features=custom_features,
576-
)
575+
if config_id is None:
576+
config_id = builder_config.create_config_id(
577+
config_kwargs,
578+
custom_features=custom_features,
579+
)
577580
is_custom = (config_id not in self.builder_configs) and config_id != "default"
578581
if is_custom:
579582
logger.info(f"Using custom data configuration {config_id}")

src/datasets/io/generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
gen_kwargs: Optional[dict] = None,
1717
num_proc: Optional[int] = None,
1818
split: NamedSplit = Split.TRAIN,
19+
fingerprint: Optional[str] = None,
1920
**kwargs,
2021
):
2122
super().__init__(
@@ -32,8 +33,10 @@ def __init__(
3233
generator=generator,
3334
gen_kwargs=gen_kwargs,
3435
split=split,
36+
config_id="default-fingerprint=" + fingerprint if fingerprint else None,
3537
**kwargs,
3638
)
39+
self.fingerprint = fingerprint
3740

3841
def read(self):
3942
# Build iterable dataset
@@ -56,4 +59,6 @@ def read(self):
5659
dataset = self.builder.as_dataset(
5760
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
5861
)
62+
if self.fingerprint:
63+
dataset._fingerprint = self.fingerprint
5964
return dataset

tests/test_arrow_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4114,6 +4114,16 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path):
41144114
_check_generator_dataset(dataset, expected_features, expected_split)
41154115

41164116

4117+
@pytest.mark.parametrize("fingerprint", [None, "test-dataset"])
4118+
def test_dataset_from_generator_fingerprint(fingerprint, data_generator, tmp_path):
4119+
cache_dir = tmp_path / "cache"
4120+
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
4121+
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, fingerprint=fingerprint)
4122+
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))
4123+
if fingerprint:
4124+
assert dataset._fingerprint == fingerprint
4125+
4126+
41174127
@require_not_windows
41184128
@require_dill_gt_0_3_2
41194129
@require_pyspark

0 commit comments

Comments
 (0)