Skip to content

Commit ab2307c

Browse files
author
Yibing Liu
committed
update code & add test
1 parent 1af0222 commit ab2307c

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
## This is a prototype of ctc beam search decoder
2+
3+
import copy
4+
import random
5+
import numpy as np
6+
7+
# vocab = blank + space + English characters
8+
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)]
9+
10+
vocab = ['-', '_', 'a']
11+
12+
13+
def ids_str2list(ids_str):
14+
ids_str = ids_str.split(' ')
15+
ids_list = [int(elem) for elem in ids_str]
16+
return ids_list
17+
18+
19+
def ids_list2str(ids_list):
20+
ids_str = [str(elem) for elem in ids_list]
21+
ids_str = ' '.join(ids_str)
22+
return ids_str
23+
24+
25+
def ids_id2token(ids_list):
26+
ids_str = ''
27+
for ids in ids_list:
28+
ids_str += vocab[ids]
29+
return ids_str
30+
31+
32+
def ctc_beam_search_decoder(input_probs_matrix,
33+
beam_size,
34+
max_time_steps=None,
35+
lang_model=None,
36+
alpha=1.0,
37+
beta=1.0,
38+
blank_id=0,
39+
space_id=1,
40+
num_results_per_sample=None):
41+
'''
42+
beam search decoder for CTC-trained network, called outside of the recurrent group.
43+
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873.
44+
45+
param input_probs_matrix: probs matrix for input sequence, row major
46+
type input_probs_matrix: 2D matrix.
47+
param beam_size: width for beam search
48+
type beam_size: int
49+
max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix)
50+
type max_time_steps: int
51+
lang_model: language model for scoring
52+
type lang_model: function
53+
54+
......
55+
56+
'''
57+
if num_results_per_sample is None:
58+
num_results_per_sample = beam_size
59+
assert num_results_per_sample <= beam_size
60+
61+
if max_time_steps is None:
62+
max_time_steps = len(input_probs_matrix)
63+
else:
64+
max_time_steps = min(max_time_steps, len(input_probs_matrix))
65+
assert max_time_steps > 0
66+
67+
vocab_dim = len(input_probs_matrix[0])
68+
assert blank_id < vocab_dim
69+
assert space_id < vocab_dim
70+
71+
## initialize
72+
start_id = -1
73+
# the set containing selected prefixes
74+
prefix_set_prev = {str(start_id): 1.0}
75+
probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0}
76+
77+
## extend prefix in loop
78+
for time_step in range(max_time_steps):
79+
# the set containing candidate prefixes
80+
prefix_set_next = {}
81+
probs_b_cur, probs_nb_cur = {}, {}
82+
for l in prefix_set_prev:
83+
prob = input_probs_matrix[time_step]
84+
85+
# convert ids in string to list
86+
ids_list = ids_str2list(l)
87+
end_id = ids_list[-1]
88+
if not prefix_set_next.has_key(l):
89+
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
90+
91+
# extend prefix by travering vocabulary
92+
for c in range(0, vocab_dim):
93+
if c == blank_id:
94+
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
95+
else:
96+
l_plus = l + ' ' + str(c)
97+
if not prefix_set_next.has_key(l_plus):
98+
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
99+
100+
if c == end_id:
101+
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
102+
probs_nb_cur[l] += prob[c] * probs_nb[l]
103+
elif c == space_id:
104+
lm = 1.0 if lang_model is None \
105+
else np.power(lang_model(ids_list), alpha)
106+
probs_nb_cur[l_plus] += lm * prob[c] * (
107+
probs_b[l] + probs_nb[l])
108+
else:
109+
probs_nb_cur[l_plus] += prob[c] * (
110+
probs_b[l] + probs_nb[l])
111+
# add l_plus into prefix_set_next
112+
prefix_set_next[l_plus] = probs_nb_cur[
113+
l_plus] + probs_b_cur[l_plus]
114+
# add l into prefix_set_next
115+
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
116+
# update probs
117+
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
118+
probs_nb_cur)
119+
120+
## store top beam_size prefixes
121+
prefix_set_prev = sorted(
122+
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
123+
if beam_size < len(prefix_set_prev):
124+
prefix_set_prev = prefix_set_prev[:beam_size]
125+
prefix_set_prev = dict(prefix_set_prev)
126+
127+
beam_result = []
128+
for (seq, prob) in prefix_set_prev.items():
129+
if prob > 0.0:
130+
ids_list = ids_str2list(seq)
131+
log_prob = np.log(prob)
132+
beam_result.append([log_prob, ids_list[1:]])
133+
134+
## output top beam_size decoding results
135+
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
136+
if num_results_per_sample < beam_size:
137+
beam_result = beam_result[:num_results_per_sample]
138+
return beam_result
139+
140+
141+
def language_model(input):
142+
# TODO
143+
return random.uniform(0, 1)
144+
145+
146+
def simple_test():
147+
148+
input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]]
149+
150+
beam_result = ctc_beam_search_decoder(
151+
input_probs_matrix=input_probs_matrix,
152+
beam_size=20,
153+
blank_id=0,
154+
space_id=1, )
155+
156+
print "\nbeam search output:"
157+
for result in beam_result:
158+
print("%6f\t%s" % (result[0], ids_id2token(result[1])))
159+
160+
161+
if __name__ == '__main__':
162+
simple_test()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import absolute_import
2+
from __future__ import print_function
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
from tensorflow.python.framework import ops
7+
from tensorflow.python.ops import array_ops
8+
import ctc_beam_search_decoder as tested_decoder
9+
10+
11+
def test_beam_search_decoder():
12+
max_time_steps = 6
13+
beam_size = 20
14+
num_results_per_sample = 20
15+
16+
input_prob_matrix_0 = np.asarray(
17+
[
18+
[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
19+
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
20+
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
21+
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
22+
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
23+
# Random entry added in at time=5
24+
[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
25+
],
26+
dtype=np.float32)
27+
28+
# Add arbitrary offset - this is fine
29+
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0
30+
31+
# len max_time_steps array of batch_size x depth matrices
32+
inputs = ([
33+
input_log_prob_matrix_0[t, :][np.newaxis, :]
34+
for t in range(max_time_steps)
35+
])
36+
37+
inputs_t = [ops.convert_to_tensor(x) for x in inputs]
38+
inputs_t = array_ops.stack(inputs_t)
39+
40+
# run CTC beam search decoder in tensorflow
41+
with tf.Session() as sess:
42+
decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(
43+
inputs_t, [max_time_steps],
44+
beam_width=beam_size,
45+
top_paths=num_results_per_sample,
46+
merge_repeated=False)
47+
tf_decoded = sess.run(decoded)
48+
tf_log_probs = sess.run(log_probabilities)
49+
50+
# run tested CTC beam search decoder
51+
beam_result = tested_decoder.ctc_beam_search_decoder(
52+
input_probs_matrix=input_prob_matrix_0,
53+
beam_size=beam_size,
54+
blank_id=5, # default blank_id in tensorflow decoder is (num classes-1)
55+
space_id=4, # doesn't matter
56+
max_time_steps=max_time_steps,
57+
num_results_per_sample=num_results_per_sample)
58+
59+
# compare decoding result
60+
print(
61+
"{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}"
62+
)
63+
for index in range(len(beam_result)):
64+
print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]),
65+
tf_decoded[index].values, ' ', beam_result[index][1])
66+
67+
68+
if __name__ == '__main__':
69+
test_beam_search_decoder()

0 commit comments

Comments
 (0)