Skip to content

Commit 1924153

Browse files
committed
fix a bug
1 parent eb4edad commit 1924153

File tree

1 file changed

+11
-12
lines changed
  • paddlespeech/s2t/exps/deepspeech2

1 file changed

+11
-12
lines changed

paddlespeech/s2t/exps/deepspeech2/model.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -413,19 +413,19 @@ def test(self):
413413

414414
def compute_result_transcripts(self, audio, audio_len):
415415
if self.args.model_type == "online":
416-
output_probs, output_lens, trans = self.static_forward_online(
416+
output_probs, output_lens, trans_batch = self.static_forward_online(
417417
audio, audio_len, decoder_chunk_size=1)
418-
result_transcripts = trans[-1:]
418+
result_transcripts = [trans[-1] for trans in trans_batch]
419419
elif self.args.model_type == "offline":
420-
batch_size = output_probs.shape[0]
421-
self.model.decoder.reset_decoder(batch_size = batch_size)
422420
output_probs, output_lens = self.static_forward_offline(audio,
423421
audio_len)
422+
batch_size = output_probs.shape[0]
423+
self.model.decoder.reset_decoder(batch_size=batch_size)
424424

425425
self.model.decoder.next(output_probs, output_lens)
426426

427427
trans_best, trans_beam = self.model.decoder.decode()
428-
428+
429429
result_transcripts = trans_best
430430

431431
else:
@@ -485,7 +485,7 @@ def static_forward_online(self, audio, audio_len,
485485
x_list = np.split(x_batch, batch_size, axis=0)
486486
x_len_list = np.split(x_len_batch, batch_size, axis=0)
487487

488-
trans = []
488+
trans_batch = []
489489
for x, x_len in zip(x_list, x_len_list):
490490
if self.args.enable_auto_log is True:
491491
self.autolog.times.start()
@@ -518,14 +518,14 @@ def static_forward_online(self, audio, audio_len,
518518
h_box_handle = self.predictor.get_input_handle(input_names[2])
519519
c_box_handle = self.predictor.get_input_handle(input_names[3])
520520

521-
trans_chunk_list = []
521+
trans = []
522522
probs_chunk_list = []
523523
probs_chunk_lens_list = []
524524
if self.args.enable_auto_log is True:
525525
# record the model preprocessing time
526526
self.autolog.times.stamp()
527-
528-
self.model.decoder.reset_decoder(batch_size = 1)
527+
528+
self.model.decoder.reset_decoder(batch_size=1)
529529
for i in range(0, num_chunk):
530530
start = i * chunk_stride
531531
end = start + chunk_size
@@ -569,8 +569,7 @@ def static_forward_online(self, audio, audio_len,
569569
probs_chunk_lens_list.append(output_chunk_lens)
570570
trans_best, trans_beam = self.model.decoder.decode()
571571
trans.append(trans_best[0])
572-
573-
572+
trans_batch.append(trans)
574573
output_probs = np.concatenate(probs_chunk_list, axis=1)
575574
output_lens = np.sum(probs_chunk_lens_list, axis=0)
576575
vocab_size = output_probs.shape[2]
@@ -592,7 +591,7 @@ def static_forward_online(self, audio, audio_len,
592591
self.autolog.times.end()
593592
output_probs = np.concatenate(output_probs_list, axis=0)
594593
output_lens = np.concatenate(output_lens_list, axis=0)
595-
return output_probs, output_lens, trans
594+
return output_probs, output_lens, trans_batch
596595

597596
def static_forward_offline(self, audio, audio_len):
598597
"""

0 commit comments

Comments
 (0)