Skip to content

Commit 97ca0da

Browse files
yt605155624luotao1
authored andcommitted
[TTS]add StarGANv2VC preprocess (PaddlePaddle#3163)
1 parent 92f4213 commit 97ca0da

File tree

8 files changed

+485
-23
lines changed

8 files changed

+485
-23
lines changed

examples/vctk/vc3/conf/default.yaml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
###########################################################
22
# FEATURE EXTRACTION SETTING #
33
###########################################################
4-
# 其实没用上,其实用的是 16000
5-
sr: 24000
4+
# 源码 load 的时候用的 24k, 提取 mel 用的 16k, 后续 load 和提取 mel 都要改成 24k
5+
fs: 16000
66
n_fft: 2048
7-
win_length: 1200
8-
hop_length: 300
7+
n_shift: 300
8+
win_length: 1200 # Window length.(in samples) 50ms
9+
# If set to null, it will be the same as fft_size.
10+
window: "hann" # Window function.
11+
12+
fmin: 0 # Minimum frequency of Mel basis.
13+
fmax: 8000 # Maximum frequency of Mel basis. sr // 2
914
n_mels: 80
15+
# only for StarGANv2 VC
16+
norm: # None here
17+
htk: True
18+
power: 2.0
19+
20+
1021
###########################################################
1122
# MODEL SETTING #
1223
###########################################################

examples/vctk/vc3/local/preprocess.sh

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,32 @@ stop_stage=100
66
config_path=$1
77

88
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
9+
# extract features
10+
echo "Extract features ..."
11+
python3 ${BIN_DIR}/preprocess.py \
12+
--dataset=vctk \
13+
--rootdir=~/datasets/VCTK-Corpus-0.92/ \
14+
--dumpdir=dump \
15+
--config=${config_path} \
16+
--num-cpu=20
917

1018
fi
1119

1220
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
13-
14-
fi
15-
16-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
21+
echo "Normalize ..."
22+
python3 ${BIN_DIR}/normalize.py \
23+
--metadata=dump/train/raw/metadata.jsonl \
24+
--dumpdir=dump/train/norm \
25+
--speaker-dict=dump/speaker_id_map.txt
26+
27+
python3 ${BIN_DIR}/normalize.py \
28+
--metadata=dump/dev/raw/metadata.jsonl \
29+
--dumpdir=dump/dev/norm \
30+
--speaker-dict=dump/speaker_id_map.txt
31+
32+
python3 ${BIN_DIR}/normalize.py \
33+
--metadata=dump/test/raw/metadata.jsonl \
34+
--dumpdir=dump/test/norm \
35+
--speaker-dict=dump/speaker_id_map.txt
1736

1837
fi

paddlespeech/t2s/datasets/am_batch_fn.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -669,18 +669,72 @@ def vits_multi_spk_batch_fn(examples):
669669
return batch
670670

671671

672-
# 未完成
673-
def starganv2_vc_batch_fn(examples):
674-
batch = {
675-
"x_real": None,
676-
"y_org": None,
677-
"x_ref": None,
678-
"x_ref2": None,
679-
"y_trg": None,
680-
"z_trg": None,
681-
"z_trg2": None,
682-
}
683-
return batch
672+
# 因为要传参数,所以需要额外构建
673+
def build_starganv2_vc_collate_fn(latent_dim: int=16, max_mel_length: int=192):
674+
675+
return StarGANv2VCCollateFn(
676+
latent_dim=latent_dim, max_mel_length=max_mel_length)
677+
678+
679+
class StarGANv2VCCollateFn:
680+
"""Functor class of common_collate_fn()"""
681+
682+
def __init__(self, latent_dim: int=16, max_mel_length: int=192):
683+
self.latent_dim = latent_dim
684+
self.max_mel_length = max_mel_length
685+
686+
def random_clip(self, mel: np.array):
687+
# [80, T]
688+
mel_length = mel.shape[1]
689+
if mel_length > self.max_mel_length:
690+
random_start = np.random.randint(0,
691+
mel_length - self.max_mel_length)
692+
mel = mel[:, random_start:random_start + self.max_mel_length]
693+
return mel
694+
695+
def __call__(self, exmaples):
696+
return self.starganv2_vc_batch_fn(exmaples)
697+
698+
def starganv2_vc_batch_fn(self, examples):
699+
batch_size = len(examples)
700+
701+
label = [np.array(item["label"], dtype=np.int64) for item in examples]
702+
ref_label = [
703+
np.array(item["ref_label"], dtype=np.int64) for item in examples
704+
]
705+
706+
# 需要对 mel 进行裁剪
707+
mel = [self.random_clip(item["mel"]) for item in examples]
708+
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
709+
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]
710+
711+
mel = batch_sequences(mel)
712+
ref_mel = batch_sequences(ref_mel)
713+
ref_mel_2 = batch_sequences(ref_mel_2)
714+
715+
# convert each batch to paddle.Tensor
716+
# (B,)
717+
label = paddle.to_tensor(label)
718+
ref_label = paddle.to_tensor(ref_label)
719+
# [B, 80, T] -> [B, 1, 80, T]
720+
mel = paddle.to_tensor(mel)
721+
ref_mel = paddle.to_tensor(ref_mel)
722+
ref_mel_2 = paddle.to_tensor(ref_mel_2)
723+
724+
z_trg = paddle.randn(batch_size, self.latent_dim)
725+
z_trg2 = paddle.randn(batch_size, self.latent_dim)
726+
727+
batch = {
728+
"x_real": mels,
729+
"y_org": labels,
730+
"x_ref": ref_mels,
731+
"x_ref2": ref_mels_2,
732+
"y_trg": ref_labels,
733+
"z_trg": z_trg,
734+
"z_trg2": z_trg2
735+
}
736+
737+
return batch
684738

685739

686740
# for PaddleSlim

paddlespeech/t2s/datasets/data_table.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import random
1415
from multiprocessing import Manager
1516
from typing import Any
1617
from typing import Callable
1718
from typing import Dict
1819
from typing import List
1920

21+
import numpy as np
2022
from paddle.io import Dataset
2123

2224

@@ -131,3 +133,54 @@ def __len__(self) -> int:
131133
The length of the dataset
132134
"""
133135
return len(self.data)
136+
137+
138+
class StarGANv2VCDataTable(DataTable):
139+
def __init__(self, data: List[Dict[str, Any]]):
140+
super().__init__(data)
141+
raw_data = data
142+
spk_id_set = list(set([item['spk_id'] for item in raw_data]))
143+
data_list_per_class = {}
144+
for spk_id in spk_id_set:
145+
data_list_per_class[spk_id] = []
146+
for item in raw_data:
147+
for spk_id in spk_id_set:
148+
if item['spk_id'] == spk_id:
149+
data_list_per_class[spk_id].append(item)
150+
self.data_list_per_class = data_list_per_class
151+
152+
def __getitem__(self, idx: int) -> Dict[str, Any]:
153+
"""Get an example given an index.
154+
Args:
155+
idx (int): Index of the example to get
156+
157+
Returns:
158+
Dict[str, Any]: A converted example
159+
"""
160+
if self.use_cache and self.caches[idx] is not None:
161+
return self.caches[idx]
162+
163+
data = self._get_metadata(idx)
164+
165+
# 裁剪放到 batch_fn 里面
166+
# 返回一个字典
167+
"""
168+
{'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'}
169+
"""
170+
ref_data = random.choice(self.data)
171+
ref_label = ref_data['spk_id']
172+
ref_data_2 = random.choice(self.data_list_per_class[ref_label])
173+
# mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label
174+
new_example = {
175+
'utt_id': data['utt_id'],
176+
'mel': np.load(data['speech']),
177+
'label': int(data['spk_id']),
178+
'ref_mel': np.load(ref_data['speech']),
179+
'ref_mel_2': np.load(ref_data_2['speech']),
180+
'ref_label': int(ref_label)
181+
}
182+
183+
if self.use_cache:
184+
self.caches[idx] = new_example
185+
186+
return new_example
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Normalize feature files and dump them."""
15+
import argparse
16+
import logging
17+
from operator import itemgetter
18+
from pathlib import Path
19+
20+
import jsonlines
21+
import numpy as np
22+
import tqdm
23+
24+
from paddlespeech.t2s.datasets.data_table import DataTable
25+
26+
27+
def main():
28+
"""Run preprocessing process."""
29+
parser = argparse.ArgumentParser(
30+
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
31+
)
32+
parser.add_argument(
33+
"--metadata",
34+
type=str,
35+
required=True,
36+
help="directory including feature files to be normalized. "
37+
"you need to specify either *-scp or rootdir.")
38+
39+
parser.add_argument(
40+
"--dumpdir",
41+
type=str,
42+
required=True,
43+
help="directory to dump normalized feature files.")
44+
45+
parser.add_argument(
46+
"--speaker-dict", type=str, default=None, help="speaker id map file.")
47+
48+
args = parser.parse_args()
49+
50+
dumpdir = Path(args.dumpdir).expanduser()
51+
# use absolute path
52+
dumpdir = dumpdir.resolve()
53+
dumpdir.mkdir(parents=True, exist_ok=True)
54+
55+
# get dataset
56+
with jsonlines.open(args.metadata, 'r') as reader:
57+
metadata = list(reader)
58+
dataset = DataTable(
59+
metadata, converters={
60+
"speech": np.load,
61+
})
62+
logging.info(f"The number of files = {len(dataset)}.")
63+
64+
vocab_speaker = {}
65+
with open(args.speaker_dict, 'rt') as f:
66+
spk_id = [line.strip().split() for line in f.readlines()]
67+
for spk, id in spk_id:
68+
vocab_speaker[spk] = int(id)
69+
70+
# process each file
71+
output_metadata = []
72+
73+
for item in tqdm.tqdm(dataset):
74+
utt_id = item['utt_id']
75+
speech = item['speech']
76+
77+
# normalize
78+
# 这里暂时写死
79+
mean, std = -4, 4
80+
speech = (speech - mean) / std
81+
speech_path = dumpdir / f"{utt_id}_speech.npy"
82+
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
83+
84+
spk_id = vocab_speaker[item["speaker"]]
85+
record = {
86+
"utt_id": item['utt_id'],
87+
"spk_id": spk_id,
88+
"speech": str(speech_path),
89+
}
90+
91+
output_metadata.append(record)
92+
output_metadata.sort(key=itemgetter('utt_id'))
93+
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
94+
with jsonlines.open(output_metadata_path, 'w') as writer:
95+
for item in output_metadata:
96+
writer.write(item)
97+
logging.info(f"metadata dumped into {output_metadata_path}")
98+
99+
100+
if __name__ == "__main__":
101+
main()

0 commit comments

Comments
 (0)