Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 29c9572

Browse files
[Parallel] Avoid race condition when downloading vocab (#1078)
* use os.replace * Update files.py * Update utils.py
1 parent f8a4318 commit 29c9572

File tree

3 files changed

+56
-26
lines changed

3 files changed

+56
-26
lines changed

src/gluonnlp/data/utils.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717

1818
"""Utility classes and functions. They help organize and keep statistics of datasets."""
1919
import collections
20-
import errno
2120
import os
2221
import tarfile
2322
import zipfile
24-
import time
23+
import random
24+
import sys
25+
import shutil
2526

2627
import numpy as np
2728
from mxnet.gluon.data import SimpleDataset
2829
from mxnet.gluon.utils import _get_repo_url, check_sha1, download
2930

3031
from .. import _constants as C
32+
from .. import utils
3133

3234
__all__ = [
3335
'Counter', 'count_tokens', 'concat_sequence', 'slice_sequence', 'train_valid_split',
@@ -303,6 +305,11 @@ def _load_pretrained_vocab(name, root, cls=None):
303305
root = os.path.expanduser(root)
304306
file_path = os.path.join(root, file_name + '.vocab')
305307
sha1_hash = _vocab_sha1[name]
308+
309+
temp_num = str(random.Random().randint(1, sys.maxsize))
310+
temp_root = os.path.join(root, temp_num)
311+
temp_file_path = os.path.join(temp_root, file_name + '.vocab')
312+
temp_zip_file_path = os.path.join(root, temp_num + file_name + '.zip')
306313
if os.path.exists(file_path):
307314
if check_sha1(file_path, sha1_hash):
308315
return _load_vocab_file(file_path, cls)
@@ -311,34 +318,19 @@ def _load_pretrained_vocab(name, root, cls=None):
311318
else:
312319
print('Vocab file is not found. Downloading.')
313320

314-
if not os.path.exists(root):
315-
try:
316-
os.makedirs(root)
317-
except OSError as e:
318-
if e.errno == errno.EEXIST and os.path.isdir(root):
319-
pass
320-
else:
321-
raise e
321+
utils.mkdir(root)
322322

323-
prefix = str(time.time())
324-
zip_file_path = os.path.join(root, prefix + file_name + '.zip')
325323
repo_url = _get_repo_url()
326324
if repo_url[-1] != '/':
327325
repo_url = repo_url + '/'
328326
download(_url_format.format(repo_url=repo_url, file_name=file_name),
329-
path=zip_file_path,
330-
overwrite=True)
331-
with zipfile.ZipFile(zip_file_path) as zf:
327+
path=temp_zip_file_path, overwrite=True)
328+
with zipfile.ZipFile(temp_zip_file_path) as zf:
332329
if not os.path.exists(file_path):
333-
zf.extractall(root)
334-
try:
335-
os.remove(zip_file_path)
336-
except OSError as e:
337-
# file has already been removed.
338-
if e.errno == 2:
339-
pass
340-
else:
341-
raise e
330+
utils.mkdir(temp_root)
331+
zf.extractall(temp_root)
332+
os.replace(temp_file_path, file_path)
333+
shutil.rmtree(temp_root)
342334

343335
if check_sha1(file_path, sha1_hash):
344336
return _load_vocab_file(file_path, cls)

src/gluonnlp/utils/files.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# pylint:disable=redefined-outer-name,logging-format-interpolation
1818
"""Utility functions for files."""
1919

20-
__all__ = ['mkdir', 'glob']
20+
__all__ = ['mkdir', 'glob', 'remove']
2121

2222
import os
2323
import warnings
@@ -45,8 +45,28 @@ def glob(url, separator=','):
4545
result.extend(_glob.glob(os.path.expanduser(pattern.strip())))
4646
return result
4747

48+
def remove(filename):
49+
"""Remove a file
50+
51+
Parameters
52+
----------
53+
filename : str
54+
The name of the target file to remove
55+
"""
56+
if C.S3_PREFIX in filename:
57+
msg = 'Removing objects on S3 is not supported: {}'.format(filename)
58+
raise NotImplementedError(msg)
59+
try:
60+
os.remove(filename)
61+
except OSError as e:
62+
# file has already been removed.
63+
if e.errno == 2:
64+
pass
65+
else:
66+
raise e
67+
4868
def mkdir(dirname):
49-
"""Create directory.
69+
"""Create a directory.
5070
5171
Parameters
5272
----------

tests/unittest/test_datasets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io
2222
import random
2323
import warnings
24+
import threading
2425

2526
from flaky import flaky
2627
import mxnet as mx
@@ -702,6 +703,7 @@ def test_numpy_dataset():
702703
(nlp.data.GlueMRPC, 'mrpc', 'dev', 408, 3),
703704
(nlp.data.GlueMRPC, 'mrpc', 'test', 1725, 2),
704705
])
706+
705707
@pytest.mark.serial
706708
@pytest.mark.remote_required
707709
def test_glue_data(cls, name, segment, length, fields):
@@ -713,3 +715,19 @@ def test_glue_data(cls, name, segment, length, fields):
713715

714716
for i, x in enumerate(dataset):
715717
assert len(x) == fields, x
718+
719+
@pytest.mark.serial
720+
@pytest.mark.remote_required
721+
def test_parallel_load_pretrained_vocab():
722+
def fn(name):
723+
root = 'test_parallel_load_pretrained_vocab'
724+
_ = nlp.data.utils._load_pretrained_vocab(name, root=root)
725+
threads = []
726+
name = 'openwebtext_book_corpus_wiki_en_uncased'
727+
for _ in range(10):
728+
x = threading.Thread(target=fn, args=(name,))
729+
threads.append(x)
730+
for t in threads:
731+
t.start()
732+
for t in threads:
733+
t.join()

0 commit comments

Comments
 (0)