1414
1515import math
1616from collections import defaultdict
17- from typing import List , Optional
17+ from typing import List , Optional , Dict
1818
1919import torch
2020from torch .nn .utils .rnn import pad_sequence
2121
22- from wenet .utils .common import (add_sos_eos , log_add , WHISPER_LANGS ,
23- add_whisper_tokens )
22+ from wenet .utils .common import (add_sos_eos , log_add , add_whisper_tokens )
2423from wenet .utils .ctc_utils import remove_duplicates_and_blank
2524from wenet .utils .mask import (make_pad_mask , mask_finished_preds ,
2625 mask_finished_scores , subsequent_mask )
@@ -253,6 +252,7 @@ def attention_beam_search(
253252 encoder_mask : torch .Tensor ,
254253 beam_size : int = 10 ,
255254 length_penalty : float = 0.0 ,
255+ infos : Dict [str , List [str ]] = None ,
256256) -> List [DecodeResult ]:
257257 device = encoder_out .device
258258 batch_size = encoder_out .shape [0 ]
@@ -265,17 +265,20 @@ def attention_beam_search(
265265 running_size , maxlen , encoder_dim ) # (B*N, maxlen, encoder_dim)
266266 encoder_mask = encoder_mask .unsqueeze (1 ).repeat (1 , beam_size , 1 , 1 ).view (
267267 running_size , 1 , maxlen ) # (B*N, 1, max_len)
268-
269268 if getattr (model , 'special_tokens' , None ) is not None \
270269 and "transcribe" in model .special_tokens :
271- hyps = torch .ones ([running_size , 4 ], dtype = torch .long ,
272- device = device ) # (B*N, 4)
273- # TODO(xcsong): add args for language, task, etc
274- hyps [:, 0 ] = model .special_tokens ["sot" ]
275- hyps [:,
276- 1 ] = model .special_tokens ["sot" ] + 1 + WHISPER_LANGS .index ("zh" )
277- hyps [:, 2 ] = model .special_tokens ["transcribe" ]
278- hyps [:, 3 ] = model .special_tokens ["no_timestamps" ]
270+ tasks , langs = infos ["tasks" ], infos ["langs" ]
271+ tasks = [t for t in tasks for _ in range (beam_size )]
272+ langs = [l for l in langs for _ in range (beam_size )]
273+ hyps = torch .ones ([running_size , 0 ], dtype = torch .long ,
274+ device = device ) # (B*N, 0)
275+ hyps , _ = add_whisper_tokens (model .special_tokens ,
276+ hyps ,
277+ model .ignore_id ,
278+ tasks = tasks ,
279+ no_timestamp = True ,
280+ langs = langs ,
281+ use_prev = False )
279282 else :
280283 hyps = torch .ones ([running_size , 1 ], dtype = torch .long ,
281284 device = device ).fill_ (model .sos ) # (B*N, 1)
@@ -360,6 +363,7 @@ def attention_rescoring(
360363 encoder_lens : torch .Tensor ,
361364 ctc_weight : float = 0.0 ,
362365 reverse_weight : float = 0.0 ,
366+ infos : Dict [str , List [str ]] = None ,
363367) -> List [DecodeResult ]:
364368 """
365369 Args:
@@ -382,15 +386,15 @@ def attention_rescoring(
382386 dtype = torch .long ) # (beam_size,)
383387 if getattr (model , 'special_tokens' , None ) is not None \
384388 and "transcribe" in model .special_tokens :
385- # TODO(xcsong): add args for language, task, etc
386389 prev_len = hyps_pad .size (1 )
387- hyps_pad , _ = add_whisper_tokens (model .special_tokens ,
388- hyps_pad ,
389- model .ignore_id ,
390- task = "transcribe" ,
391- no_timestamp = True ,
392- language = "zh" ,
393- use_prev = False )
390+ hyps_pad , _ = add_whisper_tokens (
391+ model .special_tokens ,
392+ hyps_pad ,
393+ model .ignore_id ,
394+ tasks = [infos ["tasks" ][b ]] * len (hyps ),
395+ no_timestamp = True ,
396+ langs = [infos ["langs" ][b ]] * len (hyps ),
397+ use_prev = False )
394398 cur_len = hyps_pad .size (1 )
395399 hyps_lens = hyps_lens + cur_len - prev_len
396400 prefix_len = 4
0 commit comments