Skip to content

Commit 12a32b9

Browse files
authored
Speed up Tokenization by optimizing cast_to_python_objects (#523)
* optimisations for cast_to_python_objects * fix has_changed for empty list * add tests for cast_to_python_objects * remove map_all_sequences_to_list
1 parent 90256f2 commit 12a32b9

File tree

4 files changed

+133
-25
lines changed

4 files changed

+133
-25
lines changed

src/nlp/arrow_writer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from .features import Features
3030
from .info import DatasetInfo
3131
from .utils.file_utils import HF_DATASETS_CACHE, hash_url_to_filename
32-
from .utils.py_utils import map_all_sequences_to_lists
3332

3433

3534
logger = logging.getLogger(__name__)
@@ -170,7 +169,6 @@ def write(self, example: Dict[str, Any], writer_batch_size: Optional[int] = None
170169
Args:
171170
example: the Example to add.
172171
"""
173-
example = map_all_sequences_to_lists(example)
174172
self.current_rows.append(example)
175173
self._num_examples += 1
176174
if writer_batch_size is None:
@@ -186,7 +184,6 @@ def write_batch(
186184
Args:
187185
example: the Example to add.
188186
"""
189-
batch_examples = map_all_sequences_to_lists(batch_examples)
190187
if self.pa_writer is None:
191188
self._build_writer(inferred_schema=pa.Table.from_pydict(batch_examples).schema)
192189
pa_table: pa.Table = pa.Table.from_pydict(batch_examples, schema=self._schema)

src/nlp/features.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import tensorflow as tf
3939

4040

41-
def string_to_arrow(type_str: str):
41+
def string_to_arrow(type_str: str) -> pa.DataType:
4242
if type_str not in pa.__dict__:
4343
if str(type_str + "_") not in pa.__dict__:
4444
raise ValueError(
@@ -53,22 +53,75 @@ def string_to_arrow(type_str: str):
5353
return pa.__dict__[arrow_data_type_str]()
5454

5555

56-
def _cast_to_python_objects(obj):
57-
""" Cast numpy/pytorch/tensorflow/pandas objects to python lists. """
56+
def _cast_to_python_objects(obj: Any) -> Tuple[Any, bool]:
57+
"""
58+
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
59+
It works recursively.
60+
61+
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
62+
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
63+
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
64+
65+
Args:
66+
obj: the object (nested struct) to cast
67+
68+
Returns:
69+
casted_obj: the casted object
70+
has_changed (bool): True if the object has been changed, False if it is identical
71+
"""
5872
if isinstance(obj, np.ndarray):
59-
return obj.tolist()
73+
return obj.tolist(), True
6074
elif _torch_available and isinstance(obj, torch.Tensor):
61-
return obj.detach().cpu().numpy().tolist()
75+
return obj.detach().cpu().numpy().tolist(), True
6276
elif _tf_available and isinstance(obj, tf.Tensor):
63-
return obj.numpy().tolist()
77+
return obj.numpy().tolist(), True
78+
elif isinstance(obj, pd.Series):
79+
return obj.values.tolist(), True
6480
elif isinstance(obj, pd.DataFrame):
65-
return obj.values.tolist()
81+
return obj.to_dict("list"), True
82+
elif isinstance(obj, dict):
83+
output = {}
84+
has_changed = False
85+
for k, v in obj.items():
86+
casted_v, has_changed_v = _cast_to_python_objects(v)
87+
has_changed |= has_changed_v
88+
output[k] = casted_v
89+
return output if has_changed else obj, has_changed
90+
elif isinstance(obj, (list, tuple)):
91+
if len(obj) > 0:
92+
for first_elmt in obj:
93+
if first_elmt is not None:
94+
break
95+
casted_first_elmt, has_changed_first_elmt = _cast_to_python_objects(first_elmt)
96+
if has_changed_first_elmt:
97+
return [_cast_to_python_objects(elmt)[0] for elmt in obj], True
98+
else:
99+
if isinstance(obj, list):
100+
return obj, False
101+
else:
102+
return list(obj), True
103+
else:
104+
return obj if isinstance(obj, list) else [], isinstance(obj, tuple)
66105
else:
67-
return obj
106+
return obj, False
107+
108+
109+
def cast_to_python_objects(obj: Any) -> Any:
110+
"""
111+
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
112+
It works recursively.
113+
114+
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
115+
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
116+
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
68117
118+
Args:
119+
obj: the object (nested struct) to cast
69120
70-
def cast_to_python_objects(obj):
71-
return utils.map_nested(_cast_to_python_objects, obj, map_list=True, map_tuple=True, map_numpy=False)
121+
Returns:
122+
casted_obj: the casted object
123+
"""
124+
return _cast_to_python_objects(obj)[0]
72125

73126

74127
@dataclass

src/nlp/utils/py_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,6 @@
4141
memoize = functools.lru_cache
4242

4343

44-
def map_all_sequences_to_lists(data_struct):
45-
# Could add support for more exotic data_struct, like OrderedDict
46-
def sequences_to_list(seq):
47-
if isinstance(seq, (tuple, np.ndarray)):
48-
return list(seq)
49-
else:
50-
return seq
51-
52-
return map_nested(sequences_to_list, data_struct)
53-
54-
5544
def size_str(size_in_bytes):
5645
"""Returns a human readable size string.
5746

tests/test_features.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from unittest import TestCase
2+
from unittest.mock import patch
3+
4+
import numpy as np
5+
import pandas as pd
26

37
from nlp.arrow_dataset import Dataset
4-
from nlp.features import Features, Sequence, Value
8+
from nlp.features import Features, Sequence, Value, _cast_to_python_objects, cast_to_python_objects
9+
10+
from .utils import require_tf, require_torch
511

612

713
class FeaturesTest(TestCase):
@@ -24,3 +30,66 @@ def test_from_arrow_schema_with_sequence(self):
2430
self.assertEqual(original_features.type, new_features.type)
2531
self.assertDictEqual(dset[0], new_dset[0])
2632
self.assertDictEqual(dset[:], new_dset[:])
33+
34+
def test_cast_to_python_objects_list(self):
35+
obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
36+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
37+
casted_obj = cast_to_python_objects(obj)
38+
self.assertDictEqual(casted_obj, expected_obj)
39+
40+
def test_cast_to_python_objects_tuple(self):
41+
obj = {"col_1": [{"vec": (1, 2, 3), "txt": "foo"}] * 3, "col_2": [(1, 2), (3, 4), (5, 6)]}
42+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
43+
casted_obj = cast_to_python_objects(obj)
44+
self.assertDictEqual(casted_obj, expected_obj)
45+
46+
def test_cast_to_python_objects_numpy(self):
47+
obj = {"col_1": [{"vec": np.arange(1, 4), "txt": "foo"}] * 3, "col_2": np.arange(1, 7).reshape(3, 2)}
48+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
49+
casted_obj = cast_to_python_objects(obj)
50+
self.assertDictEqual(casted_obj, expected_obj)
51+
52+
def test_cast_to_python_objects_series(self):
53+
obj = {
54+
"col_1": pd.Series([{"vec": [1, 2, 3], "txt": "foo"}] * 3),
55+
"col_2": pd.Series([[1, 2], [3, 4], [5, 6]]),
56+
}
57+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
58+
casted_obj = cast_to_python_objects(obj)
59+
self.assertDictEqual(casted_obj, expected_obj)
60+
61+
def test_cast_to_python_objects_dataframe(self):
62+
obj = pd.DataFrame({"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]})
63+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
64+
casted_obj = cast_to_python_objects(obj)
65+
self.assertDictEqual(casted_obj, expected_obj)
66+
67+
@require_torch
68+
def test_cast_to_python_objects_torch(self):
69+
import torch
70+
71+
obj = {
72+
"col_1": [{"vec": torch.Tensor(np.arange(1, 4)), "txt": "foo"}] * 3,
73+
"col_2": torch.Tensor(np.arange(1, 7).reshape(3, 2)),
74+
}
75+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
76+
casted_obj = cast_to_python_objects(obj)
77+
self.assertDictEqual(casted_obj, expected_obj)
78+
79+
@require_tf
80+
def test_cast_to_python_objects_tf(self):
81+
import tensorflow as tf
82+
83+
obj = {
84+
"col_1": [{"vec": tf.constant(np.arange(1, 4)), "txt": "foo"}] * 3,
85+
"col_2": tf.constant(np.arange(1, 7).reshape(3, 2)),
86+
}
87+
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
88+
casted_obj = cast_to_python_objects(obj)
89+
self.assertDictEqual(casted_obj, expected_obj)
90+
91+
@patch("nlp.features._cast_to_python_objects", side_effect=_cast_to_python_objects)
92+
def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast):
93+
obj = {"col_1": [[1, 2], [3, 4], [5, 6]]}
94+
cast_to_python_objects(obj)
95+
self.assertEqual(mocked_cast.call_count, 4) # 4 = depth of obj

0 commit comments

Comments
 (0)