Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 77b16a5

Browse files
author
Chris de Vries
committed
Introduce faster tokenizer for BERT
This change introduces a new tokenizer for BERT that is 3.5x faster on a 2017 13 inch MacBook pro. It was tested by tokenizing the test string u"UNwant\u00E9d,running" from test_transforms.py::bert_tokenizer 100,000 times using the timeit module. The existing implementation with the Cython optmized wordpiece took 5.56 seconds and the new implementation took 1.58 seconds. The changes were originally authored by Eric Lind <[email protected]> and this commit integrates them with Gluon NLP.
1 parent 726281e commit 77b16a5

File tree

4 files changed

+316
-218
lines changed

4 files changed

+316
-218
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,6 @@ def find_version(*file_paths):
8989
],
9090
},
9191
ext_modules=[
92-
Extension('gluonnlp.data.wordpiece', sources=['src/gluonnlp/data/wordpiece.pyx']),
92+
Extension('gluonnlp.data.fast_bert_tokenizer', sources=['src/gluonnlp/data/fast_bert_tokenizer.pyx']),
9393
],
9494
)
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""Used to tokenize text for use with a BERT model."""
2+
3+
from typing import List, Dict, Tuple
4+
import unicodedata
5+
6+
7+
def whitespace_tokenize(text):
8+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
9+
return text.strip().split()
10+
11+
12+
cdef class BasicTokenizer:
13+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
14+
cdef public bint lower
15+
16+
def __init__(self, lower=True):
17+
"""Constructs a BasicTokenizer.
18+
19+
:param do_lower_case: Whether to lower case the input.
20+
"""
21+
self.lower = lower
22+
23+
def tokenize(self, text) -> List[str]:
24+
"""Tokenizes a piece of text."""
25+
# Developments notes:
26+
# - The original BERT code loops over every char in pure Python 4 times
27+
# (several more times if you include loops that are happening inside built-ins).
28+
# This optimized version uses generators and only loops over each char explicitly twice.
29+
# - This runs in two separate steps because I thought it would be better to apply
30+
# `lower` and do accent normalization on the whole string at once rather than parts.
31+
# In Python this limits the amount of looping so it provides a speedup. But in Cython
32+
# that may not actually be true.
33+
34+
# Step 1: normalize whitespace, filter control characters, and add spaces around
35+
# Chinese characters.
36+
step1_text = "".join(_step1(text)).strip()
37+
if self.lower:
38+
step1_text = step1_text.lower()
39+
40+
# Normalize unicode characters to strip accents
41+
# This isn't part of either step1 or step2 because it runs on the entire
42+
# string and any looping over chars takes place in a built-in C loop
43+
# that is likely more optimized than anything that I can write here.
44+
step1_text = unicodedata.normalize("NFD", step1_text)
45+
46+
# Step 2: filter non-spacing marks (Mn unicode category) and
47+
# add spaces around any punctuation.
48+
# This is pretty simple in comparison to the other step.
49+
output_tokens = "".join(_step2(step1_text)).split()
50+
return output_tokens
51+
52+
53+
cdef class WordpieceTokenizer:
54+
"""Runs WordPiece tokenziation."""
55+
56+
cdef public vocab
57+
cdef public str unk_token
58+
cdef public long max_input_chars_per_word
59+
60+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
61+
self.vocab = vocab
62+
self.unk_token = unk_token
63+
self.max_input_chars_per_word = max_input_chars_per_word
64+
65+
def tokenize(self, text) -> List[str]:
66+
"""Tokenizes a piece of text into its word pieces.
67+
68+
This uses a greedy longest-match-first algorithm to perform tokenization
69+
using the given vocabulary.
70+
71+
For example:
72+
input = "unaffable"
73+
output = ["un", "##aff", "##able"]
74+
75+
:param text: A single token or whitespace separated tokens. This should have
76+
already been passed through `BasicTokenizer.
77+
:returns: A list of wordpiece tokens.
78+
"""
79+
cdef long max_input_chars_per_word = self.max_input_chars_per_word
80+
cdef:
81+
bint is_bad
82+
long start
83+
long end
84+
Py_ssize_t n_chars
85+
86+
output_tokens = []
87+
for token in whitespace_tokenize(text):
88+
chars = list(token)
89+
n_chars = len(chars)
90+
if n_chars > max_input_chars_per_word:
91+
output_tokens.append(self.unk_token)
92+
continue
93+
94+
is_bad = False
95+
start = 0
96+
sub_tokens = []
97+
while start < n_chars:
98+
end = n_chars
99+
cur_substr = None
100+
while start < end:
101+
substr = "".join(chars[start:end])
102+
if start > 0:
103+
# Now it's a subword
104+
substr = "##" + substr
105+
if substr in self.vocab:
106+
cur_substr = substr
107+
break
108+
end -= 1
109+
if cur_substr is None:
110+
is_bad = True
111+
break
112+
sub_tokens.append(cur_substr)
113+
start = end
114+
115+
if is_bad:
116+
output_tokens.append(self.unk_token)
117+
else:
118+
output_tokens.extend(sub_tokens)
119+
return output_tokens
120+
121+
122+
def _step1(str text):
123+
"""First step in pre-processing test for BERT.
124+
125+
This function yields unicode characters while, normalizing all whitespace to spaces,
126+
filtering control characters, and adding spaces around chinese characters.
127+
"""
128+
cdef bint prev_ch_whitespace = False
129+
cdef str ch
130+
cdef str cat
131+
cdef Py_UCS4 cp
132+
133+
for ch in text:
134+
cp = <Py_UCS4>ch # Casting this here removes the need for some extra error checking in the loop.
135+
136+
# `is_control` used unicodedata.category for every character that's not \t, \n, or \r
137+
# which is basically everything. So it's better to just call it on everything
138+
# to begin with and pass the result around.
139+
cat = unicodedata.category(ch)
140+
if cp == 0 or cp == 0xfffd or _is_control(cp, cat):
141+
continue
142+
if _is_whitespace(cp, cat):
143+
yield " "
144+
prev_ch_whitespace = True
145+
else:
146+
# From the original BERT code:
147+
# ---------------------------
148+
# This was added on November 1st, 2018 for the multilingual and Chinese
149+
# models. This is also applied to the English models now, but it doesn't
150+
# matter since the English models were not trained on any Chinese data
151+
# and generally don't have any Chinese data in them (there are Chinese
152+
# characters in the vocabulary because Wikipedia does have some Chinese
153+
# words in the English Wikipedia.).
154+
155+
# NB: Our regression tests will fail if we get rid of this because
156+
# our dev datasets have chinese characters in them.
157+
# I have no idea if this is important for production or not
158+
if _is_chinese_char(cp):
159+
# Add whitespace around any CJK character.
160+
if not prev_ch_whitespace:
161+
yield " "
162+
yield ch
163+
yield " "
164+
else:
165+
yield ch
166+
prev_ch_whitespace = False
167+
168+
169+
def _step2(str text):
170+
"""After encoding normalization, whitespace normalization, chinese character normalization,
171+
and accent stripping, this step runs and filters non-spacing marks (Mn unicode category) and
172+
adds spaces around any punctuation.
173+
"""
174+
cdef str ch
175+
cdef str cat
176+
177+
for ch in text:
178+
cat = unicodedata.category(ch)
179+
# Filter some chars (non-spacing mark)
180+
if cat == "Mn":
181+
continue
182+
# Add whitespace around any punctuation
183+
if _is_punctuation(ch, cat):
184+
yield " "
185+
yield ch
186+
yield " "
187+
else:
188+
yield ch
189+
190+
191+
cdef inline bint _is_punctuation(Py_UCS4 cp, str cat):
192+
"""Checks whether `cp` is a punctuation character.
193+
194+
We treat all non-letter/number ASCII as punctuation.
195+
Characters such as "^", "$", and "`" are not in the Unicode
196+
Punctuation class but we treat them as punctuation anyways, for
197+
consistency.
198+
"""
199+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
200+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
201+
return True
202+
if cat.startswith("P"):
203+
return True
204+
return False
205+
206+
207+
cdef inline bint _is_control(Py_UCS4 ch, str cat):
208+
"""Checks whether `ch` is a control character."""
209+
# Some of these are technically control characters but we count them as whitespace
210+
if ch == u"\t" or ch == u"\n" or ch == u"\r":
211+
return False
212+
if cat in ("Cc", "Cf"):
213+
return True
214+
return False
215+
216+
217+
cdef inline bint _is_whitespace(Py_UCS4 ch, str cat):
218+
"""Checks whether `chars` is a whitespace character.
219+
220+
\t, \n, and \r are technically control characters but we treat them
221+
as whitespace since they are generally considered as such.
222+
"""
223+
if ch == u" " or ch == u"\t" or ch == u"\n" or ch == u"\r":
224+
return True
225+
if cat == "Zs":
226+
return True
227+
return False
228+
229+
230+
cdef inline bint _is_chinese_char(Py_UCS4 cp):
231+
"""Checks whether CP is the codepoint of a CJK character.
232+
233+
This defines a "chinese character" as anything in the CJK Unicode block:
234+
https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
235+
236+
Note that the CJK Unicode block is NOT all Japanese and Korean characters,
237+
despite its name. The modern Korean Hangul alphabet is a different block,
238+
as is Japanese Hiragana and Katakana. Those alphabets are used to write
239+
space-separated words, so they are not treated specially and handled
240+
like the all of the other languages.
241+
"""
242+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
243+
(cp >= 0x3400 and cp <= 0x4DBF) or #
244+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
245+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
246+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
247+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
248+
(cp >= 0xF900 and cp <= 0xFAFF) or #
249+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
250+
return True
251+
return False
252+
253+
254+
# Public functions for testing
255+
def is_punctuation(Py_UCS4 cp, str cat):
256+
return _is_punctuation(cp, cat)
257+
258+
def is_control(Py_UCS4 ch, str cat):
259+
return _is_control(ch, cat)
260+
261+
def is_whitespace(Py_UCS4 ch, str cat):
262+
return _is_whitespace(ch, cat)
263+
264+
def is_chinese_char(Py_UCS4 cp):
265+
return _is_chinese_char(cp)
266+

0 commit comments

Comments
 (0)