Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 28 additions & 43 deletions datasets/crd3/crd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"""CRD3 dataset"""

from __future__ import absolute_import, division, print_function
import logging

import json
import logging
import os

import nlp
Expand All @@ -46,10 +46,11 @@

_URL = "https://github.com/RevanthRameshkumar/CRD3/archive/master.zip"


def get_train_test_dev_files(files, test_split, train_split, dev_split):
test_files = dev_files = train_files = []
for file in files:
filename = os.path.split(file)[1].split('_')[0]
filename = os.path.split(file)[1].split("_")[0]
if filename in test_split:
test_files.append(file)
elif filename in train_split:
Expand All @@ -59,25 +60,23 @@ def get_train_test_dev_files(files, test_split, train_split, dev_split):
else:
logging.info("skipped file {}".format(file))
return test_files, train_files, dev_files


class CRD3(nlp.GeneratorBasedBuilder):

def _info(self):
return nlp.DatasetInfo(
description=_DESCRIPTION,
features=nlp.Features({
"chunk": nlp.Value("string"),
"chunk_id": nlp.Value("int32"),
"turn_start": nlp.Value("int32"),
"turn_end": nlp.Value("int32"),
"alignment_score": nlp.Value("float32"),
"turn_num": nlp.Value("int32"),
"turns":nlp.features.Sequence({
"names": nlp.Value("string"),
"utterances": nlp.Value("string"),
}),
}),
features=nlp.Features(
{
"chunk": nlp.Value("string"),
"chunk_id": nlp.Value("int32"),
"turn_start": nlp.Value("int32"),
"turn_end": nlp.Value("int32"),
"alignment_score": nlp.Value("float32"),
"turn_num": nlp.Value("int32"),
"turns": nlp.features.Sequence({"names": nlp.Value("string"), "utterances": nlp.Value("string"),}),
}
),
homepage="https://github.com/RevanthRameshkumar/CRD3",
citation=_CITATION,
)
Expand All @@ -89,7 +88,7 @@ def _split_generators(self, dl_manager):
dev_file = os.path.join(path, "CRD3-master", "data", "aligned data", "val_files")
with open(test_file) as f:
test_splits = [file.replace("\n", "") for file in f.readlines()]

with open(train_file) as f:
train_splits = [file.replace("\n", "") for file in f.readlines()]
with open(dev_file) as f:
Expand All @@ -98,29 +97,20 @@ def _split_generators(self, dl_manager):
c3 = "CRD3-master/data/aligned data/c=3"
c4 = "CRD3-master/data/aligned data/c=4"
files = [os.path.join(path, c2, file) for file in sorted(os.listdir(os.path.join(path, c2)))]
files.extend([os.path.join(path, c3, file) for file in sorted(os.listdir(os.path.join(path, c3)))])
files.extend([os.path.join(path, c4, file) for file in sorted(os.listdir(os.path.join(path, c4)))])
files.extend([os.path.join(path, c3, file) for file in sorted(os.listdir(os.path.join(path, c3)))])
files.extend([os.path.join(path, c4, file) for file in sorted(os.listdir(os.path.join(path, c4)))])

test_files, train_files, dev_files = get_train_test_dev_files(files, test_splits, train_splits, dev_splits)

return [
nlp.SplitGenerator(
name=nlp.Split.TRAIN,
gen_kwargs={"files_path": train_files},
),
nlp.SplitGenerator(
name=nlp.Split.TEST,
gen_kwargs={"files_path": test_files},
),
nlp.SplitGenerator(
name=nlp.Split.VALIDATION,
gen_kwargs={"files_path": dev_files},
)
nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"files_path": train_files},),
nlp.SplitGenerator(name=nlp.Split.TEST, gen_kwargs={"files_path": test_files},),
nlp.SplitGenerator(name=nlp.Split.VALIDATION, gen_kwargs={"files_path": dev_files},),
]

def _generate_examples(self, files_path):
"""Yields examples."""

for file in files_path:
with open(file) as f:
data = json.load(f)
Expand All @@ -130,21 +120,16 @@ def _generate_examples(self, files_path):
turn_start = row["ALIGNMENT"]["TURN START"]
turn_end = row["ALIGNMENT"]["TURN END"]
score = row["ALIGNMENT"]["ALIGNMENT SCORE"]
for id2, turn in enumerate(row['TURNS']):
for id2, turn in enumerate(row["TURNS"]):
turn_names = turn["NAMES"]
turn_utterances = turn["UTTERANCES"]
turn_num = turn["NUMBER"]
yield str(id1)+'_'+str(id2), {
"chunk":chunk,
yield str(id1) + "_" + str(id2), {
"chunk": chunk,
"chunk_id": chunk_id,
"turn_start": turn_start,
"turn_end": turn_end,
"alignment_score": score,
"turn_num": turn_num,
"turns": {
"names": turn_names,
"utterances": turn_utterances,
},
"turns": {"names": turn_names, "utterances": turn_utterances,},
}


9 changes: 3 additions & 6 deletions src/nlp/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,6 @@ def _convert_outputs(

map_nested_kwargs = {}
if format_type == "numpy":
import numpy as np

if "copy" not in format_kwargs:
format_kwargs["copy"] = False
command = partial(np.array, **format_kwargs)
Expand All @@ -603,8 +601,7 @@ def identity(x):
return x

command = identity

if isinstance(outputs, (list, tuple)):
if isinstance(outputs, (list, tuple, np.ndarray)):
return command(outputs)
elif isinstance(outputs, pd.DataFrame):
if format_columns is not None and not output_all_columns:
Expand Down Expand Up @@ -678,7 +675,7 @@ def _getitem(
if format_type == "pandas":
outputs = self._data[key].to_pandas(split_blocks=True)
elif format_type in ("numpy", "torch", "tensorflow"):
outputs = self._data[key].to_pandas(split_blocks=True).to_numpy()
outputs = self._data.to_pandas(split_blocks=True).to_dict("list")[key]
else:
outputs = self._data[key].to_pylist()
else:
Expand All @@ -698,7 +695,7 @@ def _getitem(
else:
raise ValueError("Can only get row(s) (int or slice or list[int]) or columns (string).")

if (format_type is not None or format_columns is not None) and not isinstance(key, str):
if format_type is not None or format_columns is not None:
outputs = self._convert_outputs(
outputs,
format_type=format_type,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def _create_dummy_dataset(self, multiple_columns=False):
def test_dummy_dataset(self):
dset = self._create_dummy_dataset()
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

dset = self._create_dummy_dataset(multiple_columns=True)
self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
self.assertEqual(dset[0]["col_1"], 3)
self.assertEqual(dset["col_1"][0], 3)

def test_from_pandas(self):
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
Expand Down Expand Up @@ -78,20 +85,26 @@ def test_from_dict(self):
features = Features({"col_1": Value("string"), "col_2": Value("string")})
self.assertRaises(pa.ArrowTypeError, Dataset.from_dict, data, features=features)

def test_set_format_numpy(self):
def test_set_format_numpy_multiple_columns(self):
dset = self._create_dummy_dataset(multiple_columns=True)
dset.set_format(type="numpy", columns=["col_1"])
self.assertEqual(len(dset[0]), 1)
self.assertIsInstance(dset[0]["col_1"], np.ndarray)
self.assertListEqual(list(dset[0]["col_1"].shape), [])
self.assertEqual(dset[0]["col_1"].item(), 3)
self.assertIsInstance(dset["col_1"], np.ndarray)
self.assertListEqual(list(dset["col_1"].shape), [4])
np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))

dset.reset_format()
with dset.formated_as(type="numpy", columns=["col_1"]):
self.assertEqual(len(dset[0]), 1)
self.assertIsInstance(dset[0]["col_1"], np.ndarray)
self.assertListEqual(list(dset[0]["col_1"].shape), [])
self.assertEqual(dset[0]["col_1"].item(), 3)
self.assertIsInstance(dset["col_1"], np.ndarray)
self.assertListEqual(list(dset["col_1"].shape), [4])
np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))

self.assertEqual(dset.format["type"], "python")
self.assertEqual(dset.format["format_kwargs"], {})
Expand All @@ -115,6 +128,7 @@ def test_set_format_torch(self):
dset.set_format(type="torch", columns=["col_1"])
self.assertEqual(len(dset[0]), 1)
self.assertIsInstance(dset[0]["col_1"], torch.Tensor)
self.assertIsInstance(dset["col_1"], torch.Tensor)
self.assertListEqual(list(dset[0]["col_1"].shape), [])
self.assertEqual(dset[0]["col_1"].item(), 3)

Expand Down Expand Up @@ -694,21 +708,25 @@ def test_format_vectors(self):
for col in columns:
self.assertIsInstance(dset[0][col], (tf.Tensor, tf.RaggedTensor))
self.assertIsInstance(dset[:2][col][0], (tf.Tensor, tf.RaggedTensor)) # not stacked
self.assertIsInstance(dset[col][0], (tf.Tensor, tf.RaggedTensor)) # not stacked

dset.set_format("numpy")
self.assertIsNotNone(dset[0])
self.assertIsNotNone(dset[:2])
for col in columns:
self.assertIsInstance(dset[0][col], np.ndarray)
self.assertIsInstance(dset[:2][col], np.ndarray) # stacked
self.assertIsInstance(dset[col], np.ndarray) # stacked
self.assertEqual(dset[:2]["vec"].shape, (2, 3)) # stacked
self.assertEqual(dset["vec"][:2].shape, (2, 3)) # stacked

dset.set_format("torch", columns=["vec"])
self.assertIsNotNone(dset[0])
self.assertIsNotNone(dset[:2])
# torch.Tensor is only for numerical columns
self.assertIsInstance(dset[0]["vec"], torch.Tensor)
self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor) # not stacked
self.assertIsInstance(dset["vec"][0], torch.Tensor) # not stacked

def test_format_nested(self):
dset = self._create_dummy_dataset()
Expand Down