Skip to content

Conversation

@lhoestq
Copy link
Member

@lhoestq lhoestq commented Sep 1, 2020

Adding multiprocessing to .map

It works in 3 steps:

  • shard the dataset in num_proc shards
  • spawn one process per shard and call map on them
  • concatenate the resulting datasets

Example of usage:

from nlp import load_dataset

dataset = load_dataset("squad", split="train")

def function(x):
    return {"lowered": x.lower()}

processed = d.map(
    function,
    input_columns=["context"],
    num_proc=4,
    cache_file_name="playground/tmp.arrow",
    load_from_cache_file=False
)

Here it writes 4 files depending on the process rank:

  • playground/tmp_00000_of_00004.arrow
  • playground/tmp_00001_of_00004.arrow
  • playground/tmp_00002_of_00004.arrow
  • playground/tmp_00003_of_00004.arrow

The suffix format can be specified by the user.

If the cache_file_name is not specified, it writes into separated files depending on the fingerprint, as usual.

I still need to:

  • write tests for this
  • try to improve the logging (currently it shows 4 progress bars, but if one finishes before the others, then the following messages are written over the progress bars)

@lhoestq
Copy link
Member Author

lhoestq commented Sep 2, 2020

Logging looks like

Done writing 21900 indices in 3854400 bytes .
Process #0 will write at playground/tmp_00000_of_00004.arrow
Done writing 21900 indices in 3854400 bytes .
Process #1 will write at playground/tmp_00001_of_00004.arrow
Done writing 21900 indices in 3854400 bytes .
Process #2 will write at playground/tmp_00002_of_00004.arrow
Done writing 21899 indices in 3854224 bytes .
Process #3 will write at playground/tmp_00003_of_00004.arrow
Spawning 4 processes
#3: 100%|████████████████████████████████████████████████| 21899/21899 [00:02<00:00, 8027.41ex/s]
#0: 100%|████████████████████████████████████████████████| 21900/21900 [00:02<00:00, 7982.87ex/s]
#1: 100%|████████████████████████████████████████████████| 21900/21900 [00:02<00:00, 7923.89ex/s]
#2: 100%|████████████████████████████████████████████████| 21900/21900 [00:02<00:00, 7920.04ex/s]
Concatenating 4 shards from multiprocessing

@lhoestq lhoestq force-pushed the add-multiprocessing branch from c25510f to f62f124 Compare September 2, 2020 08:41
@lhoestq lhoestq marked this pull request as ready for review September 2, 2020 09:29
@lhoestq lhoestq requested a review from thomwolf September 2, 2020 09:29
@lhoestq
Copy link
Member Author

lhoestq commented Sep 2, 2020

I added tests and improved logging.
Both map and filter support multiprocessing

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomwolf
Copy link
Member

thomwolf commented Sep 2, 2020

A bit strange that the benchmarks on map/filter are worth than master.
(maybe because they are not done on the same machine)

@lhoestq
Copy link
Member Author

lhoestq commented Sep 2, 2020

The benchmark also got worse in other PRs (see here for example, where we have 16sec for map fast-tokenizer batched and 18 sec for map identity)

@lhoestq lhoestq merged commit c214aa5 into master Sep 2, 2020
@lhoestq lhoestq deleted the add-multiprocessing branch September 2, 2020 10:01
@kandorm
Copy link

kandorm commented Sep 11, 2020

Hi,

when I use the multiprocessing in .map:

dataset = load_dataset("text", data_files=file_path, split="train")
dataset = dataset.map(lambda ex: tokenizer(ex["text"], add_special_tokens=True,
                                      truncation=True, max_length=args.block_size), batched=True, num_proc=16)
dataset.set_format(type='torch', columns=['input_ids'])

I get the following error:

Traceback (most recent call last):
  File "src/run.py", line 373, in <module>
    main()
  File "src/run.py", line 295, in main
    get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
  File "src/run.py", line 153, in get_dataset
    dataset = dataset.map(lambda ex: tokenizer(ex["text"], add_special_tokens=True,
  File "/root/miniconda3/envs/py3.8/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1287, in map
    transformed_shards = [r.get() for r in results]
  File "/root/miniconda3/envs/py3.8/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1287, in <listcomp>
    transformed_shards = [r.get() for r in results]
  File "/root/miniconda3/envs/py3.8/lib/python3.8/multiprocessing/pool.py", line 771, in get
    raise self._value
    put(task)
  File "/root/miniconda3/envs/py3.8/lib/python3.8/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/root/miniconda3/envs/py3.8/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_dataset.<locals>.<lambda>'

I think you should use pathos to pickle the lambda function and some others!
I change the 30 line of src/datasets/arrow_dataset.py as following:

# 30 line: from multiprocessing import Pool, RLock
import pathos
from pathos.multiprocessing import Pool
from multiprocessing import RLock

and it works!

@lhoestq
Copy link
Member Author

lhoestq commented Sep 11, 2020

That's very cool indeed !
Shall we condiser adding this dependency @thomwolf ?

@thomwolf
Copy link
Member

We already use dill so that's definitely a very interesting option indeed!

@abhi1nandy2
Copy link

it gets stuck on debian 9 when num_proc > 1

@lhoestq
Copy link
Member Author

lhoestq commented Sep 22, 2020

Are you using a tokenizer ?
Did you try to set TOKENIZERS_PARALLELISM=false ?

Feel free to discuss it in #620 , we're discussing this issue

@abhi1nandy2
Copy link

I set TOKENIZERS_PARALLELISM=false. Just the warning went away. The program was still stuck

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants