|
| 1 | +"""Used to tokenize text for use with a BERT model.""" |
| 2 | + |
| 3 | +from typing import List, Dict, Tuple |
| 4 | +import unicodedata |
| 5 | + |
| 6 | + |
| 7 | +def whitespace_tokenize(text): |
| 8 | + """Runs basic whitespace cleaning and splitting on a piece of text.""" |
| 9 | + return text.strip().split() |
| 10 | + |
| 11 | + |
| 12 | +cdef class BasicTokenizer: |
| 13 | + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" |
| 14 | + cdef public bint lower |
| 15 | + |
| 16 | + def __init__(self, lower=True): |
| 17 | + """Constructs a BasicTokenizer. |
| 18 | + |
| 19 | + :param do_lower_case: Whether to lower case the input. |
| 20 | + """ |
| 21 | + self.lower = lower |
| 22 | + |
| 23 | + def tokenize(self, text) -> List[str]: |
| 24 | + """Tokenizes a piece of text.""" |
| 25 | + # Developments notes: |
| 26 | + # - The original BERT code loops over every char in pure Python 4 times |
| 27 | + # (several more times if you include loops that are happening inside built-ins). |
| 28 | + # This optimized version uses generators and only loops over each char explicitly twice. |
| 29 | + # - This runs in two separate steps because I thought it would be better to apply |
| 30 | + # `lower` and do accent normalization on the whole string at once rather than parts. |
| 31 | + # In Python this limits the amount of looping so it provides a speedup. But in Cython |
| 32 | + # that may not actually be true. |
| 33 | + |
| 34 | + # Step 1: normalize whitespace, filter control characters, and add spaces around |
| 35 | + # Chinese characters. |
| 36 | + step1_text = "".join(_step1(text)).strip() |
| 37 | + if self.lower: |
| 38 | + step1_text = step1_text.lower() |
| 39 | + |
| 40 | + # Normalize unicode characters to strip accents |
| 41 | + # This isn't part of either step1 or step2 because it runs on the entire |
| 42 | + # string and any looping over chars takes place in a built-in C loop |
| 43 | + # that is likely more optimized than anything that I can write here. |
| 44 | + step1_text = unicodedata.normalize("NFD", step1_text) |
| 45 | + |
| 46 | + # Step 2: filter non-spacing marks (Mn unicode category) and |
| 47 | + # add spaces around any punctuation. |
| 48 | + # This is pretty simple in comparison to the other step. |
| 49 | + output_tokens = "".join(_step2(step1_text)).split() |
| 50 | + return output_tokens |
| 51 | + |
| 52 | + |
| 53 | +cdef class WordpieceTokenizer: |
| 54 | + """Runs WordPiece tokenziation.""" |
| 55 | + |
| 56 | + cdef public vocab |
| 57 | + cdef public str unk_token |
| 58 | + cdef public long max_input_chars_per_word |
| 59 | + |
| 60 | + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): |
| 61 | + self.vocab = vocab |
| 62 | + self.unk_token = unk_token |
| 63 | + self.max_input_chars_per_word = max_input_chars_per_word |
| 64 | + |
| 65 | + def tokenize(self, text) -> List[str]: |
| 66 | + """Tokenizes a piece of text into its word pieces. |
| 67 | + |
| 68 | + This uses a greedy longest-match-first algorithm to perform tokenization |
| 69 | + using the given vocabulary. |
| 70 | + |
| 71 | + For example: |
| 72 | + input = "unaffable" |
| 73 | + output = ["un", "##aff", "##able"] |
| 74 | + |
| 75 | + :param text: A single token or whitespace separated tokens. This should have |
| 76 | + already been passed through `BasicTokenizer. |
| 77 | + :returns: A list of wordpiece tokens. |
| 78 | + """ |
| 79 | + cdef long max_input_chars_per_word = self.max_input_chars_per_word |
| 80 | + cdef: |
| 81 | + bint is_bad |
| 82 | + long start |
| 83 | + long end |
| 84 | + Py_ssize_t n_chars |
| 85 | + |
| 86 | + output_tokens = [] |
| 87 | + for token in whitespace_tokenize(text): |
| 88 | + chars = list(token) |
| 89 | + n_chars = len(chars) |
| 90 | + if n_chars > max_input_chars_per_word: |
| 91 | + output_tokens.append(self.unk_token) |
| 92 | + continue |
| 93 | + |
| 94 | + is_bad = False |
| 95 | + start = 0 |
| 96 | + sub_tokens = [] |
| 97 | + while start < n_chars: |
| 98 | + end = n_chars |
| 99 | + cur_substr = None |
| 100 | + while start < end: |
| 101 | + substr = "".join(chars[start:end]) |
| 102 | + if start > 0: |
| 103 | + # Now it's a subword |
| 104 | + substr = "##" + substr |
| 105 | + if substr in self.vocab: |
| 106 | + cur_substr = substr |
| 107 | + break |
| 108 | + end -= 1 |
| 109 | + if cur_substr is None: |
| 110 | + is_bad = True |
| 111 | + break |
| 112 | + sub_tokens.append(cur_substr) |
| 113 | + start = end |
| 114 | + |
| 115 | + if is_bad: |
| 116 | + output_tokens.append(self.unk_token) |
| 117 | + else: |
| 118 | + output_tokens.extend(sub_tokens) |
| 119 | + return output_tokens |
| 120 | + |
| 121 | + |
| 122 | +def _step1(str text): |
| 123 | + """First step in pre-processing test for BERT. |
| 124 | + |
| 125 | + This function yields unicode characters while, normalizing all whitespace to spaces, |
| 126 | + filtering control characters, and adding spaces around chinese characters. |
| 127 | + """ |
| 128 | + cdef bint prev_ch_whitespace = False |
| 129 | + cdef str ch |
| 130 | + cdef str cat |
| 131 | + cdef Py_UCS4 cp |
| 132 | + |
| 133 | + for ch in text: |
| 134 | + cp = <Py_UCS4>ch # Casting this here removes the need for some extra error checking in the loop. |
| 135 | + |
| 136 | + # `is_control` used unicodedata.category for every character that's not \t, \n, or \r |
| 137 | + # which is basically everything. So it's better to just call it on everything |
| 138 | + # to begin with and pass the result around. |
| 139 | + cat = unicodedata.category(ch) |
| 140 | + if cp == 0 or cp == 0xfffd or _is_control(cp, cat): |
| 141 | + continue |
| 142 | + if _is_whitespace(cp, cat): |
| 143 | + yield " " |
| 144 | + prev_ch_whitespace = True |
| 145 | + else: |
| 146 | + # From the original BERT code: |
| 147 | + # --------------------------- |
| 148 | + # This was added on November 1st, 2018 for the multilingual and Chinese |
| 149 | + # models. This is also applied to the English models now, but it doesn't |
| 150 | + # matter since the English models were not trained on any Chinese data |
| 151 | + # and generally don't have any Chinese data in them (there are Chinese |
| 152 | + # characters in the vocabulary because Wikipedia does have some Chinese |
| 153 | + # words in the English Wikipedia.). |
| 154 | + |
| 155 | + # NB: Our regression tests will fail if we get rid of this because |
| 156 | + # our dev datasets have chinese characters in them. |
| 157 | + # I have no idea if this is important for production or not |
| 158 | + if _is_chinese_char(cp): |
| 159 | + # Add whitespace around any CJK character. |
| 160 | + if not prev_ch_whitespace: |
| 161 | + yield " " |
| 162 | + yield ch |
| 163 | + yield " " |
| 164 | + else: |
| 165 | + yield ch |
| 166 | + prev_ch_whitespace = False |
| 167 | + |
| 168 | + |
| 169 | +def _step2(str text): |
| 170 | + """After encoding normalization, whitespace normalization, chinese character normalization, |
| 171 | + and accent stripping, this step runs and filters non-spacing marks (Mn unicode category) and |
| 172 | + adds spaces around any punctuation. |
| 173 | + """ |
| 174 | + cdef str ch |
| 175 | + cdef str cat |
| 176 | + |
| 177 | + for ch in text: |
| 178 | + cat = unicodedata.category(ch) |
| 179 | + # Filter some chars (non-spacing mark) |
| 180 | + if cat == "Mn": |
| 181 | + continue |
| 182 | + # Add whitespace around any punctuation |
| 183 | + if _is_punctuation(ch, cat): |
| 184 | + yield " " |
| 185 | + yield ch |
| 186 | + yield " " |
| 187 | + else: |
| 188 | + yield ch |
| 189 | + |
| 190 | + |
| 191 | +cdef inline bint _is_punctuation(Py_UCS4 cp, str cat): |
| 192 | + """Checks whether `cp` is a punctuation character. |
| 193 | + |
| 194 | + We treat all non-letter/number ASCII as punctuation. |
| 195 | + Characters such as "^", "$", and "`" are not in the Unicode |
| 196 | + Punctuation class but we treat them as punctuation anyways, for |
| 197 | + consistency. |
| 198 | + """ |
| 199 | + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or |
| 200 | + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): |
| 201 | + return True |
| 202 | + if cat.startswith("P"): |
| 203 | + return True |
| 204 | + return False |
| 205 | + |
| 206 | + |
| 207 | +cdef inline bint _is_control(Py_UCS4 ch, str cat): |
| 208 | + """Checks whether `ch` is a control character.""" |
| 209 | + # Some of these are technically control characters but we count them as whitespace |
| 210 | + if ch == u"\t" or ch == u"\n" or ch == u"\r": |
| 211 | + return False |
| 212 | + if cat in ("Cc", "Cf"): |
| 213 | + return True |
| 214 | + return False |
| 215 | + |
| 216 | + |
| 217 | +cdef inline bint _is_whitespace(Py_UCS4 ch, str cat): |
| 218 | + """Checks whether `chars` is a whitespace character. |
| 219 | + |
| 220 | + \t, \n, and \r are technically control characters but we treat them |
| 221 | + as whitespace since they are generally considered as such. |
| 222 | + """ |
| 223 | + if ch == u" " or ch == u"\t" or ch == u"\n" or ch == u"\r": |
| 224 | + return True |
| 225 | + if cat == "Zs": |
| 226 | + return True |
| 227 | + return False |
| 228 | + |
| 229 | + |
| 230 | +cdef inline bint _is_chinese_char(Py_UCS4 cp): |
| 231 | + """Checks whether CP is the codepoint of a CJK character. |
| 232 | + |
| 233 | + This defines a "chinese character" as anything in the CJK Unicode block: |
| 234 | + https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
| 235 | +
|
| 236 | + Note that the CJK Unicode block is NOT all Japanese and Korean characters, |
| 237 | + despite its name. The modern Korean Hangul alphabet is a different block, |
| 238 | + as is Japanese Hiragana and Katakana. Those alphabets are used to write |
| 239 | + space-separated words, so they are not treated specially and handled |
| 240 | + like the all of the other languages. |
| 241 | + """ |
| 242 | + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # |
| 243 | + (cp >= 0x3400 and cp <= 0x4DBF) or # |
| 244 | + (cp >= 0x20000 and cp <= 0x2A6DF) or # |
| 245 | + (cp >= 0x2A700 and cp <= 0x2B73F) or # |
| 246 | + (cp >= 0x2B740 and cp <= 0x2B81F) or # |
| 247 | + (cp >= 0x2B820 and cp <= 0x2CEAF) or |
| 248 | + (cp >= 0xF900 and cp <= 0xFAFF) or # |
| 249 | + (cp >= 0x2F800 and cp <= 0x2FA1F)): # |
| 250 | + return True |
| 251 | + return False |
| 252 | + |
| 253 | + |
| 254 | +# Public functions for testing |
| 255 | +def is_punctuation(Py_UCS4 cp, str cat): |
| 256 | + return _is_punctuation(cp, cat) |
| 257 | + |
| 258 | +def is_control(Py_UCS4 ch, str cat): |
| 259 | + return _is_control(ch, cat) |
| 260 | + |
| 261 | +def is_whitespace(Py_UCS4 ch, str cat): |
| 262 | + return _is_whitespace(ch, cat) |
| 263 | + |
| 264 | +def is_chinese_char(Py_UCS4 cp): |
| 265 | + return _is_chinese_char(cp) |
| 266 | + |
0 commit comments