Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions src/gluonnlp/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@

"""Utility classes and functions. They help organize and keep statistics of datasets."""
import collections
import errno
import os
import tarfile
import zipfile
import time
import random
import sys
import shutil

import numpy as np
from mxnet.gluon.data import SimpleDataset
from mxnet.gluon.utils import _get_repo_url, check_sha1, download

from .. import _constants as C
from .. import utils

__all__ = [
'Counter', 'count_tokens', 'concat_sequence', 'slice_sequence', 'train_valid_split',
Expand Down Expand Up @@ -303,6 +305,11 @@ def _load_pretrained_vocab(name, root, cls=None):
root = os.path.expanduser(root)
file_path = os.path.join(root, file_name + '.vocab')
sha1_hash = _vocab_sha1[name]

temp_num = str(random.Random().randint(1, sys.maxsize))
temp_root = os.path.join(root, temp_num)
temp_file_path = os.path.join(temp_root, file_name + '.vocab')
temp_zip_file_path = os.path.join(root, temp_num + file_name + '.zip')
if os.path.exists(file_path):
if check_sha1(file_path, sha1_hash):
return _load_vocab_file(file_path, cls)
Expand All @@ -311,34 +318,19 @@ def _load_pretrained_vocab(name, root, cls=None):
else:
print('Vocab file is not found. Downloading.')

if not os.path.exists(root):
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(root):
pass
else:
raise e
utils.mkdir(root)

prefix = str(time.time())
zip_file_path = os.path.join(root, prefix + file_name + '.zip')
repo_url = _get_repo_url()
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
path=temp_zip_file_path, overwrite=True)
with zipfile.ZipFile(temp_zip_file_path) as zf:
if not os.path.exists(file_path):
zf.extractall(root)
try:
os.remove(zip_file_path)
except OSError as e:
# file has already been removed.
if e.errno == 2:
pass
else:
raise e
utils.mkdir(temp_root)
zf.extractall(temp_root)
os.replace(temp_file_path, file_path)
shutil.rmtree(temp_root)

if check_sha1(file_path, sha1_hash):
return _load_vocab_file(file_path, cls)
Expand Down
24 changes: 22 additions & 2 deletions src/gluonnlp/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=redefined-outer-name,logging-format-interpolation
"""Utility functions for files."""

__all__ = ['mkdir', 'glob']
__all__ = ['mkdir', 'glob', 'remove']

import os
import warnings
Expand Down Expand Up @@ -45,8 +45,28 @@ def glob(url, separator=','):
result.extend(_glob.glob(os.path.expanduser(pattern.strip())))
return result

def remove(filename):
"""Remove a file

Parameters
----------
filename : str
The name of the target file to remove
"""
if C.S3_PREFIX in filename:
msg = 'Removing objects on S3 is not supported: {}'.format(filename)
raise NotImplementedError(msg)
try:
os.remove(filename)
except OSError as e:
# file has already been removed.
if e.errno == 2:
pass
else:
raise e

def mkdir(dirname):
"""Create directory.
"""Create a directory.

Parameters
----------
Expand Down
18 changes: 18 additions & 0 deletions tests/unittest/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io
import random
import warnings
import threading

from flaky import flaky
import mxnet as mx
Expand Down Expand Up @@ -702,6 +703,7 @@ def test_numpy_dataset():
(nlp.data.GlueMRPC, 'mrpc', 'dev', 408, 3),
(nlp.data.GlueMRPC, 'mrpc', 'test', 1725, 2),
])

@pytest.mark.serial
@pytest.mark.remote_required
def test_glue_data(cls, name, segment, length, fields):
Expand All @@ -713,3 +715,19 @@ def test_glue_data(cls, name, segment, length, fields):

for i, x in enumerate(dataset):
assert len(x) == fields, x

@pytest.mark.serial
@pytest.mark.remote_required
def test_parallel_load_pretrained_vocab():
def fn(name):
root = 'test_parallel_load_pretrained_vocab'
_ = nlp.data.utils._load_pretrained_vocab(name, root=root)
threads = []
name = 'openwebtext_book_corpus_wiki_en_uncased'
for _ in range(10):
x = threading.Thread(target=fn, args=(name,))
threads.append(x)
for t in threads:
t.start()
for t in threads:
t.join()