Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
16108de
add voxceleb1 dataset prepare process
LeoMax-Xiong Feb 24, 2022
35b7968
remove invalid directory
LeoMax-Xiong Feb 25, 2022
6f7e965
add kaldi feats ark dataset
LeoMax-Xiong Feb 25, 2022
d7da629
add kaldi feats egs dataset
LeoMax-Xiong Feb 26, 2022
70d3b01
remove invalid code
LeoMax-Xiong Feb 26, 2022
1395b5f
Merge branch 'PaddlePaddle:develop' into develop
LeoMax-Xiong Mar 2, 2022
7ef60eb
add voxceleb1 data prepare
LeoMax-Xiong Mar 2, 2022
0780d18
remove personal code test=doc
LeoMax-Xiong Mar 2, 2022
3a943ca
repair the variable name bug
LeoMax-Xiong Mar 2, 2022
dc28ebe
move the csv vox format to paddleaudio, test=doc
LeoMax-Xiong Mar 3, 2022
57c4f4a
add sid learning rate and training model
LeoMax-Xiong Mar 3, 2022
6af2bc3
add sid loss wraper for voxceleb, test=doc
LeoMax-Xiong Mar 3, 2022
7668f61
add sid dataloader for training, test=doc
LeoMax-Xiong Mar 3, 2022
4648059
add training process for sid, test=doc
LeoMax-Xiong Mar 3, 2022
1f74af1
add training log info and comment, test=doc
LeoMax-Xiong Mar 3, 2022
97ec012
add speaker verification using cosine score, test=doc
LeoMax-Xiong Mar 4, 2022
016ed6d
repair the code according to the part comment, test=doc
LeoMax-Xiong Mar 4, 2022
ac4967e
optimize the data prepare process
LeoMax-Xiong Mar 6, 2022
2d89c80
add waveform augment pipeline, test=doc
LeoMax-Xiong Mar 7, 2022
7db7eb8
add extract audio embedding api, test=doc
LeoMax-Xiong Mar 7, 2022
386ef3f
add voxceleb augment unit test, test=doc
LeoMax-Xiong Mar 8, 2022
14efbf5
check extract embedding result, test=doc
LeoMax-Xiong Mar 8, 2022
60d73bb
add state 0 to prepare the voxcele data and augment data
LeoMax-Xiong Mar 9, 2022
0dee8f4
Merge branch 'PaddlePaddle:develop' into develop
LeoMax-Xiong Mar 9, 2022
4473405
merge develop to vox12, test=doc
LeoMax-Xiong Mar 9, 2022
0e87037
refactor to compilance paddleaudio
LeoMax-Xiong Mar 9, 2022
993d678
remove unused code, test=doc
LeoMax-Xiong Mar 9, 2022
584a2c0
add ecapa-tdnn config yaml file
LeoMax-Xiong Mar 9, 2022
8ed5c28
add vox2 data into VoxCeleb class
LeoMax-Xiong Mar 10, 2022
311fa87
add some comments to the code
LeoMax-Xiong Mar 13, 2022
7eb8fa7
convert save_freq to save_interval, test=doc
LeoMax-Xiong Mar 13, 2022
506d26a
change the code style to s2t code style, test=doc
LeoMax-Xiong Mar 14, 2022
d28ccfa
add vector cli component, test=doc
LeoMax-Xiong Mar 20, 2022
9c6735f
add vector voxceleb12 base mode url, test=doc
LeoMax-Xiong Mar 21, 2022
b9eafdd
change - to _ to distinguish field
LeoMax-Xiong Mar 21, 2022
9874fb7
add some comments in code
LeoMax-Xiong Mar 22, 2022
d85d1de
exec pre-commit in paddlespeech vector, test=doc
LeoMax-Xiong Mar 22, 2022
5221c27
add voxceleb dataset and trial info, test=doc
LeoMax-Xiong Mar 23, 2022
e2684e7
refactor the data prepare process
LeoMax-Xiong Mar 23, 2022
62cbce6
add vectorwrapper to extract audio embedding
LeoMax-Xiong Mar 24, 2022
0bb67d8
add vector cli unit test, test=doc
LeoMax-Xiong Mar 24, 2022
305bacd
Merge branch 'develop' into vox12
LeoMax-Xiong Mar 24, 2022
0f78d25
add vector cli batch and pipeline test demo, test=doc
LeoMax-Xiong Mar 24, 2022
3054659
remove debug info, test=doc
LeoMax-Xiong Mar 24, 2022
faf6b8d
add the vec cli test audio name, test=doc
LeoMax-Xiong Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions examples/voxceleb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,6 @@ VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech,
ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s
```

``` shell
# copy this to root directory of data and
# chmod a+x convert.sh
# ./convert.sh
# https://unix.stackexchange.com/questions/103920/parallelize-a-bash-for-loop

open_sem(){
mkfifo pipe-$$
exec 3<>pipe-$$
rm pipe-$$
local i=$1
for((;i>0;i--)); do
printf %s 000 >&3
done
}
run_with_lock(){
local x
read -u 3 -n 3 x && ((0==x)) || exit $x
(
( "$@"; )
printf '%.3d' $? >&3
)&
}

N=32 # number of vCPU
open_sem $N
for f in $(find . -name "*.m4a"); do
run_with_lock ffmpeg -loglevel panic -i "$f" -ar 16000 "${f%.*}.wav"
done
```

You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once.

3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`)

4.
33 changes: 28 additions & 5 deletions examples/voxceleb/sv0/local/data_prepare.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,51 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import numpy as np
import paddle

from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything

logger = Log(__name__).getlog()


def main(args):

# stage0: set the cpu device, all data prepare process will be done in cpu mode
paddle.set_device("cpu")
# set the random seed, it is a must for multiprocess training
seed_everything(args.seed)

# stage 1: generate the voxceleb csv file
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
train_dataset = VoxCeleb1('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir)
# so we ignore the execption
# we explicitly pass the vox2 base path to data prepare and generate the audio info
train_dataset = VoxCeleb(
'train', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path)
dev_dataset = VoxCeleb(
'dev', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path)

# stage 2: generate the augment noise csv file
if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)


if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
Expand All @@ -38,10 +57,14 @@ def main(args):
default="./data/",
type=str,
help="data directory")
parser.add_argument("--vox2-base-path",
default=None,
type=str,
help="vox2 base path, where is store the wav audio")
parser.add_argument("--augment",
action="store_true",
default=False,
help="Apply audio augments.")
args = parser.parse_args()
# yapf: enable
main(args)
main(args)
16 changes: 15 additions & 1 deletion examples/voxceleb/sv0/path.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export MAIN_ROOT=`realpath ${PWD}/../../../`

export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
Expand All @@ -10,5 +24,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/

MODEL=ecapa-tdnn
MODEL=ecapa_tdnn
export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}
34 changes: 29 additions & 5 deletions examples/voxceleb/sv0/run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

. ./path.sh
set -e
Expand All @@ -11,19 +24,30 @@ set -e
# stage 3: extract the training embeding to train the LDA and PLDA
######################################################################

# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset
# default the dataset is the ~/.paddleaudio/
# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
# default the dataset will be stored in the ~/.paddleaudio/
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME=

stage=0
dir=data.bak/ # data directory
exp_dir=exp/ecapa-tdnn/ # experiment directory
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
dir=data/
exp_dir=exp/ecapa-tdnn/ # experiment directory

# vox2 wav path, we must convert the m4a format to wav format
# and store them in the ${PPAUDIO_HOME}/datasets/vox2/wav/ directory
vox2_base_path=${PPAUDIO_HOME}/datasets/vox2/wav/
mkdir -p ${dir}
mkdir -p ${exp_dir}

if [ $stage -le 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3 local/data_prepare.py --data-dir ${dir} --augment
python3 local/data_prepare.py \
--data-dir ${dir} --augment --vox2-base-path ${vox2_base_path}
fi

if [ $stage -le 1 ]; then
Expand Down
2 changes: 1 addition & 1 deletion paddleaudio/paddleaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from .gtzan import GTZAN
from .tess import TESS
from .urban_sound import UrbanSound8K
from .voxceleb import VoxCeleb1
from .voxceleb import VoxCeleb
from .rirs_noises import OpenRIRNoise
44 changes: 31 additions & 13 deletions paddleaudio/paddleaudio/datasets/voxceleb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
from pathos.multiprocessing import Pool
from tqdm import tqdm

from .dataset import feat_funcs
from ..backends import load as load_audio
from ..utils import DATA_HOME
from ..utils import decompress
from .dataset import feat_funcs
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.download import download_and_decompress
from utils.utility import download
from utils.utility import unpack

logger = Log(__name__).getlog()

__all__ = ['VoxCeleb1']
__all__ = ['VoxCeleb']


class VoxCeleb1(Dataset):
class VoxCeleb(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [
{
Expand Down Expand Up @@ -94,8 +94,18 @@ def __init__(
split_ratio: float=0.9, # train split ratio
seed: int=0,
target_dir: str=None,
vox2_base_path=None,
**kwargs):

"""VoxCeleb data prepare and get the specific dataset audio info

Args:
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
"""
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)

Expand All @@ -106,19 +116,20 @@ def __init__(
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else VoxCeleb1.base_path
self.target_dir = target_dir if target_dir else VoxCeleb.base_path
self.vox2_base_path = vox2_base_path

# if we set the target dir, we will change the vox data info data from base path to target dir
VoxCeleb1.csv_path = os.path.join(
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb1.csv_path
VoxCeleb1.meta_path = os.path.join(
VoxCeleb.csv_path = os.path.join(
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
VoxCeleb.meta_path = os.path.join(
target_dir, "voxceleb",
'meta') if target_dir else VoxCeleb1.meta_path
VoxCeleb1.veri_test_file = os.path.join(VoxCeleb1.meta_path,
'veri_test2.txt')
'meta') if target_dir else VoxCeleb.meta_path
VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb1, self).__init__()
super(VoxCeleb, self).__init__()

# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
Expand Down Expand Up @@ -300,7 +311,14 @@ def prepare_data(self):
# get all the train and dev audios file path
audio_files = []
speakers = set()
for path in [self.wav_path]:
for path in [self.wav_path, self.vox2_base_path]:
# if vox2 directory is not set and vox2 is not a directory
# we will not process this directory
if not path or not os.path.exists(path):
logger.warning(
f"{path} is an invalid path, please check again, "
"and we will ignore the vox2 base path")
continue
for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

logger = Log(__name__).getlog()


def extract_audio_embedding(args, config):
# stage 0: set the training device, cpu or gpu
paddle.set_device(args.device)
Expand Down Expand Up @@ -83,7 +84,7 @@ def extract_audio_embedding(args, config):
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

import numpy as np
import paddle
from yacs.config import CfgNode
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode

from paddleaudio.paddleaudio.datasets import VoxCeleb1
from paddlespeech.s2t.utils.log import Log
from paddleaudio.paddleaudio.datasets import VoxCeleb
from paddleaudio.paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything

logger = Log(__name__).getlog()


def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
Expand All @@ -44,7 +45,7 @@ def main(args, config):

# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers)
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)

# stage3: load the pre-trained model
args.load_checkpoint = os.path.abspath(
Expand All @@ -57,7 +58,7 @@ def main(args, config):
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')

# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb1(
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
feat_type='melspectrogram',
Expand All @@ -73,7 +74,7 @@ def main(args, config):
num_workers=config.num_workers,
return_list=True,)

test_dataset = VoxCeleb1(
test_dataset = VoxCeleb(
subset='test',
target_dir=args.data_dir,
feat_type='melspectrogram',
Expand Down Expand Up @@ -145,7 +146,7 @@ def main(args, config):
labels = []
enrol_ids = []
test_ids = []
with open(VoxCeleb1.veri_test_file, 'r') as f:
with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines():
label, enrol_id, test_id = line.strip().split(' ')
labels.append(int(label))
Expand All @@ -171,7 +172,7 @@ def main(args, config):
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
Expand Down
Loading