Skip to content

Commit c832b9c

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into ctcdecoders
2 parents 624e86d + 4907628 commit c832b9c

File tree

4 files changed

+82
-61
lines changed

4 files changed

+82
-61
lines changed

paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,19 @@ def ctc_beam_search_decoding_batch(probs_split,
137137
return batch_beam_results
138138

139139

140-
class CTCBeamSearchDecoder(
141-
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
140+
class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
142141
"""Wrapper for CtcBeamSearchDecoderBatch.
143142
Args:
144-
vocab_list (list): [Vocabulary list.]
145-
beam_size (int): [Width for beam search.]
146-
num_processes (int): [Number of parallel processes.]
147-
param cutoff_prob (float): [Cutoff probability in vocabulary pruning,
148-
default 1.0, no pruning.]
149-
cutoff_top_n (int): [Cutoff number in pruning, only top cutoff_top_n
143+
vocab_list (list): Vocabulary list.
144+
beam_size (int): Width for beam search.
145+
num_processes (int): Number of parallel processes.
146+
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
147+
default 1.0, no pruning.
148+
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
150149
characters with highest probs in vocabulary will be
151-
used in beam search, default 40.]
152-
param ext_scorer (Scorer): [External scorer for partially decoded sentence, e.g. word count
153-
or language model.]
150+
used in beam search, default 40.
151+
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
152+
or language model.
154153
"""
155154

156155
def __init__(self, vocab_list, batch_size, beam_size, num_processes,

paddlespeech/s2t/models/ds2/deepspeech2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def decode(self, audio, audio_len):
174174
self.decoder.reset_decoder(batch_size=batch_size)
175175
self.decoder.next(probs, eouts_len)
176176
trans_best, trans_beam = self.decoder.decode()
177-
177+
178178
return trans_best
179179

180180
@classmethod

paddlespeech/s2t/modules/ctc.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ def init_decoder(self, batch_size, vocab_list, decoding_method,
252252
init ctc decoders
253253
Args:
254254
batch_size(int): Batch size for input data
255-
vocab_list (list): [List of tokens in the vocabulary, for decoding.]
256-
decoding_method (str): ["ctc_beam_search"]
257-
lang_model_path (str): [language model path]
258-
beam_alpha (float): [beam_alpha]
259-
beam_beta (float): [beam_beta]
260-
beam_size (int): [beam_size]
261-
cutoff_prob (float): [cutoff probability in beam search]
262-
cutoff_top_n (int): [cutoff_top_n]
263-
num_processes (int): [num_processes]
255+
vocab_list (list): List of tokens in the vocabulary, for decoding
256+
decoding_method (str): ctc_beam_search
257+
lang_model_path (str): language model path
258+
beam_alpha (float): beam_alpha
259+
beam_beta (float): beam_beta
260+
beam_size (int): beam_size
261+
cutoff_prob (float): cutoff probability in beam search
262+
cutoff_top_n (int): cutoff_top_n
263+
num_processes (int): num_processes
264264
265265
Raises:
266266
ValueError: when decoding_method not support.
@@ -299,15 +299,15 @@ def decode_probs_offline(self, probs, logits_lens, vocab_list,
299299
Args:
300300
probs (Tensor): activation after softmax
301301
logits_lens (Tensor): audio output lens
302-
vocab_list (list): [List of tokens in the vocabulary, for decoding.]
303-
decoding_method (str): ["ctc_beam_search"]
304-
lang_model_path (str): [language model path]
305-
beam_alpha (float): [beam_alpha]
306-
beam_beta (float): [beam_beta]
307-
beam_size (int): [beam_size]
308-
cutoff_prob (float): [cutoff probability in beam search]
309-
cutoff_top_n (int): [cutoff_top_n]
310-
num_processes (int): [num_processes]
302+
vocab_list (list): List of tokens in the vocabulary, for decoding
303+
decoding_method (str): ctc_beam_search
304+
lang_model_path (str): language model path
305+
beam_alpha (float): beam_alpha
306+
beam_beta (float): beam_beta
307+
beam_size (int): beam_size
308+
cutoff_prob (float): cutoff probability in beam search
309+
cutoff_top_n (int): cutoff_top_n
310+
num_processes (int): num_processes
311311
312312
Raises:
313313
ValueError: when decoding_method not support.
@@ -340,14 +340,14 @@ def get_decoder(self, vocab_list, batch_size, beam_alpha, beam_beta,
340340
"""
341341
init get ctc decoder
342342
Args:
343-
vocab_list (list): [List of tokens in the vocabulary, for decoding.]
343+
vocab_list (list): List of tokens in the vocabulary, for decoding.
344344
batch_size(int): Batch size for input data
345-
beam_alpha (float): [beam_alpha]
346-
beam_beta (float): [beam_beta]
347-
beam_size (int): [beam_size]
348-
num_processes (int): [num_processes]
349-
cutoff_prob (float): [cutoff probability in beam search]
350-
cutoff_top_n (int): [cutoff_top_n]
345+
beam_alpha (float): beam_alpha
346+
beam_beta (float): beam_beta
347+
beam_size (int): beam_size
348+
num_processes (int): num_processes
349+
cutoff_prob (float): cutoff probability in beam search
350+
cutoff_top_n (int): cutoff_top_n
351351
352352
Raises:
353353
ValueError: when decoding_method not support.
@@ -370,8 +370,8 @@ def next(self, probs, logits_lens):
370370
"""
371371
Input probs into ctc decoder
372372
Args:
373-
probs (list(list(float))): [probs for a batch of data]
374-
logits_lens (list(int)): [logits lens for a batch of data]
373+
probs (list(list(float))): probs for a batch of data
374+
logits_lens (list(int)): logits lens for a batch of data
375375
Raises:
376376
Exception: when the ctc decoder is not initialized
377377
ValueError: when decoding_method not support.
@@ -405,8 +405,8 @@ def decode(self):
405405
Exception: when the ctc decoder is not initialized
406406
ValueError: when decoding_method not support.
407407
Returns:
408-
results_best (list(str)): [The best result for a batch of data]
409-
results_beam (list(list(str))): [The beam search result for a batch of data]
408+
results_best (list(str)): The best result for a batch of data
409+
results_beam (list(list(str))): The beam search result for a batch of data
410410
"""
411411
if self.beam_search_decoder is None:
412412
raise Exception(
@@ -426,7 +426,12 @@ def decode(self):
426426

427427
return results_best, results_beam
428428

429-
def reset_decoder(self, batch_size=-1, beam_size=-1, num_processes=-1, cutoff_prob=-1.0, cutoff_top_n=-1):
429+
def reset_decoder(self,
430+
batch_size=-1,
431+
beam_size=-1,
432+
num_processes=-1,
433+
cutoff_prob=-1.0,
434+
cutoff_top_n=-1):
430435
if batch_size > 0:
431436
self.batch_size = batch_size
432437
if beam_size > 0:
@@ -445,7 +450,9 @@ def reset_decoder(self, batch_size=-1, beam_size=-1, num_processes=-1, cutoff_pr
445450
if self.beam_search_decoder is None:
446451
raise Exception(
447452
"You need to initialize the beam_search_decoder firstly")
448-
self.beam_search_decoder.reset_state(self.batch_size, self.beam_size, self.num_processes, self.cutoff_prob, self.cutoff_top_n)
453+
self.beam_search_decoder.reset_state(
454+
self.batch_size, self.beam_size, self.num_processes,
455+
self.cutoff_prob, self.cutoff_top_n)
449456

450457
def del_decoder(self):
451458
"""

paddlespeech/t2s/exps/synthesize_e2e.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def evaluate(args):
129129
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
130130
elif am_name == 'speedyspeech':
131131
am = am_class(
132-
vocab_size=vocab_size, tone_size=tone_size, **am_config["model"])
132+
vocab_size=vocab_size,
133+
tone_size=tone_size,
134+
spk_num=spk_num,
135+
**am_config["model"])
133136
elif am_name == 'tacotron2':
134137
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
135138

@@ -171,25 +174,31 @@ def evaluate(args):
171174
InputSpec([-1], dtype=paddle.int64),
172175
InputSpec([1], dtype=paddle.int64)
173176
])
174-
paddle.jit.save(am_inference,
175-
os.path.join(args.inference_dir, args.am))
176-
am_inference = paddle.jit.load(
177-
os.path.join(args.inference_dir, args.am))
178177
else:
179178
am_inference = jit.to_static(
180179
am_inference,
181180
input_spec=[InputSpec([-1], dtype=paddle.int64)])
182-
paddle.jit.save(am_inference,
183-
os.path.join(args.inference_dir, args.am))
184-
am_inference = paddle.jit.load(
185-
os.path.join(args.inference_dir, args.am))
181+
paddle.jit.save(am_inference,
182+
os.path.join(args.inference_dir, args.am))
183+
am_inference = paddle.jit.load(
184+
os.path.join(args.inference_dir, args.am))
186185
elif am_name == 'speedyspeech':
187-
am_inference = jit.to_static(
188-
am_inference,
189-
input_spec=[
190-
InputSpec([-1], dtype=paddle.int64),
191-
InputSpec([-1], dtype=paddle.int64)
192-
])
186+
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
187+
am_inference = jit.to_static(
188+
am_inference,
189+
input_spec=[
190+
InputSpec([-1], dtype=paddle.int64), # text
191+
InputSpec([-1], dtype=paddle.int64), # tone
192+
None, # duration
193+
InputSpec([-1], dtype=paddle.int64) # spk_id
194+
])
195+
else:
196+
am_inference = jit.to_static(
197+
am_inference,
198+
input_spec=[
199+
InputSpec([-1], dtype=paddle.int64),
200+
InputSpec([-1], dtype=paddle.int64)
201+
])
193202

194203
paddle.jit.save(am_inference,
195204
os.path.join(args.inference_dir, args.am))
@@ -242,7 +251,12 @@ def evaluate(args):
242251
mel = am_inference(part_phone_ids)
243252
elif am_name == 'speedyspeech':
244253
part_tone_ids = tone_ids[i]
245-
mel = am_inference(part_phone_ids, part_tone_ids)
254+
if am_dataset in {"aishell3", "vctk"}:
255+
spk_id = paddle.to_tensor(args.spk_id)
256+
mel = am_inference(part_phone_ids, part_tone_ids,
257+
spk_id)
258+
else:
259+
mel = am_inference(part_phone_ids, part_tone_ids)
246260
elif am_name == 'tacotron2':
247261
mel = am_inference(part_phone_ids)
248262
# vocoder
@@ -269,8 +283,9 @@ def main():
269283
type=str,
270284
default='fastspeech2_csmsc',
271285
choices=[
272-
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
273-
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc'
286+
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
287+
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
288+
'tacotron2_csmsc'
274289
],
275290
help='Choose acoustic model type of tts task.')
276291
parser.add_argument(

0 commit comments

Comments
 (0)