Skip to content

Commit a6a7f30

Browse files
authored
[whisper] support arbitrary language and task (#2342)
* [whisper] support arbitrary language and task * [whisper] try to pass ut * [whisper] try to pass ut * [whisper] limit languages * [whisper] limit languages * [whisper] limit languages * [whisper] try to pass ut * [whisper] try to pass ut
1 parent 7468156 commit a6a7f30

File tree

11 files changed

+132
-68
lines changed

11 files changed

+132
-68
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ deepspeed<0.13.0
2020
librosa
2121
openai-whisper
2222
pre-commit==3.5.0
23+
langid

test/wenet/dataset/test_datapipes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import torch
33
from torch.utils.data import datapipes
44
from torch.utils.data.datapipes.iter import IterableWrapper
5+
from functools import partial
56

67
from wenet.dataset.datapipes import (SortDataPipe, WenetRawDatasetSource,
78
WenetTarShardDatasetSource)
89
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding,
9-
parse_json, compute_fbank)
10+
parse_json, compute_fbank,
11+
detect_language, detect_task)
1012

1113

1214
@pytest.mark.parametrize("data_list", [
@@ -98,6 +100,8 @@ def test_dynamic_batch_datapipe(data_list):
98100
dataset = dataset.map(decode_wav)
99101
dataset = dataset.map(compute_fbank)
100102
dataset = dataset.map(fake_labels)
103+
dataset = dataset.map(partial(detect_language, limited_langs=['zh', 'en']))
104+
dataset = dataset.map(detect_task)
101105
max_frames_in_batch = 10000
102106
dataset = dataset.dynamic_batch(
103107
window_class=DynamicBatchWindow(max_frames_in_batch),

test/wenet/whisper/test_whisper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def test_sinusoids(length, channels):
9494
@pytest.mark.parametrize("model,audio_path", [
9595
("tiny", "test/resources/aishell-BAC009S0724W0121.wav"),
9696
("base", "test/resources/librispeech-1995-1837-0001.wav"),
97-
("small", "test/resources/aishell-BAC009S0724W0121.wav"),
98-
("medium", "test/resources/librispeech-1995-1837-0001.wav"),
9997
])
10098
def test_model(model, audio_path):
10199
default = os.path.join(os.path.expanduser("~"), ".cache")
@@ -362,9 +360,9 @@ def test_model(model, audio_path):
362360
configs['tokenizer_conf']['special_tokens'],
363361
torch.tensor([dummy_tokens], dtype=torch.long),
364362
ignore_id=-1,
365-
task=task,
363+
tasks=[task],
366364
no_timestamp=True,
367-
language=language,
365+
langs=[language],
368366
use_prev=False)
369367
L = wenet_tokens.size(1)
370368
tgt_mask = ~make_pad_mask(torch.tensor([L], dtype=torch.long),

wenet/bin/recognize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def main():
244244
target = batch["target"].to(device)
245245
feats_lengths = batch["feats_lengths"].to(device)
246246
target_lengths = batch["target_lengths"].to(device)
247+
infos = {"tasks": batch["tasks"], "langs": batch["langs"]}
247248
results = model.decode(
248249
args.modes,
249250
feats,
@@ -257,7 +258,8 @@ def main():
257258
context_graph=context_graph,
258259
blank_id=blank_id,
259260
blank_penalty=args.blank_penalty,
260-
length_penalty=args.length_penalty)
261+
length_penalty=args.length_penalty,
262+
infos=infos)
261263
for i, key in enumerate(keys):
262264
for mode, hyps in results.items():
263265
tokens = hyps[i].tokens

wenet/dataset/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def Dataset(data_type,
9494
spec_trim_conf = conf.get('spec_trim_conf', {})
9595
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))
9696

97+
language_conf = conf.get('language_conf', {"limited_langs": ['zh', 'en']})
98+
dataset = dataset.map(partial(processor.detect_language, **language_conf))
99+
dataset = dataset.map(processor.detect_task)
100+
97101
shuffle = conf.get('shuffle', True)
98102
if shuffle:
99103
shuffle_conf = conf.get('shuffle_conf', {})

wenet/dataset/processor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import json
1818
from subprocess import PIPE, Popen
1919
from urllib.parse import urlparse
20+
from langid.langid import LanguageIdentifier, model
21+
import logging
2022
import librosa
2123
import random
2224

@@ -31,6 +33,10 @@
3133

3234
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
3335

36+
lid = LanguageIdentifier.from_modelstring(model, norm_probs=True)
37+
38+
logging.getLogger('langid').setLevel(logging.INFO)
39+
3440

3541
class UrlOpenError(Exception):
3642

@@ -79,6 +85,28 @@ def parse_speaker(sample, speaker_dict):
7985
return sample
8086

8187

88+
def detect_language(sample, limited_langs):
89+
assert 'txt' in sample
90+
# NOTE(xcsong): Because language classification may not be very accurate
91+
# (for example, Chinese being classified as Japanese), our workaround,
92+
# given we know for certain that the training data only consists of
93+
# Chinese and English, is to limit the classification results to reduce
94+
# the impact of misclassification.
95+
lid.set_languages(limited_langs)
96+
# i.e., ('zh', 0.9999999909903544)
97+
sample['lang'] = lid.classify(sample['txt'])[0]
98+
return sample
99+
100+
101+
def detect_task(sample):
102+
# TODO(xcsong): Currently, the task is hard-coded to 'transcribe'.
103+
# In the future, we could dynamically determine the task based on
104+
# the contents of sample. For instance, if a sample contains both
105+
# 'txt_en' and 'txt_zh', the task should be set to 'translate'.
106+
sample['task'] = "transcribe"
107+
return sample
108+
109+
82110
def decode_wav(sample):
83111
""" Parse key/wav/txt from json line
84112
@@ -457,6 +485,8 @@ def padding(data):
457485
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
458486
]
459487
sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
488+
langs = [sample[i]['lang'] for i in order]
489+
tasks = [sample[i]['task'] for i in order]
460490
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
461491
dtype=torch.int32)
462492
wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
@@ -477,6 +507,8 @@ def padding(data):
477507
"target_lengths": label_lengths,
478508
"pcm": padded_wavs,
479509
"pcm_length": wav_lengths,
510+
"langs": langs,
511+
"tasks": tasks,
480512
}
481513
if 'speaker' in sample[0]:
482514
speaker = torch.tensor([sample[i]['speaker'] for i in order],

wenet/paraformer/paraformer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,14 @@ def forward(
218218
}
219219

220220
def _calc_att_loss(
221-
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
222-
ys_pad: torch.Tensor, ys_pad_emb: torch.Tensor,
223-
ys_pad_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221+
self,
222+
encoder_out: torch.Tensor,
223+
encoder_mask: torch.Tensor,
224+
ys_pad: torch.Tensor,
225+
ys_pad_emb: torch.Tensor,
226+
ys_pad_lens: torch.Tensor,
227+
infos: Dict[str, List[str]] = None,
228+
) -> Tuple[torch.Tensor, torch.Tensor]:
224229
decoder_out, _, _ = self.decoder(encoder_out, encoder_mask, ys_pad_emb,
225230
ys_pad_lens)
226231
loss_att = self.criterion_att(decoder_out, ys_pad)

wenet/transformer/asr_model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def forward(
110110
encoder_out, encoder_mask = self.filter_blank_embedding(
111111
ctc_probs, encoder_out)
112112
if self.ctc_weight != 1.0:
113-
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
114-
text, text_lengths)
113+
loss_att, acc_att = self._calc_att_loss(
114+
encoder_out, encoder_mask, text, text_lengths, {
115+
"langs": batch["langs"],
116+
"tasks": batch["tasks"]
117+
})
115118
else:
116119
loss_att = None
117120
acc_att = None
@@ -174,6 +177,7 @@ def _calc_att_loss(
174177
encoder_mask: torch.Tensor,
175178
ys_pad: torch.Tensor,
176179
ys_pad_lens: torch.Tensor,
180+
infos: Dict[str, List[str]] = None,
177181
) -> Tuple[torch.Tensor, torch.Tensor]:
178182
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
179183
self.ignore_id)
@@ -256,6 +260,7 @@ def decode(
256260
blank_id: int = 0,
257261
blank_penalty: float = 0.0,
258262
length_penalty: float = 0.0,
263+
infos: Dict[str, List[str]] = None,
259264
) -> Dict[str, List[DecodeResult]]:
260265
""" Decode input speech
261266
@@ -292,7 +297,8 @@ def decode(
292297
results = {}
293298
if 'attention' in methods:
294299
results['attention'] = attention_beam_search(
295-
self, encoder_out, encoder_mask, beam_size, length_penalty)
300+
self, encoder_out, encoder_mask, beam_size, length_penalty,
301+
infos)
296302
if 'ctc_greedy_search' in methods:
297303
results['ctc_greedy_search'] = ctc_greedy_search(
298304
ctc_probs, encoder_lens, blank_id)
@@ -314,7 +320,7 @@ def decode(
314320
ctc_probs, encoder_out)
315321
results['attention_rescoring'] = attention_rescoring(
316322
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
317-
reverse_weight)
323+
reverse_weight, infos)
318324
return results
319325

320326
@torch.jit.export

wenet/transformer/search.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414

1515
import math
1616
from collections import defaultdict
17-
from typing import List, Optional
17+
from typing import List, Optional, Dict
1818

1919
import torch
2020
from torch.nn.utils.rnn import pad_sequence
2121

22-
from wenet.utils.common import (add_sos_eos, log_add, WHISPER_LANGS,
23-
add_whisper_tokens)
22+
from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens)
2423
from wenet.utils.ctc_utils import remove_duplicates_and_blank
2524
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
2625
mask_finished_scores, subsequent_mask)
@@ -253,6 +252,7 @@ def attention_beam_search(
253252
encoder_mask: torch.Tensor,
254253
beam_size: int = 10,
255254
length_penalty: float = 0.0,
255+
infos: Dict[str, List[str]] = None,
256256
) -> List[DecodeResult]:
257257
device = encoder_out.device
258258
batch_size = encoder_out.shape[0]
@@ -265,17 +265,20 @@ def attention_beam_search(
265265
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
266266
encoder_mask = encoder_mask.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
267267
running_size, 1, maxlen) # (B*N, 1, max_len)
268-
269268
if getattr(model, 'special_tokens', None) is not None \
270269
and "transcribe" in model.special_tokens:
271-
hyps = torch.ones([running_size, 4], dtype=torch.long,
272-
device=device) # (B*N, 4)
273-
# TODO(xcsong): add args for language, task, etc
274-
hyps[:, 0] = model.special_tokens["sot"]
275-
hyps[:,
276-
1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh")
277-
hyps[:, 2] = model.special_tokens["transcribe"]
278-
hyps[:, 3] = model.special_tokens["no_timestamps"]
270+
tasks, langs = infos["tasks"], infos["langs"]
271+
tasks = [t for t in tasks for _ in range(beam_size)]
272+
langs = [l for l in langs for _ in range(beam_size)]
273+
hyps = torch.ones([running_size, 0], dtype=torch.long,
274+
device=device) # (B*N, 0)
275+
hyps, _ = add_whisper_tokens(model.special_tokens,
276+
hyps,
277+
model.ignore_id,
278+
tasks=tasks,
279+
no_timestamp=True,
280+
langs=langs,
281+
use_prev=False)
279282
else:
280283
hyps = torch.ones([running_size, 1], dtype=torch.long,
281284
device=device).fill_(model.sos) # (B*N, 1)
@@ -360,6 +363,7 @@ def attention_rescoring(
360363
encoder_lens: torch.Tensor,
361364
ctc_weight: float = 0.0,
362365
reverse_weight: float = 0.0,
366+
infos: Dict[str, List[str]] = None,
363367
) -> List[DecodeResult]:
364368
"""
365369
Args:
@@ -382,15 +386,15 @@ def attention_rescoring(
382386
dtype=torch.long) # (beam_size,)
383387
if getattr(model, 'special_tokens', None) is not None \
384388
and "transcribe" in model.special_tokens:
385-
# TODO(xcsong): add args for language, task, etc
386389
prev_len = hyps_pad.size(1)
387-
hyps_pad, _ = add_whisper_tokens(model.special_tokens,
388-
hyps_pad,
389-
model.ignore_id,
390-
task="transcribe",
391-
no_timestamp=True,
392-
language="zh",
393-
use_prev=False)
390+
hyps_pad, _ = add_whisper_tokens(
391+
model.special_tokens,
392+
hyps_pad,
393+
model.ignore_id,
394+
tasks=[infos["tasks"][b]] * len(hyps),
395+
no_timestamp=True,
396+
langs=[infos["langs"][b]] * len(hyps),
397+
use_prev=False)
394398
cur_len = hyps_pad.size(1)
395399
hyps_lens = hyps_lens + cur_len - prev_len
396400
prefix_len = 4

0 commit comments

Comments
 (0)