@@ -129,7 +129,10 @@ def evaluate(args):
129
129
idim = vocab_size , odim = odim , spk_num = spk_num , ** am_config ["model" ])
130
130
elif am_name == 'speedyspeech' :
131
131
am = am_class (
132
- vocab_size = vocab_size , tone_size = tone_size , ** am_config ["model" ])
132
+ vocab_size = vocab_size ,
133
+ tone_size = tone_size ,
134
+ spk_num = spk_num ,
135
+ ** am_config ["model" ])
133
136
elif am_name == 'tacotron2' :
134
137
am = am_class (idim = vocab_size , odim = odim , ** am_config ["model" ])
135
138
@@ -171,25 +174,31 @@ def evaluate(args):
171
174
InputSpec ([- 1 ], dtype = paddle .int64 ),
172
175
InputSpec ([1 ], dtype = paddle .int64 )
173
176
])
174
- paddle .jit .save (am_inference ,
175
- os .path .join (args .inference_dir , args .am ))
176
- am_inference = paddle .jit .load (
177
- os .path .join (args .inference_dir , args .am ))
178
177
else :
179
178
am_inference = jit .to_static (
180
179
am_inference ,
181
180
input_spec = [InputSpec ([- 1 ], dtype = paddle .int64 )])
182
- paddle .jit .save (am_inference ,
183
- os .path .join (args .inference_dir , args .am ))
184
- am_inference = paddle .jit .load (
185
- os .path .join (args .inference_dir , args .am ))
181
+ paddle .jit .save (am_inference ,
182
+ os .path .join (args .inference_dir , args .am ))
183
+ am_inference = paddle .jit .load (
184
+ os .path .join (args .inference_dir , args .am ))
186
185
elif am_name == 'speedyspeech' :
187
- am_inference = jit .to_static (
188
- am_inference ,
189
- input_spec = [
190
- InputSpec ([- 1 ], dtype = paddle .int64 ),
191
- InputSpec ([- 1 ], dtype = paddle .int64 )
192
- ])
186
+ if am_dataset in {"aishell3" , "vctk" } and args .speaker_dict :
187
+ am_inference = jit .to_static (
188
+ am_inference ,
189
+ input_spec = [
190
+ InputSpec ([- 1 ], dtype = paddle .int64 ), # text
191
+ InputSpec ([- 1 ], dtype = paddle .int64 ), # tone
192
+ None , # duration
193
+ InputSpec ([- 1 ], dtype = paddle .int64 ) # spk_id
194
+ ])
195
+ else :
196
+ am_inference = jit .to_static (
197
+ am_inference ,
198
+ input_spec = [
199
+ InputSpec ([- 1 ], dtype = paddle .int64 ),
200
+ InputSpec ([- 1 ], dtype = paddle .int64 )
201
+ ])
193
202
194
203
paddle .jit .save (am_inference ,
195
204
os .path .join (args .inference_dir , args .am ))
@@ -242,7 +251,12 @@ def evaluate(args):
242
251
mel = am_inference (part_phone_ids )
243
252
elif am_name == 'speedyspeech' :
244
253
part_tone_ids = tone_ids [i ]
245
- mel = am_inference (part_phone_ids , part_tone_ids )
254
+ if am_dataset in {"aishell3" , "vctk" }:
255
+ spk_id = paddle .to_tensor (args .spk_id )
256
+ mel = am_inference (part_phone_ids , part_tone_ids ,
257
+ spk_id )
258
+ else :
259
+ mel = am_inference (part_phone_ids , part_tone_ids )
246
260
elif am_name == 'tacotron2' :
247
261
mel = am_inference (part_phone_ids )
248
262
# vocoder
@@ -269,8 +283,9 @@ def main():
269
283
type = str ,
270
284
default = 'fastspeech2_csmsc' ,
271
285
choices = [
272
- 'speedyspeech_csmsc' , 'fastspeech2_csmsc' , 'fastspeech2_ljspeech' ,
273
- 'fastspeech2_aishell3' , 'fastspeech2_vctk' , 'tacotron2_csmsc'
286
+ 'speedyspeech_csmsc' , 'speedyspeech_aishell3' , 'fastspeech2_csmsc' ,
287
+ 'fastspeech2_ljspeech' , 'fastspeech2_aishell3' , 'fastspeech2_vctk' ,
288
+ 'tacotron2_csmsc'
274
289
],
275
290
help = 'Choose acoustic model type of tts task.' )
276
291
parser .add_argument (
0 commit comments