Skip to content

Commit 5c4751e

Browse files
author
Yibing Liu
committed
adapt to the new data provider
1 parent 06f272a commit 5c4751e

File tree

4 files changed

+813
-31
lines changed

4 files changed

+813
-31
lines changed

deep_speech_2/decoder.py

Lines changed: 247 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
"""Contains various CTC decoder."""
1+
"""Contains various CTC decoders."""
22
from __future__ import absolute_import
33
from __future__ import division
44
from __future__ import print_function
55

6-
import numpy as np
6+
import os
77
from itertools import groupby
8+
import numpy as np
9+
import kenlm
10+
import multiprocessing
811

912

1013
def ctc_best_path_decode(probs_seq, vocabulary):
@@ -36,24 +39,250 @@ def ctc_best_path_decode(probs_seq, vocabulary):
3639
return ''.join([vocabulary[index] for index in index_list])
3740

3841

39-
def ctc_decode(probs_seq, vocabulary, method):
40-
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
42+
class Scorer(object):
43+
"""External defined scorer to evaluate a sentence in beam search
44+
decoding, consisting of language model and word count.
4145
42-
:param probs_seq: 2-D list of probabilities over the vocabulary for each
43-
character. Each element is a list of float probabilities
44-
for one character.
45-
:type probs_seq: list
46+
:param alpha: Parameter associated with language model.
47+
:type alpha: float
48+
:param beta: Parameter associated with word count.
49+
:type beta: float
50+
:model_path: Path to load language model.
51+
:type model_path: basestring
52+
"""
53+
54+
def __init__(self, alpha, beta, model_path):
55+
self._alpha = alpha
56+
self._beta = beta
57+
if not os.path.isfile(model_path):
58+
raise IOError("Invaid language model path: %s" % model_path)
59+
self._language_model = kenlm.LanguageModel(model_path)
60+
61+
# n-gram language model scoring
62+
def language_model_score(self, sentence):
63+
#log prob of last word
64+
log_cond_prob = list(
65+
self._language_model.full_scores(sentence, eos=False))[-1][0]
66+
return np.power(10, log_cond_prob)
67+
68+
# word insertion term
69+
def word_count(self, sentence):
70+
words = sentence.strip().split(' ')
71+
return len(words)
72+
73+
# execute evaluation
74+
def __call__(self, sentence, log=False):
75+
"""Evaluation function, gathering all the scores.
76+
77+
:param sentence: The input sentence for evalutation
78+
:type sentence: basestring
79+
:param log: Whether return the score in log representation.
80+
:type log: bool
81+
:return: Evaluation score, in the decimal or log.
82+
:rtype: float
83+
"""
84+
lm = self.language_model_score(sentence)
85+
word_cnt = self.word_count(sentence)
86+
if log == False:
87+
score = np.power(lm, self._alpha) \
88+
* np.power(word_cnt, self._beta)
89+
else:
90+
score = self._alpha * np.log(lm) \
91+
+ self._beta * np.log(word_cnt)
92+
return score
93+
94+
95+
def ctc_beam_search_decoder(probs_seq,
96+
beam_size,
97+
vocabulary,
98+
blank_id=0,
99+
cutoff_prob=1.0,
100+
ext_scoring_func=None,
101+
nproc=False):
102+
'''Beam search decoder for CTC-trained network, using beam search with width
103+
beam_size to find many paths to one label, return beam_size labels in
104+
the descending order of probabilities. The implementation is based on Prefix
105+
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
106+
redesigned.
107+
108+
:param probs_seq: 2-D list with length num_time_steps, each element
109+
is a list of normalized probabilities over vocabulary
110+
and blank for one time step.
111+
:type probs_seq: 2-D list
112+
:param beam_size: Width for beam search.
113+
:type beam_size: int
46114
:param vocabulary: Vocabulary list.
47115
:type vocabulary: list
48-
:param method: Decoding method name, with options: "best_path".
49-
:type method: basestring
50-
:return: Decoding result string.
51-
:rtype: baseline
52-
"""
116+
:param blank_id: ID of blank, default 0.
117+
:type blank_id: int
118+
:param cutoff_prob: Cutoff probability in pruning,
119+
default 1.0, no pruning.
120+
:type cutoff_prob: float
121+
:param ext_scoring_func: External defined scoring function for
122+
partially decoded sentence, e.g. word count
123+
and language model.
124+
:type external_scoring_function: function
125+
:param nproc: Whether the decoder used in multiprocesses.
126+
:type nproc: bool
127+
:return: Decoding log probabilities and result sentences in descending order.
128+
:rtype: list
129+
'''
130+
# dimension check
53131
for prob_list in probs_seq:
54132
if not len(prob_list) == len(vocabulary) + 1:
55-
raise ValueError("probs dimension mismatchedd with vocabulary")
56-
if method == "best_path":
57-
return ctc_best_path_decode(probs_seq, vocabulary)
58-
else:
59-
raise ValueError("Decoding method [%s] is not supported.")
133+
raise ValueError("probs dimension mismatched with vocabulary")
134+
num_time_steps = len(probs_seq)
135+
136+
# blank_id check
137+
probs_dim = len(probs_seq[0])
138+
if not blank_id < probs_dim:
139+
raise ValueError("blank_id shouldn't be greater than probs dimension")
140+
141+
# If the decoder called in the multiprocesses, then use the global scorer
142+
# instantiated in ctc_beam_search_decoder_nproc().
143+
if nproc is True:
144+
global ext_nproc_scorer
145+
ext_scoring_func = ext_nproc_scorer
146+
147+
## initialize
148+
# the set containing selected prefixes
149+
prefix_set_prev = {'\t': 1.0}
150+
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
151+
152+
## extend prefix in loop
153+
for time_step in xrange(num_time_steps):
154+
# the set containing candidate prefixes
155+
prefix_set_next = {}
156+
probs_b_cur, probs_nb_cur = {}, {}
157+
prob = probs_seq[time_step]
158+
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
159+
cutoff_len = len(prob_idx)
160+
#If pruning is enabled
161+
if (cutoff_prob < 1.0):
162+
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
163+
cutoff_len = 0
164+
cum_prob = 0.0
165+
for i in xrange(len(prob_idx)):
166+
cum_prob += prob_idx[i][1]
167+
cutoff_len += 1
168+
if cum_prob >= cutoff_prob:
169+
break
170+
prob_idx = prob_idx[0:cutoff_len]
171+
172+
for l in prefix_set_prev:
173+
if not prefix_set_next.has_key(l):
174+
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
175+
176+
# extend prefix by travering prob_idx
177+
for index in xrange(cutoff_len):
178+
c, prob_c = prob_idx[index][0], prob_idx[index][1]
179+
180+
if c == blank_id:
181+
probs_b_cur[l] += prob_c * (
182+
probs_b_prev[l] + probs_nb_prev[l])
183+
else:
184+
last_char = l[-1]
185+
new_char = vocabulary[c]
186+
l_plus = l + new_char
187+
if not prefix_set_next.has_key(l_plus):
188+
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
189+
190+
if new_char == last_char:
191+
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
192+
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
193+
elif new_char == ' ':
194+
if (ext_scoring_func is None) or (len(l) == 1):
195+
score = 1.0
196+
else:
197+
prefix = l[1:]
198+
score = ext_scoring_func(prefix)
199+
probs_nb_cur[l_plus] += score * prob_c * (
200+
probs_b_prev[l] + probs_nb_prev[l])
201+
else:
202+
probs_nb_cur[l_plus] += prob_c * (
203+
probs_b_prev[l] + probs_nb_prev[l])
204+
# add l_plus into prefix_set_next
205+
prefix_set_next[l_plus] = probs_nb_cur[
206+
l_plus] + probs_b_cur[l_plus]
207+
# add l into prefix_set_next
208+
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
209+
# update probs
210+
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
211+
212+
## store top beam_size prefixes
213+
prefix_set_prev = sorted(
214+
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
215+
if beam_size < len(prefix_set_prev):
216+
prefix_set_prev = prefix_set_prev[:beam_size]
217+
prefix_set_prev = dict(prefix_set_prev)
218+
219+
beam_result = []
220+
for (seq, prob) in prefix_set_prev.items():
221+
if prob > 0.0 and len(seq) > 1:
222+
result = seq[1:]
223+
# score last word by external scorer
224+
if (ext_scoring_func is not None) and (result[-1] != ' '):
225+
prob = prob * ext_scoring_func(result)
226+
log_prob = np.log(prob)
227+
beam_result.append([log_prob, result])
228+
229+
## output top beam_size decoding results
230+
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
231+
return beam_result
232+
233+
234+
def ctc_beam_search_decoder_nproc(probs_split,
235+
beam_size,
236+
vocabulary,
237+
blank_id=0,
238+
cutoff_prob=1.0,
239+
ext_scoring_func=None,
240+
num_processes=None):
241+
'''Beam search decoder using multiple processes.
242+
243+
:param probs_seq: 3-D list with length batch_size, each element
244+
is a 2-D list of probabilities can be used by
245+
ctc_beam_search_decoder.
246+
:type probs_seq: 3-D list
247+
:param beam_size: Width for beam search.
248+
:type beam_size: int
249+
:param vocabulary: Vocabulary list.
250+
:type vocabulary: list
251+
:param blank_id: ID of blank, default 0.
252+
:type blank_id: int
253+
:param cutoff_prob: Cutoff probability in pruning,
254+
default 0, no pruning.
255+
:type cutoff_prob: float
256+
:param ext_scoring_func: External defined scoring function for
257+
partially decoded sentence, e.g. word count
258+
and language model.
259+
:type external_scoring_function: function
260+
:param num_processes: Number of processes, default None, equal to the
261+
number of CPUs.
262+
:type num_processes: int
263+
:return: Decoding log probabilities and result sentences in descending order.
264+
:rtype: list
265+
'''
266+
if num_processes is None:
267+
num_processes = multiprocessing.cpu_count()
268+
if not num_processes > 0:
269+
raise ValueError("Number of processes must be positive!")
270+
271+
# use global variable to pass the externnal scorer to beam search decoder
272+
global ext_nproc_scorer
273+
ext_nproc_scorer = ext_scoring_func
274+
nproc = True
275+
276+
pool = multiprocessing.Pool(processes=num_processes)
277+
results = []
278+
for i, probs_list in enumerate(probs_split):
279+
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
280+
nproc)
281+
results.append(pool.apply_async(ctc_beam_search_decoder, args))
282+
283+
pool.close()
284+
pool.join()
285+
beam_search_results = []
286+
for result in results:
287+
beam_search_results.append(result.get())
288+
return beam_search_results

0 commit comments

Comments
 (0)