Skip to content

Commit 7ce446d

Browse files
authored
[Distributed] Fix load_dataset error when multiprocessing + add test (#544)
* Fix #543 + add test * fix tests * make csv/json/pandas/text able to access distant data
1 parent 8cf7dab commit 7ce446d

File tree

6 files changed

+149
-114
lines changed

6 files changed

+149
-114
lines changed

datasets/csv/csv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,16 @@ def _info(self):
5656
def _split_generators(self, dl_manager):
5757
""" We handle string, list and dicts in datafiles
5858
"""
59-
if isinstance(self.config.data_files, (str, list, tuple)):
60-
files = self.config.data_files
59+
data_files = dl_manager.download_and_extract(self.config.data_files)
60+
if isinstance(data_files, (str, list, tuple)):
61+
files = data_files
6162
if isinstance(files, str):
6263
files = [files]
6364
return [nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"files": files})]
6465
splits = []
6566
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
66-
if split_name in self.config.data_files:
67-
files = self.config.data_files[split_name]
67+
if split_name in data_files:
68+
files = data_files[split_name]
6869
if isinstance(files, str):
6970
files = [files]
7071
splits.append(nlp.SplitGenerator(name=split_name, gen_kwargs={"files": files}))

datasets/json/json.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ def _info(self):
4343
def _split_generators(self, dl_manager):
4444
""" We handle string, list and dicts in datafiles
4545
"""
46-
if isinstance(self.config.data_files, (str, list, tuple)):
47-
files = self.config.data_files
46+
data_files = dl_manager.download_and_extract(self.config.data_files)
47+
if isinstance(data_files, (str, list, tuple)):
48+
files = data_files
4849
if isinstance(files, str):
4950
files = [files]
5051
return [nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"files": files})]
5152
splits = []
5253
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
53-
if split_name in self.config.data_files:
54-
files = self.config.data_files[split_name]
54+
if split_name in data_files:
55+
files = data_files[split_name]
5556
if isinstance(files, str):
5657
files = [files]
5758
splits.append(nlp.SplitGenerator(name=split_name, gen_kwargs={"files": files}))

datasets/pandas/pandas.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ def _info(self):
1313
def _split_generators(self, dl_manager):
1414
""" We handle string, list and dicts in datafiles
1515
"""
16-
if isinstance(self.config.data_files, (str, list, tuple)):
17-
files = self.config.data_files
16+
data_files = dl_manager.download_and_extract(self.config.data_files)
17+
if isinstance(data_files, (str, list, tuple)):
18+
files = data_files
1819
if isinstance(files, str):
1920
files = [files]
2021
return [nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"files": files})]
2122
splits = []
2223
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
23-
if split_name in self.config.data_files:
24-
files = self.config.data_files[split_name]
24+
if split_name in data_files:
25+
files = data_files[split_name]
2526
if isinstance(files, str):
2627
files = [files]
2728
splits.append(nlp.SplitGenerator(name=split_name, gen_kwargs={"files": files}))

datasets/text/text.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@ def _split_generators(self, dl_manager):
1111
If str or List[str], then the dataset returns only the 'train' split.
1212
If dict, then keys should be from the `nlp.Split` enum.
1313
"""
14-
if isinstance(self.config.data_files, (str, list, tuple)):
15-
# Handle case with only one split
16-
files = self.config.data_files
14+
data_files = dl_manager.download_and_extract(self.config.data_files)
15+
if isinstance(data_files, (str, list, tuple)):
16+
files = data_files
1717
if isinstance(files, str):
1818
files = [files]
1919
return [nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"files": files})]
20-
else:
21-
# Handle case with several splits and a dict mapping
22-
splits = []
23-
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
24-
if split_name in self.config.data_files:
25-
files = self.config.data_files[split_name]
26-
if isinstance(files, str):
27-
files = [files]
28-
splits.append(nlp.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
29-
return splits
20+
splits = []
21+
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
22+
if split_name in data_files:
23+
files = data_files[split_name]
24+
if isinstance(files, str):
25+
files = [files]
26+
splits.append(nlp.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
27+
return splits
3028

3129
def _generate_examples(self, files):
3230
""" Read files sequentially, then lines sequentially. """

src/nlp/builder.py

Lines changed: 92 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Dict, List, Optional, Union
2828

2929
import xxhash
30+
from filelock import FileLock
3031

3132
from . import utils
3233
from .arrow_dataset import Dataset
@@ -391,101 +392,104 @@ def download_and_prepare(
391392
dataset_name=self.name, download_config=download_config, data_dir=self.config.data_dir
392393
)
393394

394-
data_exists = os.path.exists(self._cache_dir)
395-
if data_exists and download_mode == REUSE_DATASET_IF_EXISTS:
396-
logger.info("Reusing dataset %s (%s)", self.name, self._cache_dir)
397-
self.download_post_processing_resources(dl_manager)
398-
return
399-
400-
# Currently it's not possible to overwrite the data because it would
401-
# conflict with versioning: If the last version has already been generated,
402-
# it will always be reloaded and cache_dir will be set at construction.
403-
if data_exists and download_mode != REUSE_CACHE_IF_EXISTS:
404-
raise ValueError(
405-
"Trying to overwrite an existing dataset {} at {}. A dataset with "
406-
"the same version {} already exists. If the dataset has changed, "
407-
"please update the version number.".format(self.name, self._cache_dir, self.config.version)
408-
)
395+
# Prevent parallel disk operations
396+
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace("/", "_") + ".lock")
397+
with FileLock(lock_path):
398+
data_exists = os.path.exists(self._cache_dir)
399+
if data_exists and download_mode == REUSE_DATASET_IF_EXISTS:
400+
logger.info("Reusing dataset %s (%s)", self.name, self._cache_dir)
401+
self.download_post_processing_resources(dl_manager)
402+
return
403+
404+
# Currently it's not possible to overwrite the data because it would
405+
# conflict with versioning: If the last version has already been generated,
406+
# it will always be reloaded and cache_dir will be set at construction.
407+
if data_exists and download_mode != REUSE_CACHE_IF_EXISTS:
408+
raise ValueError(
409+
"Trying to overwrite an existing dataset {} at {}. A dataset with "
410+
"the same version {} already exists. If the dataset has changed, "
411+
"please update the version number.".format(self.name, self._cache_dir, self.config.version)
412+
)
409413

410-
logger.info("Generating dataset %s (%s)", self.name, self._cache_dir)
411-
if not is_remote_url(self._cache_dir): # if cache dir is local, check for available space
412-
os.makedirs(self._cache_dir_root, exist_ok=True)
413-
if not utils.has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
414-
raise IOError(
415-
"Not enough disk space. Needed: {} (download: {}, generated: {}, post-processed: {})".format(
416-
utils.size_str(self.info.size_in_bytes or 0),
417-
utils.size_str(self.info.download_size or 0),
418-
utils.size_str(self.info.dataset_size or 0),
419-
utils.size_str(self.info.post_processing_size or 0),
414+
logger.info("Generating dataset %s (%s)", self.name, self._cache_dir)
415+
if not is_remote_url(self._cache_dir): # if cache dir is local, check for available space
416+
os.makedirs(self._cache_dir_root, exist_ok=True)
417+
if not utils.has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
418+
raise IOError(
419+
"Not enough disk space. Needed: {} (download: {}, generated: {}, post-processed: {})".format(
420+
utils.size_str(self.info.size_in_bytes or 0),
421+
utils.size_str(self.info.download_size or 0),
422+
utils.size_str(self.info.dataset_size or 0),
423+
utils.size_str(self.info.post_processing_size or 0),
424+
)
420425
)
426+
427+
@contextlib.contextmanager
428+
def incomplete_dir(dirname):
429+
"""Create temporary dir for dirname and rename on exit."""
430+
if is_remote_url(dirname):
431+
yield dirname
432+
else:
433+
tmp_dir = dirname + ".incomplete"
434+
os.makedirs(tmp_dir)
435+
try:
436+
yield tmp_dir
437+
if os.path.isdir(dirname):
438+
shutil.rmtree(dirname)
439+
os.rename(tmp_dir, dirname)
440+
finally:
441+
if os.path.exists(tmp_dir):
442+
shutil.rmtree(tmp_dir)
443+
444+
# Print is intentional: we want this to always go to stdout so user has
445+
# information needed to cancel download/preparation if needed.
446+
# This comes right before the progress bar.
447+
print(
448+
f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} "
449+
f"(download: {utils.size_str(self.info.download_size)}, generated: {utils.size_str(self.info.dataset_size)}, "
450+
f"post-processed: {utils.size_str(self.info.post_processing_size)}, "
451+
f"total: {utils.size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
452+
)
453+
454+
if self.manual_download_instructions is not None:
455+
assert (
456+
dl_manager.manual_dir is not None
457+
), "The dataset {} with config {} requires manual data. \n Please follow the manual download instructions: {}. \n Manual data can be loaded with `nlp.load_dataset({}, data_dir='<path/to/manual/data>')".format(
458+
self.name, self.config.name, self.manual_download_instructions, self.name
421459
)
422460

423-
@contextlib.contextmanager
424-
def incomplete_dir(dirname):
425-
"""Create temporary dir for dirname and rename on exit."""
426-
if is_remote_url(dirname):
427-
yield dirname
428-
else:
429-
tmp_dir = dirname + ".incomplete"
430-
os.makedirs(tmp_dir)
431-
try:
432-
yield tmp_dir
433-
if os.path.isdir(dirname):
434-
shutil.rmtree(dirname)
435-
os.rename(tmp_dir, dirname)
436-
finally:
437-
if os.path.exists(tmp_dir):
438-
shutil.rmtree(tmp_dir)
439-
440-
# Print is intentional: we want this to always go to stdout so user has
441-
# information needed to cancel download/preparation if needed.
442-
# This comes right before the progress bar.
443-
print(
444-
f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} "
445-
f"(download: {utils.size_str(self.info.download_size)}, generated: {utils.size_str(self.info.dataset_size)}, "
446-
f"post-processed: {utils.size_str(self.info.post_processing_size)}, "
447-
f"total: {utils.size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
448-
)
461+
# Create a tmp dir and rename to self._cache_dir on successful exit.
462+
with incomplete_dir(self._cache_dir) as tmp_data_dir:
463+
# Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward
464+
# it to every sub function.
465+
with utils.temporary_assignment(self, "_cache_dir", tmp_data_dir):
466+
# Try to download the already prepared dataset files
467+
downloaded_from_gcs = False
468+
if try_from_hf_gcs:
469+
try:
470+
self._download_prepared_from_hf_gcs()
471+
downloaded_from_gcs = True
472+
except (DatasetNotOnHfGcs, MissingFilesOnHfGcs):
473+
logger.info("Dataset not on Hf google storage. Downloading and preparing it from source")
474+
if not downloaded_from_gcs:
475+
self._download_and_prepare(
476+
dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs
477+
)
478+
# Sync info
479+
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
480+
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
481+
self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
482+
# Save info
483+
self._save_info()
484+
485+
# Download post processing resources
486+
self.download_post_processing_resources(dl_manager)
449487

450-
if self.manual_download_instructions is not None:
451-
assert (
452-
dl_manager.manual_dir is not None
453-
), "The dataset {} with config {} requires manual data. \n Please follow the manual download instructions: {}. \n Manual data can be loaded with `nlp.load_dataset({}, data_dir='<path/to/manual/data>')".format(
454-
self.name, self.config.name, self.manual_download_instructions, self.name
488+
print(
489+
f"Dataset {self.name} downloaded and prepared to {self._cache_dir}. "
490+
f"Subsequent calls will reuse this data."
455491
)
456492

457-
# Create a tmp dir and rename to self._cache_dir on successful exit.
458-
with incomplete_dir(self._cache_dir) as tmp_data_dir:
459-
# Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward
460-
# it to every sub function.
461-
with utils.temporary_assignment(self, "_cache_dir", tmp_data_dir):
462-
# Try to download the already prepared dataset files
463-
downloaded_from_gcs = False
464-
if try_from_hf_gcs:
465-
try:
466-
self._download_prepared_from_hf_gcs()
467-
downloaded_from_gcs = True
468-
except (DatasetNotOnHfGcs, MissingFilesOnHfGcs):
469-
logger.info("Dataset not on Hf google storage. Downloading and preparing it from source")
470-
if not downloaded_from_gcs:
471-
self._download_and_prepare(
472-
dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs
473-
)
474-
# Sync info
475-
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
476-
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
477-
self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
478-
# Save info
479-
self._save_info()
480-
481-
# Download post processing resources
482-
self.download_post_processing_resources(dl_manager)
483-
484-
print(
485-
f"Dataset {self.name} downloaded and prepared to {self._cache_dir}. "
486-
f"Subsequent calls will reuse this data."
487-
)
488-
489493
def _download_prepared_from_hf_gcs(self):
490494
relative_data_dir = self._relative_data_dir(with_version=True, with_hash=False)
491495
reader = ArrowReader(self._cache_dir, self.info)

tests/test_dataset_common.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import logging
1818
import os
1919
import tempfile
20+
from multiprocessing import Pool
21+
from unittest import TestCase
2022

2123
import requests
2224
from absl.testing import parameterized
@@ -27,6 +29,7 @@
2729
DownloadConfig,
2830
GenerateMode,
2931
MockDownloadManager,
32+
cached_path,
3033
hf_api,
3134
hf_bucket_url,
3235
import_main_class,
@@ -42,7 +45,7 @@
4245

4346
class DatasetTester(object):
4447
def __init__(self, parent):
45-
self.parent = parent
48+
self.parent = parent if parent is not None else TestCase()
4649

4750
def load_builder_class(self, dataset_name, is_local=False):
4851
# Download/copy dataset script
@@ -219,6 +222,33 @@ def test_load_real_dataset_all_configs(self, dataset_name):
219222
self.assertTrue(len(dataset[split]) > 0)
220223

221224

225+
def distributed_load_dataset(args):
226+
data_name, tmp_dir, datafiles = args
227+
dataset = load_dataset(data_name, cache_dir=tmp_dir, data_files=datafiles)
228+
return dataset
229+
230+
231+
class DistributedDatasetTest(TestCase):
232+
def test_load_dataset_distributed(self):
233+
num_workers = 5
234+
with tempfile.TemporaryDirectory() as tmp_dir:
235+
data_name = "./datasets/csv"
236+
data_base_path = os.path.join(data_name, "dummy/0.0.0/dummy_data.zip")
237+
local_path = cached_path(
238+
data_base_path, cache_dir=tmp_dir, extract_compressed_file=True, force_extract=True
239+
)
240+
datafiles = {
241+
"train": os.path.join(local_path, "dummy_data/train.csv"),
242+
"dev": os.path.join(local_path, "dummy_data/dev.csv"),
243+
"test": os.path.join(local_path, "dummy_data/test.csv"),
244+
}
245+
args = data_name, tmp_dir, datafiles
246+
with Pool(processes=num_workers) as pool: # start num_workers processes
247+
result = pool.apply_async(distributed_load_dataset, (args,))
248+
_ = result.get(timeout=20)
249+
_ = pool.map(distributed_load_dataset, [args] * num_workers)
250+
251+
222252
def get_aws_dataset_names():
223253
api = hf_api.HfApi()
224254
# fetch all dataset names

0 commit comments

Comments
 (0)