Skip to content

Commit b1dc73d

Browse files
committed
add the init_chunk_decoder()
1 parent 76ae568 commit b1dc73d

File tree

6 files changed

+108
-114
lines changed

6 files changed

+108
-114
lines changed

deepspeech/decoders/swig/ctc_beam_search_decoder.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,23 @@ CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {
486486
}
487487
}
488488

489-
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(size_t batch_size,
490-
Scorer *ext_scorer) {
491-
this->batch_size = batch_size;
489+
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(
490+
const std::vector<std::string> &vocabulary,
491+
size_t batch_size,
492+
size_t beam_size,
493+
size_t num_processes,
494+
double cutoff_prob,
495+
size_t cutoff_top_n,
496+
Scorer *ext_scorer,
497+
size_t blank_id)
498+
: batch_size(batch_size),
499+
beam_size(beam_size),
500+
num_processes(num_processes),
501+
cutoff_prob(cutoff_prob),
502+
cutoff_top_n(cutoff_top_n),
503+
blank_id(blank_id) {
504+
this->vocabulary = vocabulary;
505+
492506
for (size_t i = 0; i < batch_size; i++) {
493507
CtcBeamSearchDecoderStorage *decoder_storage =
494508
new CtcBeamSearchDecoderStorage();
@@ -501,18 +515,17 @@ CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(size_t batch_size,
501515

502516
void CtcBeamSearchDecoderBatch::next(
503517
const std::vector<std::vector<std::vector<double>>> &probs_split,
504-
const std::vector<std::string> &vocabulary,
505-
size_t beam_size,
506-
size_t num_processes,
507-
double cutoff_prob,
508-
size_t cutoff_top_n,
509-
Scorer *ext_scorer,
510-
size_t blank_id) {
518+
Scorer *ext_scorer) {
511519
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
512520
// thread pool
513521
ThreadPool pool(num_processes);
514522
// number of samples
515-
size_t batch_size = probs_split.size();
523+
size_t probs_num = probs_split.size();
524+
VALID_CHECK_EQ(this->batch_size,
525+
probs_num,
526+
"The batch size of the current input data should be same "
527+
"with the input data before");
528+
516529
// enqueue the tasks of decoding
517530
std::vector<std::future<void>> res;
518531
for (size_t i = 0; i < batch_size; ++i) {
@@ -521,12 +534,12 @@ void CtcBeamSearchDecoderBatch::next(
521534
std::ref(this->decoder_storage_vector[i]->root),
522535
std::ref(this->decoder_storage_vector[i]->prefixes),
523536
probs_split[i],
524-
vocabulary,
525-
beam_size,
526-
cutoff_prob,
527-
cutoff_top_n,
537+
this->vocabulary,
538+
this->beam_size,
539+
this->cutoff_prob,
540+
this->cutoff_top_n,
528541
ext_scorer,
529-
blank_id));
542+
this->blank_id));
530543
}
531544

532545
for (size_t i = 0; i < batch_size; ++i) {
@@ -536,28 +549,25 @@ void CtcBeamSearchDecoderBatch::next(
536549
};
537550

538551
std::vector<std::vector<std::pair<double, std::string>>>
539-
CtcBeamSearchDecoderBatch::decode(const std::vector<std::string> &vocabulary,
540-
size_t beam_size,
541-
size_t num_processes,
542-
Scorer *ext_scorer) {
543-
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
552+
CtcBeamSearchDecoderBatch::decode(Scorer *ext_scorer) {
553+
VALID_CHECK_GT(
554+
this->num_processes, 0, "num_processes must be nonnegative!");
544555
// thread pool
545-
ThreadPool pool(num_processes);
556+
ThreadPool pool(this->num_processes);
546557
// number of samples
547558
// enqueue the tasks of decoding
548559
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
549560
for (size_t i = 0; i < this->batch_size; ++i) {
550561
res.emplace_back(
551562
pool.enqueue(get_decode_result,
552563
std::ref(this->decoder_storage_vector[i]->prefixes),
553-
vocabulary,
554-
beam_size,
564+
this->vocabulary,
565+
this->beam_size,
555566
ext_scorer));
556567
}
557-
558568
// get decoding results
559569
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
560-
for (size_t i = 0; i < batch_size; ++i) {
570+
for (size_t i = 0; i < this->batch_size; ++i) {
561571
batch_results.emplace_back(res[i].get());
562572
}
563573
return batch_results;

deepspeech/decoders/swig/ctc_beam_search_decoder.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,31 @@ class CtcBeamSearchDecoderStorage {
9292
// TODU[HYX]: Support batch_size > 1
9393
class CtcBeamSearchDecoderBatch {
9494
public:
95-
CtcBeamSearchDecoderBatch(size_t batch_size, Scorer *ext_scorer);
95+
CtcBeamSearchDecoderBatch(const std::vector<std::string> &vocabulary,
96+
size_t batch_size,
97+
size_t beam_size,
98+
size_t num_processes,
99+
double cutoff_prob,
100+
size_t cutoff_top_n,
101+
Scorer *ext_scorer,
102+
size_t blank_id);
96103

97104
~CtcBeamSearchDecoderBatch();
98105
void next(const std::vector<std::vector<std::vector<double>>> &probs_split,
99-
const std::vector<std::string> &vocabulary,
100-
size_t beam_size,
101-
size_t num_processes,
102-
double cutoff_prob,
103-
size_t cutoff_top_n,
104-
Scorer *ext_scorer,
105-
size_t blank_id);
106+
Scorer *ext_scorer);
107+
106108
std::vector<std::vector<std::pair<double, std::string>>> decode(
107-
const std::vector<std::string> &vocabulary,
108-
size_t beam_size,
109-
size_t num_processes,
110109
Scorer *ext_scorer);
111110

112111

113112
private:
113+
std::vector<std::string> vocabulary;
114114
size_t batch_size;
115+
size_t beam_size;
116+
size_t num_processes;
117+
double cutoff_prob;
118+
size_t cutoff_top_n;
119+
size_t blank_id;
115120
std::vector<CtcBeamSearchDecoderStorage *> decoder_storage_vector;
116121
};
117122

deepspeech/decoders/swig_wrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ def ctc_beam_search_decoder_batch(probs_split,
134134
return batch_beam_results
135135

136136

137-
def get_ctc_beam_search_chunk_decoder(batch_size, ext_scoring_func):
138-
chunk_decoder = swig_decoders.CtcBeamSearchDecoderBatch(batch_size,
139-
ext_scoring_func)
137+
def get_ctc_beam_search_chunk_decoder(vocabulary, batch_size, beam_size,
138+
num_processes, cutoff_prob, cutoff_top_n,
139+
ext_scoring_func, blank_id):
140+
chunk_decoder = swig_decoders.CtcBeamSearchDecoderBatch(
141+
vocabulary, batch_size, beam_size, num_processes, cutoff_prob,
142+
cutoff_top_n, ext_scoring_func, blank_id)
140143
return chunk_decoder

deepspeech/exps/deepspeech2/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ def static_forward_online(self,
537537
trans_chunk_list = []
538538
probs_chunk_list = []
539539
probs_chunk_lens_list = []
540+
self.model.init_chunk_decoder(
541+
1, vocab_list, cfg.decoding_method, cfg.lang_model_path,
542+
cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob,
543+
cfg.cutoff_top_n, cfg.num_proc_bsearch)
540544
for i in range(0, num_chunk):
541545
start = i * chunk_stride
542546
end = start + chunk_size
@@ -577,15 +581,10 @@ def static_forward_online(self,
577581
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
578582
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
579583
self.model.decode_get_next(
580-
output_chunk_probs, output_chunk_lens, vocab_list,
581-
cfg.decoding_method, cfg.lang_model_path, cfg.alpha,
582-
cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
583-
cfg.num_proc_bsearch)
584+
probs=output_chunk_probs, probs_len=output_chunk_lens)
584585
probs_chunk_list.append(output_chunk_probs)
585586
probs_chunk_lens_list.append(output_chunk_lens)
586-
trans = self.model.decode_get_trans(
587-
1, vocab_list, cfg.decoding_method, cfg.alpha, cfg.beta,
588-
cfg.beam_size, cfg.num_proc_bsearch)
587+
trans = self.model.decode_get_trans()
589588
batch_trans_list.append(trans[0])
590589
self.model.del_chunk_decoder()
591590

deepspeech/models/ds2_online/deepspeech2.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -334,54 +334,31 @@ def decode(self, audio, audio_len, vocab_list, decoding_method,
334334
cutoff_top_n, num_processes)
335335

336336
@paddle.no_grad()
337-
def decode_get_next(self, probs, probs_len, vocab_list, decoding_method,
338-
lang_model_path, beam_alpha, beam_beta, beam_size,
339-
cutoff_prob, cutoff_top_n, num_processes):
340-
341-
if self.chunk_decoder is None:
342-
self.decoder.init_decode(
343-
beam_alpha=beam_alpha,
344-
beam_beta=beam_beta,
345-
lang_model_path=lang_model_path,
346-
vocab_list=vocab_list,
347-
decoding_method=decoding_method)
348-
batch_size = probs.shape[0]
349-
self.chunk_decoder = self.decoder.get_chunk_decoder(decoding_method,
350-
batch_size)
351-
352-
self.decoder.decoder_next(self.chunk_decoder, probs, probs_len,
353-
vocab_list, decoding_method, beam_alpha,
354-
beam_beta, beam_size, cutoff_prob,
355-
cutoff_top_n, num_processes)
356-
357-
def decode_get_trans(self, batch_size, vocab_list, decoding_method,
358-
beam_alpha, beam_beta, beam_size, num_processes):
359-
assert (self.chunk_decoder is not None)
360-
trans = self.decoder.chunk_decoder_to_decode(
361-
self.chunk_decoder, batch_size, vocab_list, decoding_method,
362-
beam_alpha, beam_beta, beam_size, num_processes)
363-
return trans
337+
def init_chunk_decoder(self, batch_size, vocab_list, decoding_method,
338+
lang_model_path, beam_alpha, beam_beta, beam_size,
339+
cutoff_prob, cutoff_top_n, num_processes):
340+
self.decoder.init_decode(
341+
beam_alpha=beam_alpha,
342+
beam_beta=beam_beta,
343+
lang_model_path=lang_model_path,
344+
vocab_list=vocab_list,
345+
decoding_method=decoding_method)
346+
if self.chunk_decoder is not None:
347+
self.del_chunk_decoder()
348+
self.chunk_decoder = self.decoder.get_chunk_decoder(
349+
vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
350+
num_processes, cutoff_prob, cutoff_top_n)
364351

365352
@paddle.no_grad()
366-
def decode_chunk(self, probs, probs_len, vocab_list, decoding_method,
367-
lang_model_path, beam_alpha, beam_beta, beam_size,
368-
cutoff_prob, cutoff_top_n, num_processes):
353+
def decode_get_next(self, probs, probs_len):
354+
if self.chunk_decoder is None:
355+
raise Exception("You need to initialize the chunk decoder firstly")
356+
self.decoder.chunk_decoder_next(self.chunk_decoder, probs, probs_len)
369357

358+
def decode_get_trans(self):
370359
if self.chunk_decoder is None:
371-
self.decoder.init_decode(
372-
beam_alpha=beam_alpha,
373-
beam_beta=beam_beta,
374-
lang_model_path=lang_model_path,
375-
vocab_list=vocab_list,
376-
decoding_method=decoding_method)
377-
batch_size = probs.shape[0]
378-
self.chunk_decoder = self.decoder.get_chunk_decoder(decoding_method,
379-
batch_size)
380-
381-
trans = self.decoder.chunk_decoder_to_decode(
382-
self.chunk_decoder, probs, probs_len, vocab_list, decoding_method,
383-
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
384-
cutoff_top_n, num_processes)
360+
raise Exception("You need to initialize the chunk decoder firstly")
361+
trans = self.decoder.chunk_decoder_decode(self.chunk_decoder)
385362
return trans
386363

387364
def del_chunk_decoder(self):

deepspeech/modules/ctc.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from deepspeech.utils.log import Log
2222

2323
logger = Log(__name__).getlog()
24+
2425
try:
2526
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
2627
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
@@ -66,6 +67,7 @@ def __init__(self,
6667
batch_average=batch_average,
6768
grad_norm_type=grad_norm_type)
6869

70+
self.decoding_method = "ctc_beam_search"
6971
# CTCDecoder LM Score handle
7072
self._ext_scorer = None
7173

@@ -227,7 +229,7 @@ def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
227229

228230
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
229231
decoding_method):
230-
232+
self.decoding_method = decoding_method
231233
if decoding_method == "ctc_beam_search":
232234
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
233235
vocab_list)
@@ -275,40 +277,38 @@ def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
275277
raise ValueError(f"Not support: {decoding_method}")
276278
return result_transcripts
277279

278-
def get_chunk_decoder(self, decoding_method, batch_size):
279-
if decoding_method == "ctc_beam_search":
280+
def get_chunk_decoder(self, vocabulary, batch_size, beam_alpha, beam_beta,
281+
beam_size, num_processes, cutoff_prob, cutoff_top_n):
282+
num_processes = min(num_processes, batch_size)
283+
if self._ext_scorer is not None:
284+
self._ext_scorer.reset_params(beam_alpha, beam_beta)
285+
if self.decoding_method == "ctc_beam_search":
280286
chunk_decoder = get_ctc_beam_search_chunk_decoder(
281-
batch_size=batch_size, ext_scoring_func=self._ext_scorer)
287+
vocabulary=vocabulary,
288+
batch_size=batch_size,
289+
beam_size=beam_size,
290+
num_processes=num_processes,
291+
cutoff_prob=cutoff_prob,
292+
cutoff_top_n=cutoff_top_n,
293+
ext_scoring_func=self._ext_scorer,
294+
blank_id=self.blank_id)
282295
else:
283296
raise ValueError(f"Not support: {decoding_method}")
284297
return chunk_decoder
285298

286-
def decoder_next(self, chunk_decoder, probs, logits_lens, vocab_list,
287-
decoding_method, beam_alpha, beam_beta, beam_size,
288-
cutoff_prob, cutoff_top_n, num_processes):
299+
def chunk_decoder_next(self, chunk_decoder, probs, logits_lens):
289300
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
290301
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
291-
num_processes = min(num_processes, len(probs_split))
292-
if self._ext_scorer is not None:
293-
self._ext_scorer.reset_params(beam_alpha, beam_beta)
294-
if decoding_method == "ctc_beam_search":
295-
chunk_decoder.next(probs_split, vocab_list, beam_size,
296-
num_processes, cutoff_prob, cutoff_top_n,
297-
self._ext_scorer, self.blank_id)
302+
if self.decoding_method == "ctc_beam_search":
303+
chunk_decoder.next(probs_split, self._ext_scorer)
298304
else:
299305
raise ValueError(f"Not support: {decoding_method}")
300306

301307
return
302308

303-
def chunk_decoder_to_decode(self, chunk_decoder, batch_size, vocab_list,
304-
decoding_method, beam_alpha, beam_beta,
305-
beam_size, num_processes):
306-
num_processes = min(num_processes, batch_size)
307-
if self._ext_scorer is not None:
308-
self._ext_scorer.reset_params(beam_alpha, beam_beta)
309-
if decoding_method == "ctc_beam_search":
310-
batch_beam_results = chunk_decoder.decode(
311-
vocab_list, beam_size, num_processes, self._ext_scorer)
309+
def chunk_decoder_decode(self, chunk_decoder):
310+
if self.decoding_method == "ctc_beam_search":
311+
batch_beam_results = chunk_decoder.decode(self._ext_scorer)
312312
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
313313
for beam_results in batch_beam_results]
314314
results = [result[0][1] for result in batch_beam_results]

0 commit comments

Comments
 (0)