|
1 | | -"""Contains various CTC decoder.""" |
| 1 | +"""Contains various CTC decoders.""" |
2 | 2 | from __future__ import absolute_import |
3 | 3 | from __future__ import division |
4 | 4 | from __future__ import print_function |
5 | 5 |
|
6 | | -import numpy as np |
| 6 | +import os |
7 | 7 | from itertools import groupby |
| 8 | +import numpy as np |
| 9 | +import kenlm |
| 10 | +import multiprocessing |
8 | 11 |
|
9 | 12 |
|
10 | 13 | def ctc_best_path_decode(probs_seq, vocabulary): |
@@ -36,24 +39,250 @@ def ctc_best_path_decode(probs_seq, vocabulary): |
36 | 39 | return ''.join([vocabulary[index] for index in index_list]) |
37 | 40 |
|
38 | 41 |
|
39 | | -def ctc_decode(probs_seq, vocabulary, method): |
40 | | - """CTC-like sequence decoding from a sequence of likelihood probablilites. |
| 42 | +class Scorer(object): |
| 43 | + """External defined scorer to evaluate a sentence in beam search |
| 44 | + decoding, consisting of language model and word count. |
41 | 45 |
|
42 | | - :param probs_seq: 2-D list of probabilities over the vocabulary for each |
43 | | - character. Each element is a list of float probabilities |
44 | | - for one character. |
45 | | - :type probs_seq: list |
| 46 | + :param alpha: Parameter associated with language model. |
| 47 | + :type alpha: float |
| 48 | + :param beta: Parameter associated with word count. |
| 49 | + :type beta: float |
| 50 | + :model_path: Path to load language model. |
| 51 | + :type model_path: basestring |
| 52 | + """ |
| 53 | + |
| 54 | + def __init__(self, alpha, beta, model_path): |
| 55 | + self._alpha = alpha |
| 56 | + self._beta = beta |
| 57 | + if not os.path.isfile(model_path): |
| 58 | + raise IOError("Invaid language model path: %s" % model_path) |
| 59 | + self._language_model = kenlm.LanguageModel(model_path) |
| 60 | + |
| 61 | + # n-gram language model scoring |
| 62 | + def language_model_score(self, sentence): |
| 63 | + #log prob of last word |
| 64 | + log_cond_prob = list( |
| 65 | + self._language_model.full_scores(sentence, eos=False))[-1][0] |
| 66 | + return np.power(10, log_cond_prob) |
| 67 | + |
| 68 | + # word insertion term |
| 69 | + def word_count(self, sentence): |
| 70 | + words = sentence.strip().split(' ') |
| 71 | + return len(words) |
| 72 | + |
| 73 | + # execute evaluation |
| 74 | + def __call__(self, sentence, log=False): |
| 75 | + """Evaluation function, gathering all the scores. |
| 76 | +
|
| 77 | + :param sentence: The input sentence for evalutation |
| 78 | + :type sentence: basestring |
| 79 | + :param log: Whether return the score in log representation. |
| 80 | + :type log: bool |
| 81 | + :return: Evaluation score, in the decimal or log. |
| 82 | + :rtype: float |
| 83 | + """ |
| 84 | + lm = self.language_model_score(sentence) |
| 85 | + word_cnt = self.word_count(sentence) |
| 86 | + if log == False: |
| 87 | + score = np.power(lm, self._alpha) \ |
| 88 | + * np.power(word_cnt, self._beta) |
| 89 | + else: |
| 90 | + score = self._alpha * np.log(lm) \ |
| 91 | + + self._beta * np.log(word_cnt) |
| 92 | + return score |
| 93 | + |
| 94 | + |
| 95 | +def ctc_beam_search_decoder(probs_seq, |
| 96 | + beam_size, |
| 97 | + vocabulary, |
| 98 | + blank_id=0, |
| 99 | + cutoff_prob=1.0, |
| 100 | + ext_scoring_func=None, |
| 101 | + nproc=False): |
| 102 | + '''Beam search decoder for CTC-trained network, using beam search with width |
| 103 | + beam_size to find many paths to one label, return beam_size labels in |
| 104 | + the descending order of probabilities. The implementation is based on Prefix |
| 105 | + Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is |
| 106 | + redesigned. |
| 107 | +
|
| 108 | + :param probs_seq: 2-D list with length num_time_steps, each element |
| 109 | + is a list of normalized probabilities over vocabulary |
| 110 | + and blank for one time step. |
| 111 | + :type probs_seq: 2-D list |
| 112 | + :param beam_size: Width for beam search. |
| 113 | + :type beam_size: int |
46 | 114 | :param vocabulary: Vocabulary list. |
47 | 115 | :type vocabulary: list |
48 | | - :param method: Decoding method name, with options: "best_path". |
49 | | - :type method: basestring |
50 | | - :return: Decoding result string. |
51 | | - :rtype: baseline |
52 | | - """ |
| 116 | + :param blank_id: ID of blank, default 0. |
| 117 | + :type blank_id: int |
| 118 | + :param cutoff_prob: Cutoff probability in pruning, |
| 119 | + default 1.0, no pruning. |
| 120 | + :type cutoff_prob: float |
| 121 | + :param ext_scoring_func: External defined scoring function for |
| 122 | + partially decoded sentence, e.g. word count |
| 123 | + and language model. |
| 124 | + :type external_scoring_function: function |
| 125 | + :param nproc: Whether the decoder used in multiprocesses. |
| 126 | + :type nproc: bool |
| 127 | + :return: Decoding log probabilities and result sentences in descending order. |
| 128 | + :rtype: list |
| 129 | + ''' |
| 130 | + # dimension check |
53 | 131 | for prob_list in probs_seq: |
54 | 132 | if not len(prob_list) == len(vocabulary) + 1: |
55 | | - raise ValueError("probs dimension mismatchedd with vocabulary") |
56 | | - if method == "best_path": |
57 | | - return ctc_best_path_decode(probs_seq, vocabulary) |
58 | | - else: |
59 | | - raise ValueError("Decoding method [%s] is not supported.") |
| 133 | + raise ValueError("probs dimension mismatched with vocabulary") |
| 134 | + num_time_steps = len(probs_seq) |
| 135 | + |
| 136 | + # blank_id check |
| 137 | + probs_dim = len(probs_seq[0]) |
| 138 | + if not blank_id < probs_dim: |
| 139 | + raise ValueError("blank_id shouldn't be greater than probs dimension") |
| 140 | + |
| 141 | + # If the decoder called in the multiprocesses, then use the global scorer |
| 142 | + # instantiated in ctc_beam_search_decoder_nproc(). |
| 143 | + if nproc is True: |
| 144 | + global ext_nproc_scorer |
| 145 | + ext_scoring_func = ext_nproc_scorer |
| 146 | + |
| 147 | + ## initialize |
| 148 | + # the set containing selected prefixes |
| 149 | + prefix_set_prev = {'\t': 1.0} |
| 150 | + probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} |
| 151 | + |
| 152 | + ## extend prefix in loop |
| 153 | + for time_step in xrange(num_time_steps): |
| 154 | + # the set containing candidate prefixes |
| 155 | + prefix_set_next = {} |
| 156 | + probs_b_cur, probs_nb_cur = {}, {} |
| 157 | + prob = probs_seq[time_step] |
| 158 | + prob_idx = [[i, prob[i]] for i in xrange(len(prob))] |
| 159 | + cutoff_len = len(prob_idx) |
| 160 | + #If pruning is enabled |
| 161 | + if (cutoff_prob < 1.0): |
| 162 | + prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) |
| 163 | + cutoff_len = 0 |
| 164 | + cum_prob = 0.0 |
| 165 | + for i in xrange(len(prob_idx)): |
| 166 | + cum_prob += prob_idx[i][1] |
| 167 | + cutoff_len += 1 |
| 168 | + if cum_prob >= cutoff_prob: |
| 169 | + break |
| 170 | + prob_idx = prob_idx[0:cutoff_len] |
| 171 | + |
| 172 | + for l in prefix_set_prev: |
| 173 | + if not prefix_set_next.has_key(l): |
| 174 | + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 |
| 175 | + |
| 176 | + # extend prefix by travering prob_idx |
| 177 | + for index in xrange(cutoff_len): |
| 178 | + c, prob_c = prob_idx[index][0], prob_idx[index][1] |
| 179 | + |
| 180 | + if c == blank_id: |
| 181 | + probs_b_cur[l] += prob_c * ( |
| 182 | + probs_b_prev[l] + probs_nb_prev[l]) |
| 183 | + else: |
| 184 | + last_char = l[-1] |
| 185 | + new_char = vocabulary[c] |
| 186 | + l_plus = l + new_char |
| 187 | + if not prefix_set_next.has_key(l_plus): |
| 188 | + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 |
| 189 | + |
| 190 | + if new_char == last_char: |
| 191 | + probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] |
| 192 | + probs_nb_cur[l] += prob_c * probs_nb_prev[l] |
| 193 | + elif new_char == ' ': |
| 194 | + if (ext_scoring_func is None) or (len(l) == 1): |
| 195 | + score = 1.0 |
| 196 | + else: |
| 197 | + prefix = l[1:] |
| 198 | + score = ext_scoring_func(prefix) |
| 199 | + probs_nb_cur[l_plus] += score * prob_c * ( |
| 200 | + probs_b_prev[l] + probs_nb_prev[l]) |
| 201 | + else: |
| 202 | + probs_nb_cur[l_plus] += prob_c * ( |
| 203 | + probs_b_prev[l] + probs_nb_prev[l]) |
| 204 | + # add l_plus into prefix_set_next |
| 205 | + prefix_set_next[l_plus] = probs_nb_cur[ |
| 206 | + l_plus] + probs_b_cur[l_plus] |
| 207 | + # add l into prefix_set_next |
| 208 | + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] |
| 209 | + # update probs |
| 210 | + probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur |
| 211 | + |
| 212 | + ## store top beam_size prefixes |
| 213 | + prefix_set_prev = sorted( |
| 214 | + prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) |
| 215 | + if beam_size < len(prefix_set_prev): |
| 216 | + prefix_set_prev = prefix_set_prev[:beam_size] |
| 217 | + prefix_set_prev = dict(prefix_set_prev) |
| 218 | + |
| 219 | + beam_result = [] |
| 220 | + for (seq, prob) in prefix_set_prev.items(): |
| 221 | + if prob > 0.0 and len(seq) > 1: |
| 222 | + result = seq[1:] |
| 223 | + # score last word by external scorer |
| 224 | + if (ext_scoring_func is not None) and (result[-1] != ' '): |
| 225 | + prob = prob * ext_scoring_func(result) |
| 226 | + log_prob = np.log(prob) |
| 227 | + beam_result.append([log_prob, result]) |
| 228 | + |
| 229 | + ## output top beam_size decoding results |
| 230 | + beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) |
| 231 | + return beam_result |
| 232 | + |
| 233 | + |
| 234 | +def ctc_beam_search_decoder_nproc(probs_split, |
| 235 | + beam_size, |
| 236 | + vocabulary, |
| 237 | + blank_id=0, |
| 238 | + cutoff_prob=1.0, |
| 239 | + ext_scoring_func=None, |
| 240 | + num_processes=None): |
| 241 | + '''Beam search decoder using multiple processes. |
| 242 | +
|
| 243 | + :param probs_seq: 3-D list with length batch_size, each element |
| 244 | + is a 2-D list of probabilities can be used by |
| 245 | + ctc_beam_search_decoder. |
| 246 | + :type probs_seq: 3-D list |
| 247 | + :param beam_size: Width for beam search. |
| 248 | + :type beam_size: int |
| 249 | + :param vocabulary: Vocabulary list. |
| 250 | + :type vocabulary: list |
| 251 | + :param blank_id: ID of blank, default 0. |
| 252 | + :type blank_id: int |
| 253 | + :param cutoff_prob: Cutoff probability in pruning, |
| 254 | + default 0, no pruning. |
| 255 | + :type cutoff_prob: float |
| 256 | + :param ext_scoring_func: External defined scoring function for |
| 257 | + partially decoded sentence, e.g. word count |
| 258 | + and language model. |
| 259 | + :type external_scoring_function: function |
| 260 | + :param num_processes: Number of processes, default None, equal to the |
| 261 | + number of CPUs. |
| 262 | + :type num_processes: int |
| 263 | + :return: Decoding log probabilities and result sentences in descending order. |
| 264 | + :rtype: list |
| 265 | + ''' |
| 266 | + if num_processes is None: |
| 267 | + num_processes = multiprocessing.cpu_count() |
| 268 | + if not num_processes > 0: |
| 269 | + raise ValueError("Number of processes must be positive!") |
| 270 | + |
| 271 | + # use global variable to pass the externnal scorer to beam search decoder |
| 272 | + global ext_nproc_scorer |
| 273 | + ext_nproc_scorer = ext_scoring_func |
| 274 | + nproc = True |
| 275 | + |
| 276 | + pool = multiprocessing.Pool(processes=num_processes) |
| 277 | + results = [] |
| 278 | + for i, probs_list in enumerate(probs_split): |
| 279 | + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, |
| 280 | + nproc) |
| 281 | + results.append(pool.apply_async(ctc_beam_search_decoder, args)) |
| 282 | + |
| 283 | + pool.close() |
| 284 | + pool.join() |
| 285 | + beam_search_results = [] |
| 286 | + for result in results: |
| 287 | + beam_search_results.append(result.get()) |
| 288 | + return beam_search_results |
0 commit comments