Skip to content

Commit bff7dd0

Browse files
committed
only divide by MAX_WAV_VALUE if int16, pad mel/audio if smaller than segment_size, switch tqdm, crop mel/wav in train if necessary
1 parent 455f3f3 commit bff7dd0

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

meldataset.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ def __getitem__(self, index):
247247
filename = self.audio_files[index]
248248
if self._cache_ref_count == 0:
249249
audio, sampling_rate = load_wav(filename, self.sampling_rate)
250-
audio = audio / MAX_WAV_VALUE
250+
if np.abs(audio).max() > 1:
251+
audio = audio / MAX_WAV_VALUE
251252
if not self.fine_tuning:
252253
audio = normalize(audio) * 0.95
253254
self.cached_wav = audio
@@ -328,13 +329,13 @@ def __getitem__(self, index):
328329
* self.hop_size : (mel_start + frames_per_seg)
329330
* self.hop_size,
330331
]
331-
else:
332-
mel = torch.nn.functional.pad(
333-
mel, (0, frames_per_seg - mel.size(2)), "constant"
334-
)
335-
audio = torch.nn.functional.pad(
336-
audio, (0, self.segment_size - audio.size(1)), "constant"
337-
)
332+
333+
mel = torch.nn.functional.pad(
334+
mel, (0, frames_per_seg - mel.size(2)), "constant"
335+
)
336+
audio = torch.nn.functional.pad(
337+
audio, (0, self.segment_size - audio.size(1)), "constant"
338+
)
338339

339340
mel_loss = mel_spectrogram(
340341
audio,

train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def validate(rank, a, h, loader, mode="seen"):
308308
print(f"step {steps} {mode} speaker validation...")
309309

310310
# Loop over validation set and compute metrics
311-
for j, batch in tqdm(enumerate(loader)):
311+
for j, batch in enumerate(tqdm(loader)):
312312
x, y, _, y_mel = batch
313313
y = y.to(device)
314314
if hasattr(generator, "module"):
@@ -326,7 +326,8 @@ def validate(rank, a, h, loader, mode="seen"):
326326
h.fmin,
327327
h.fmax_for_loss,
328328
)
329-
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
329+
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
330+
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
330331

331332
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
332333
if (
@@ -343,7 +344,8 @@ def validate(rank, a, h, loader, mode="seen"):
343344
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
344345

345346
# MRSTFT calculation
346-
val_mrstft_tot += loss_mrstft(y_g_hat, y).item()
347+
min_t = min(y.size(-1), y_g_hat.size(-1))
348+
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
347349

348350
# Log audio and figures to Tensorboard
349351
if j % a.eval_subsample == 0: # Subsample every nth from validation set

0 commit comments

Comments
 (0)