Skip to content

Commit c214aa5

Browse files
authored
Add multiprocessing (#552)
* add multiprocessing * test multiprocessing and fix empty tables issues * improve logging * minor * fix unit * fix tqdm in notebook for multiprocessing
1 parent f9e8f7e commit c214aa5

File tree

7 files changed

+386
-163
lines changed

7 files changed

+386
-163
lines changed

src/nlp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pyarrow import total_allocated_bytes
2525

2626
from . import datasets
27-
from .arrow_dataset import Dataset
27+
from .arrow_dataset import Dataset, concatenate_datasets
2828
from .arrow_reader import ArrowReader, ReadInstruction
2929
from .arrow_writer import ArrowWriter
3030
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
@@ -43,7 +43,7 @@
4343
)
4444
from .info import DatasetInfo, MetricInfo
4545
from .inspect import inspect_dataset, inspect_metric, list_datasets, list_metrics
46-
from .load import concatenate_datasets, import_main_class, load_dataset, load_metric, prepare_module
46+
from .load import import_main_class, load_dataset, load_metric, prepare_module
4747
from .metric import Metric
4848
from .splits import NamedSplit, Split, SplitBase, SplitDict, SplitGenerator, SplitInfo, SubSplitInfo, percent
4949
from .utils import *

src/nlp/arrow_dataset.py

Lines changed: 310 additions & 41 deletions
Large diffs are not rendered by default.

src/nlp/arrow_reader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,13 @@ def _read_files(self, files) -> pa.Table:
159159
skip/take indicates which example read in the file: `ds.slice(skip, take)`
160160
"""
161161
assert len(files) > 0 and all(isinstance(f, dict) for f in files), "please provide valid file informations"
162-
pa_batches = []
162+
pa_tables = []
163163
for f_dict in files:
164164
pa_table: pa.Table = self._get_dataset_from_filename(f_dict)
165-
pa_batches.extend(pa_table.to_batches())
166-
assert len(pa_batches) > 0, "tried to read an empty arrow table"
167-
pa_table = pa.Table.from_batches(pa_batches)
165+
pa_tables.append(pa_table)
166+
pa_tables = [t for t in pa_tables if len(t) > 0]
167+
pa_tables = pa_tables or [pa.Table.from_batches([], schema=pa.schema(self._info.features.type))]
168+
pa_table = pa.concat_tables(pa_tables)
168169
return pa_table
169170

170171
def get_file_instructions(self, name, instruction, split_infos):

src/nlp/arrow_writer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
disable_nullable: bool = False,
135135
update_features: bool = False,
136136
with_metadata: bool = True,
137+
unit: str = "examples",
137138
):
138139
if path is None and stream is None:
139140
raise ValueError("At least one of path and stream must be provided.")
@@ -161,6 +162,7 @@ def __init__(
161162
self.writer_batch_size = writer_batch_size or DEFAULT_MAX_BATCH_SIZE
162163
self.update_features = update_features
163164
self.with_metadata = with_metadata
165+
self.unit = unit
164166

165167
self._num_examples = 0
166168
self._num_bytes = 0
@@ -290,8 +292,9 @@ def finalize(self, close_stream=True):
290292
if close_stream:
291293
self.stream.close()
292294
logger.info(
293-
"Done writing %s examples in %s bytes %s.",
295+
"Done writing %s %s in %s bytes %s.",
294296
self._num_examples,
297+
self.unit,
295298
self._num_bytes,
296299
self._path if self._path else "",
297300
)

src/nlp/fingerprint.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,11 @@ def _fingerprint(func):
123123

124124
@wraps(func)
125125
def wrapper(*args, **kwargs):
126-
self: "Dataset" = args[0]
127-
args = args[1:]
126+
if args:
127+
self: "Dataset" = args[0]
128+
args = args[1:]
129+
else:
130+
self: "Dataset" = kwargs.pop("self")
128131
kwargs_for_fingerprint = dict(kwargs)
129132
kwargs_for_fingerprint.update(zip(func.__code__.co_varnames, args))
130133

@@ -145,7 +148,7 @@ def wrapper(*args, **kwargs):
145148
new_inplace_history_item = (func.__name__, deepcopy(args), deepcopy(kwargs))
146149
else:
147150
for fingerprint_name in fingerprint_names: # transforms like `train_test_split` have several hashes
148-
if fingerprint_name not in kwargs:
151+
if kwargs.get(fingerprint_name) is None:
149152
kwargs_for_fingerprint["fingerprint_name"] = fingerprint_name
150153
kwargs[fingerprint_name] = update_fingerprint(self._fingerprint, func, kwargs_for_fingerprint)
151154

src/nlp/load.py

Lines changed: 2 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,16 @@
2424
import shutil
2525
from hashlib import sha256
2626
from pathlib import Path
27-
from typing import Any, Dict, List, Optional, Tuple, Union
27+
from typing import Dict, List, Optional, Tuple, Union
2828
from urllib.parse import urlparse
2929

30-
import numpy as np
31-
import pyarrow as pa
3230
from filelock import FileLock
3331

3432
from .arrow_dataset import Dataset
3533
from .builder import DatasetBuilder
3634
from .dataset_dict import DatasetDict
3735
from .features import Features
38-
from .fingerprint import update_fingerprint
39-
from .info import DATASET_INFOS_DICT_FILE_NAME, DatasetInfo
36+
from .info import DATASET_INFOS_DICT_FILE_NAME
4037
from .metric import Metric
4138
from .splits import Split
4239
from .utils.download_manager import GenerateMode
@@ -560,110 +557,3 @@ def load_dataset(
560557
builder_instance._save_infos()
561558

562559
return ds
563-
564-
565-
def concatenate_datasets(
566-
dsets: List["Dataset"],
567-
info: Optional[Any] = None,
568-
split: Optional[Any] = None,
569-
):
570-
"""
571-
Converts a list of :obj:``nlp.Dataset`` with the same schema into a single :obj:``nlp.Dataset``.
572-
573-
Args:
574-
dsets (:obj:``List[nlp.Dataset]``): A list of Datasets to concatenate
575-
info (:obj:``nlp.DatasetInfo``, `optional`, defaults to :obj:``None``): If specified, the dataset info containing info like
576-
description, citation, etc.
577-
split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
578-
"""
579-
if not all([dset.features.type == dsets[0].features.type for dset in dsets]):
580-
raise ValueError("Features must match for all datasets")
581-
582-
# Datasets tables should all come from disk or memory, but not a mix
583-
584-
dsets_in_memory = [not dset._data_files for dset in dsets]
585-
if any(dset_in_memory != dsets_in_memory[0] for dset_in_memory in dsets_in_memory):
586-
raise ValueError(
587-
"Datasets should ALL come from memory, or should ALL come from disk.\n"
588-
"However datasets {} come from memory and datasets {} come from disk.".format(
589-
[i for i in range(len(dsets)) if dsets_in_memory[i]],
590-
[i for i in range(len(dsets)) if not dsets_in_memory[i]],
591-
)
592-
)
593-
594-
# Concatenate tables
595-
596-
table = pa.concat_tables([dset._data for dset in dsets])
597-
data_files = [f for dset in dsets for f in dset._data_files]
598-
inplace_history = [h for dset in dsets for h in dset._inplace_history]
599-
600-
def apply_offset_to_indices_table(table, offset):
601-
if offset == 0:
602-
return table
603-
else:
604-
array = table["indices"]
605-
if isinstance(array, pa.ChunkedArray):
606-
new_array = pa.array(np.concatenate([c.to_numpy() for c in array.chunks]) + offset, pa.uint64())
607-
else:
608-
new_array = pa.array(array.to_numpy() + offset, pa.uint64())
609-
return pa.Table.from_arrays([new_array], names=["indices"])
610-
611-
# Concatenate indices if they exist
612-
613-
if any(dset._indices is not None for dset in dsets):
614-
615-
# Datasets indices tables should all come from disk or memory, but not a mix
616-
# Datasets with no indices tables are replaced with a dataset with an indicies table in memory
617-
618-
indices_mappings_in_memory = [not dset._indices_data_files for dset in dsets]
619-
if any(
620-
indices_mapping_in_memory != indices_mappings_in_memory[0]
621-
for indices_mapping_in_memory in indices_mappings_in_memory
622-
):
623-
raise ValueError(
624-
"Datasets' indices should ALL come from memory, or should ALL come from disk.\n"
625-
"However datasets' indices {} come from memory and datasets' indices {} come from disk.".format(
626-
[i for i in range(len(dsets)) if indices_mappings_in_memory[i]],
627-
[i for i in range(len(dsets)) if not indices_mappings_in_memory[i]],
628-
)
629-
)
630-
indices_in_memory = indices_mappings_in_memory[0]
631-
632-
# Create missing indices tables in memory
633-
634-
if indices_in_memory:
635-
for i in range(len(dsets)):
636-
if dsets[i]._indices is None:
637-
dsets[i] = dsets[i].select(range(len(dsets[i])))
638-
assert all(dset._indices is not None for dset in dsets), "each dataset should have an indices table"
639-
640-
# An offset needs to be applied to the indices before concatenating
641-
642-
indices_tables = []
643-
offset = 0
644-
for dset in dsets:
645-
indices_tables.append(apply_offset_to_indices_table(dset._indices, offset))
646-
offset += len(dset._data)
647-
648-
# Concatenate indices
649-
650-
indices_table = pa.concat_tables(indices_tables)
651-
indices_data_files = None if indices_in_memory else [f for dset in dsets for f in dset._indices_data_files]
652-
else:
653-
indices_table = None
654-
indices_data_files = None
655-
if info is None:
656-
info = DatasetInfo.from_merge([dset.info for dset in dsets])
657-
fingerprint = update_fingerprint(
658-
"".join(dset._fingerprint for dset in dsets), concatenate_datasets, {"info": info, "split": split}
659-
)
660-
return Dataset(
661-
table,
662-
info=info,
663-
split=split,
664-
data_files=data_files,
665-
indices_table=indices_table,
666-
indices_data_files=indices_data_files,
667-
fingerprint=fingerprint,
668-
inplace_history=inplace_history,
669-
)

tests/test_arrow_dataset.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ def __getstate__(self):
2020
raise pickle.PicklingError()
2121

2222

23+
def picklable_map_function(x):
24+
return {"id": int(x["filename"].split("_")[-1])}
25+
26+
27+
def picklable_filter_function(x):
28+
return int(x["filename"].split("_")[-1]) < 10
29+
30+
2331
class BaseDatasetTest(TestCase):
2432
def _create_dummy_dataset(self, multiple_columns=False):
2533
if multiple_columns:
@@ -493,6 +501,34 @@ def func(x, i):
493501
Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
494502
)
495503

504+
def test_map_multiprocessing(self):
505+
dset = self._create_dummy_dataset()
506+
507+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
508+
fingerprint = dset._fingerprint
509+
dset_test = dset.map(picklable_map_function, num_proc=2)
510+
self.assertEqual(len(dset_test), 30)
511+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
512+
self.assertDictEqual(
513+
dset_test.features,
514+
Features({"filename": Value("string"), "id": Value("int64")}),
515+
)
516+
self.assertEqual(len(dset_test._data_files), 0)
517+
self.assertNotEqual(dset_test._fingerprint, fingerprint)
518+
519+
with tempfile.TemporaryDirectory() as tmp_dir:
520+
dset = dset.map(cache_file_name=os.path.join(tmp_dir, "test.arrow"))
521+
fingerprint = dset._fingerprint
522+
dset_test = dset.map(picklable_map_function, num_proc=3)
523+
self.assertEqual(len(dset_test), 30)
524+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
525+
self.assertDictEqual(
526+
dset_test.features,
527+
Features({"filename": Value("string"), "id": Value("int64")}),
528+
)
529+
self.assertEqual(len(dset_test._data_files), 3)
530+
self.assertNotEqual(dset_test._fingerprint, fingerprint)
531+
496532
def test_new_features(self):
497533
dset = self._create_dummy_dataset()
498534

@@ -642,6 +678,27 @@ def test_filter(self):
642678
self.assertDictEqual(dset_filter_even_num.features, Features({"filename": Value("string")}))
643679
self.assertNotEqual(dset_filter_even_num._fingerprint, fingerprint)
644680

681+
def test_filter_multiprocessing(self):
682+
dset = self._create_dummy_dataset()
683+
684+
fingerprint = dset._fingerprint
685+
dset_filter_first_ten = dset.filter(picklable_filter_function, num_proc=2)
686+
self.assertEqual(len(dset_filter_first_ten), 10)
687+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
688+
self.assertDictEqual(dset_filter_first_ten.features, Features({"filename": Value("string")}))
689+
self.assertEqual(len(dset_filter_first_ten._data_files), 0)
690+
self.assertNotEqual(dset_filter_first_ten._fingerprint, fingerprint)
691+
692+
with tempfile.TemporaryDirectory() as tmp_dir:
693+
dset = dset.map(cache_file_name=os.path.join(tmp_dir, "test.arrow"))
694+
fingerprint = dset._fingerprint
695+
dset_filter_first_ten = dset.filter(picklable_filter_function, num_proc=2)
696+
self.assertEqual(len(dset_filter_first_ten), 10)
697+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
698+
self.assertDictEqual(dset_filter_first_ten.features, Features({"filename": Value("string")}))
699+
self.assertEqual(len(dset_filter_first_ten._data_files), 2)
700+
self.assertNotEqual(dset_filter_first_ten._fingerprint, fingerprint)
701+
645702
def test_keep_features_after_transform_specified(self):
646703
features = Features(
647704
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}

0 commit comments

Comments
 (0)