@@ -413,19 +413,19 @@ def test(self):
413
413
414
414
def compute_result_transcripts (self , audio , audio_len ):
415
415
if self .args .model_type == "online" :
416
- output_probs , output_lens , trans = self .static_forward_online (
416
+ output_probs , output_lens , trans_batch = self .static_forward_online (
417
417
audio , audio_len , decoder_chunk_size = 1 )
418
- result_transcripts = trans [- 1 : ]
418
+ result_transcripts = [ trans [- 1 ] for trans in trans_batch ]
419
419
elif self .args .model_type == "offline" :
420
- batch_size = output_probs .shape [0 ]
421
- self .model .decoder .reset_decoder (batch_size = batch_size )
422
420
output_probs , output_lens = self .static_forward_offline (audio ,
423
421
audio_len )
422
+ batch_size = output_probs .shape [0 ]
423
+ self .model .decoder .reset_decoder (batch_size = batch_size )
424
424
425
425
self .model .decoder .next (output_probs , output_lens )
426
426
427
427
trans_best , trans_beam = self .model .decoder .decode ()
428
-
428
+
429
429
result_transcripts = trans_best
430
430
431
431
else :
@@ -485,7 +485,7 @@ def static_forward_online(self, audio, audio_len,
485
485
x_list = np .split (x_batch , batch_size , axis = 0 )
486
486
x_len_list = np .split (x_len_batch , batch_size , axis = 0 )
487
487
488
- trans = []
488
+ trans_batch = []
489
489
for x , x_len in zip (x_list , x_len_list ):
490
490
if self .args .enable_auto_log is True :
491
491
self .autolog .times .start ()
@@ -518,14 +518,14 @@ def static_forward_online(self, audio, audio_len,
518
518
h_box_handle = self .predictor .get_input_handle (input_names [2 ])
519
519
c_box_handle = self .predictor .get_input_handle (input_names [3 ])
520
520
521
- trans_chunk_list = []
521
+ trans = []
522
522
probs_chunk_list = []
523
523
probs_chunk_lens_list = []
524
524
if self .args .enable_auto_log is True :
525
525
# record the model preprocessing time
526
526
self .autolog .times .stamp ()
527
-
528
- self .model .decoder .reset_decoder (batch_size = 1 )
527
+
528
+ self .model .decoder .reset_decoder (batch_size = 1 )
529
529
for i in range (0 , num_chunk ):
530
530
start = i * chunk_stride
531
531
end = start + chunk_size
@@ -569,8 +569,7 @@ def static_forward_online(self, audio, audio_len,
569
569
probs_chunk_lens_list .append (output_chunk_lens )
570
570
trans_best , trans_beam = self .model .decoder .decode ()
571
571
trans .append (trans_best [0 ])
572
-
573
-
572
+ trans_batch .append (trans )
574
573
output_probs = np .concatenate (probs_chunk_list , axis = 1 )
575
574
output_lens = np .sum (probs_chunk_lens_list , axis = 0 )
576
575
vocab_size = output_probs .shape [2 ]
@@ -592,7 +591,7 @@ def static_forward_online(self, audio, audio_len,
592
591
self .autolog .times .end ()
593
592
output_probs = np .concatenate (output_probs_list , axis = 0 )
594
593
output_lens = np .concatenate (output_lens_list , axis = 0 )
595
- return output_probs , output_lens , trans
594
+ return output_probs , output_lens , trans_batch
596
595
597
596
def static_forward_offline (self , audio , audio_len ):
598
597
"""
0 commit comments