Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/gluonnlp/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
word_embedding_evaluation)
from .candidate_sampler import *
from .conll import *
from .glue import *
from .corpora import *
from .dataloader import *
from .dataset import *
Expand Down
25 changes: 22 additions & 3 deletions src/gluonnlp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import io
import os
import warnings
import bisect
import numpy as np

Expand Down Expand Up @@ -116,7 +117,7 @@ class TSVDataset(SimpleDataset):
"""
def __init__(self, filename, encoding='utf8',
sample_splitter=line_splitter, field_separator=Splitter('\t'),
num_discard_samples=0, field_indices=None):
num_discard_samples=0, field_indices=None, allow_missing=False):
assert sample_splitter, 'sample_splitter must be specified.'

if not isinstance(filename, (tuple, list)):
Expand All @@ -128,6 +129,7 @@ def __init__(self, filename, encoding='utf8',
self._field_separator = field_separator
self._num_discard_samples = num_discard_samples
self._field_indices = field_indices
self._allow_missing = allow_missing
super(TSVDataset, self).__init__(self._read())

def _should_discard(self):
Expand All @@ -138,7 +140,11 @@ def _should_discard(self):
def _field_selector(self, fields):
if not self._field_indices:
return fields
return [fields[i] for i in self._field_indices]
try:
result = [fields[i] for i in self._field_indices]
except IndexError as e:
raise(IndexError('%s. Fields = %s'%(e.message, str(fields))))
return result

def _read(self):
all_samples = []
Expand All @@ -147,7 +153,20 @@ def _read(self):
content = fin.read()
samples = (s for s in self._sample_splitter(content) if not self._should_discard())
if self._field_separator:
samples = [self._field_selector(self._field_separator(s)) for s in samples]
if not self._allow_missing:
samples = [self._field_selector(self._field_separator(s)) for s in samples]
else:
selected_samples = []
num_missing = 0
for s in samples:
try:
fields = self._field_separator(s)
selected_samples.append(self._field_selector(fields))
except IndexError:
num_missing += 1
if num_missing > 0:
warnings.warn('%d incomplete samples in %s'%(num_missing, filename))
samples = selected_samples
all_samples += samples
return all_samples

Expand Down
283 changes: 283 additions & 0 deletions src/gluonnlp/data/glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=
"""CoNLL format corpora."""

__all__ = ['GlueCoLA', 'GlueSST2', 'GlueSTSB', 'GlueQQP']

import codecs
import glob
import gzip
import zipfile
import io
import os
import shutil
import tarfile

from .dataset import TSVDataset
from mxnet.gluon.utils import download, check_sha1, _get_repo_file_url

from .. import _constants as C
from .registry import register
from ..base import get_home_dir

_glue_s3_uri = 's3://apache-mxnet/gluon/dataset/Glue/'

class _GlueDataset(TSVDataset):
def __init__(self, root, data_file, **kwargs):
root = os.path.expanduser(root)
if not os.path.isdir(root):
os.makedirs(root)
self._root = root
segment, zip_hash, data_hash = data_file
filename = os.path.join(self._root, '%s.tsv' % segment)
self._get_data(segment, zip_hash, data_hash, filename)
super(_GlueDataset, self).__init__(filename, **kwargs)

def _get_data(self, segment, zip_hash, data_hash, filename):
data_filename = '%s-%s.zip' % (segment, data_hash[:8])
if not os.path.exists(filename) or not check_sha1(filename, data_hash):
download(_get_repo_file_url(self._repo_dir(), data_filename),
path=self._root, sha1_hash=zip_hash)
# unzip
downloaded_path = os.path.join(self._root, data_filename)
with zipfile.ZipFile(downloaded_path, 'r') as zf:
# skip dir structures in the zip
for zip_info in zf.infolist():
if zip_info.filename[-1] == '/':
continue
zip_info.filename = os.path.basename(zip_info.filename)
zf.extract(zip_info, self._root)

def _repo_dir(self):
raise NotImplementedError

@register(segment=['train', 'dev', 'test'])
class GlueCoLA(_GlueDataset):
"""CoNLL2000 Part-of-speech (POS) tagging and chunking joint task dataset.

Each sample has three fields: word, POS tag, chunk label.

From
https://www.clips.uantwerpen.be/cola2000/chunking/

Parameters
----------
segment : {'train', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/cola2000'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.

Examples
--------
>>> cola = gluonnlp.data.GlueCoLA('test', root='./datasets/cola')
-etc-
>>> len(cola)
1063
>>> len(cola[0])
1
>>> cola[0][0]
['Bill whistled past the house.']
"""
def __init__(self, segment='train',
root=os.path.join(get_home_dir(), 'datasets', 'glue_cola'),
return_all_fields=False):
self._data_file = {'train': ('train', '662227ed4d98bb96b3495234b650e37826a5ef72',
'7760a9c4b1fb05f6d003475cc7bb0d0118875190'),
'dev': ('dev', '6f3f5252b004eab187bf22ab5b0af31e739d3a3f',
'30ece4de38e1929545c4154d4c71ad297c7f54b4'),
'test': ('test', 'b88180515ad041935793e74e3a76470b0c1b2c50',
'f38b43d31bb06accf82a3d5b2fe434a752a74c9f')}
data_file = self._data_file[segment]
if segment in ['train', 'dev']:
A_IDX, LABEL_IDX = 3, 1
field_indices = [A_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 0
elif segment == 'test':
A_IDX = 1
field_indices = [A_IDX] if not return_all_fields else None
num_discard_samples = 1

super(GlueCoLA, self).__init__(root, data_file,
num_discard_samples=num_discard_samples, field_indices=field_indices)

def _repo_dir(self):
return 'gluon/dataset/GLUE/CoLA'

@register(segment=['train', 'dev', 'test'])
class GlueSST2(_GlueDataset):
"""CoNLL2000 Part-of-speech (POS) tagging and chunking joint task dataset.

Each sample has three fields: word, POS tag, chunk label.

From
https://www.clips.uantwerpen.be/cola2000/chunking/

Parameters
----------
segment : {'train', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/cola2000'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.

Examples
--------
>>> cola = gluonnlp.data.GlueCoLA('test', root='./datasets/cola')
-etc-
>>> len(cola)
1063
>>> len(cola[0])
1
>>> cola[0][0]
['Bill whistled past the house.']
"""
def __init__(self, segment='train',
root=os.path.join(get_home_dir(), 'datasets', 'glue_sst'),
return_all_fields=False):
self._data_file = {'train': ('train', 'bcde781bed5caa30d5e9a9d24e5c826965ed02a2',
'ffbb67a55e27525e925b79fee110ca19585d70ca'),
'dev': ('dev', '85698e465ff6573fb80d0b34229c76df84cd766b',
'e166f986cec68fd4cca0ae5ce5869b917f88a2fa'),
'test': ('test', 'efac1c275553ed78500e9b8d8629408f5f867b20',
'3ce8041182bf82dbbbbfe13738b39d3c69722744')}
data_file = self._data_file[segment]
if segment in ['train', 'dev']:
A_IDX, LABEL_IDX = 0, 1
field_indices = [A_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
A_IDX = 1
field_indices = [A_IDX] if not return_all_fields else None
num_discard_samples = 1

super(GlueSST2, self).__init__(root, data_file,
num_discard_samples=num_discard_samples, field_indices=field_indices)

def _repo_dir(self):
return 'gluon/dataset/GLUE/SST-2'

@register(segment=['train', 'dev', 'test'])
class GlueSTSB(_GlueDataset):
"""CoNLL2000 Part-of-speech (POS) tagging and chunking joint task dataset.

Each sample has three fields: word, POS tag, chunk label.

From
https://www.clips.uantwerpen.be/cola2000/chunking/

Parameters
----------
segment : {'train', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/cola2000'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.

Examples
--------
>>> cola = gluonnlp.data.GlueCoLA('test', root='./datasets/cola')
-etc-
>>> len(cola)
1063
>>> len(cola[0])
1
>>> cola[0][0]
['Bill whistled past the house.']
"""
def __init__(self, segment='train',
root=os.path.join(get_home_dir(), 'datasets', 'glue_sst'),
return_all_fields=False):
self._data_file = {'train': ('train', '9378bd341576810730a5c666ed03122e4c5ecc9f',
'501e55248c6db2a3f416c75932a63693000a82bc'),
'dev': ('dev', '529c3e7c36d0807d88d0b2a5d4b954809ddd4228',
'f8bcc33b01dfa2e9ba85601d0140020735b8eff3'),
'test': ('test', '6284872d6992d8ec6d96320af89c2f46ac076d18',
'36553e5e2107b817257232350e95ff0f3271d844')}
data_file = self._data_file[segment]
if segment in ['train', 'dev']:
A_IDX, B_IDX, LABEL_IDX = 7, 8, 9
field_indices = [A_IDX, B_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
A_IDX, B_IDX, = 7, 8
field_indices = [A_IDX, B_IDX] if not return_all_fields else None
num_discard_samples = 1

super(GlueSTSB, self).__init__(root, data_file,
num_discard_samples=num_discard_samples, field_indices=field_indices)

def _repo_dir(self):
return 'gluon/dataset/GLUE/STS-B'

@register(segment=['train', 'dev', 'test'])
class GlueQQP(_GlueDataset):
"""CoNLL2000 Part-of-speech (POS) tagging and chunking joint task dataset.

Each sample has three fields: word, POS tag, chunk label.

From
https://www.clips.uantwerpen.be/cola2000/chunking/

Parameters
----------
segment : {'train', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/cola2000'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.

Examples
--------
>>> cola = gluonnlp.data.GlueCoLA('test', root='./datasets/cola')
-etc-
>>> len(cola)
1063
>>> len(cola[0])
1
>>> cola[0][0]
['Bill whistled past the house.']
"""
def __init__(self, segment='train',
root=os.path.join(get_home_dir(), 'datasets', 'glue_sst'),
return_all_fields=False):
self._data_file = {'train': ('train', '494f280d651f168ad96d6cd05f8d4ddc6be73ce9',
'95c01e711ac8dbbda8f67f3a4291e583a72b6988'),
'dev': ('dev', '9957b60c4c62f9b98ec91b26a9d43529d2ee285d',
'755e0bf2899b8ad315d4bd7d4c85ec51beee5ad0'),
'test': ('test', '1e325cc5dbeeb358f9429c619ebe974fc2d1a8ca',
'0f50d1a62dd51fe932ba91be08238e47c3e2504a')}
data_file = self._data_file[segment]
if segment in ['train', 'dev']:
A_IDX, B_IDX, LABEL_IDX = 3, 4, 5
field_indices = [A_IDX, B_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
A_IDX, B_IDX, = 1, 2
field_indices = [A_IDX, B_IDX] if not return_all_fields else None
num_discard_samples = 1
# QQP may include broken samples
super(GlueQQP, self).__init__(root, data_file,
num_discard_samples=num_discard_samples, field_indices=field_indices,
allow_missing=True)

def _repo_dir(self):
return 'gluon/dataset/GLUE/QQP'
43 changes: 43 additions & 0 deletions tests/unittest/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,46 @@ def test_numpy_dataset():
assert np.all(dataset[1][1] == b[1])
dataset_b = dataset.get_field('b')
assert np.all(dataset_b == b)

@pytest.mark.parametrize('segment,length,fields', [
(nlp.data.GlueCoLA, 'cola', 'train', 8551, 2)
(nlp.data.GlueCoLA, 'cola', 'dev', 1043, 2)
(nlp.data.GlueCoLA, 'cola', 'test', 1063, 1)
# source: https://arxiv.org/pdf/1804.07461.pdf
(nlp.data.GlueSST2, 'sst', 'train', 67349, 2)
(nlp.data.GlueSST2, 'sst', 'dev', 872, 2)
(nlp.data.GlueSST2, 'sst', 'test', 1821, 1)
# source: http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark
(nlp.data.GlueSTSB, 'sts', 'train', 5749, 3)
(nlp.data.GlueSTSB, 'sts', 'dev', 1500, 3)
(nlp.data.GlueSTSB, 'sts', 'test', 1379, 2)
# source: https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs
(nlp.data.GlueQQP, 'qqp', 'train', 363849, 3)
(nlp.data.GlueQQP, 'qqp', 'dev', 40430, 3)
(nlp.data.GlueQQP, 'qqp', 'test', 390965, 2)
])
@pytest.mark.serial
@pytest.mark.remote_required
def test_glue_data(cls, name, segment, length, fields):
dataset = cls(segment=segment, root=os.path.join(
'tests', 'externaldata', 'glue', name))
assert len(dataset) == length, len(dataset)

for i, x in enumerate(dataset):
assert len(x) == fields, x

test_glue_data(nlp.data.GlueCoLA, 'cola', 'train', 8551, 2)
test_glue_data(nlp.data.GlueCoLA, 'cola', 'dev', 1043, 2)
test_glue_data(nlp.data.GlueCoLA, 'cola', 'test', 1063, 1)
# source: https://arxiv.org/pdf/1804.07461.pdf
test_glue_data(nlp.data.GlueSST2, 'sst', 'train', 67349, 2)
test_glue_data(nlp.data.GlueSST2, 'sst', 'dev', 872, 2)
test_glue_data(nlp.data.GlueSST2, 'sst', 'test', 1821, 1)
# source: http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark
test_glue_data(nlp.data.GlueSTSB, 'sts', 'train', 5749, 3)
test_glue_data(nlp.data.GlueSTSB, 'sts', 'dev', 1500, 3)
test_glue_data(nlp.data.GlueSTSB, 'sts', 'test', 1379, 2)
# source: https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs
test_glue_data(nlp.data.GlueQQP, 'qqp', 'train', 363849, 3)
test_glue_data(nlp.data.GlueQQP, 'qqp', 'dev', 40430, 3)
test_glue_data(nlp.data.GlueQQP, 'qqp', 'test', 390965, 2)