Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
16108de
add voxceleb1 dataset prepare process
LeoMax-Xiong Feb 24, 2022
35b7968
remove invalid directory
LeoMax-Xiong Feb 25, 2022
6f7e965
add kaldi feats ark dataset
LeoMax-Xiong Feb 25, 2022
d7da629
add kaldi feats egs dataset
LeoMax-Xiong Feb 26, 2022
70d3b01
remove invalid code
LeoMax-Xiong Feb 26, 2022
1395b5f
Merge branch 'PaddlePaddle:develop' into develop
LeoMax-Xiong Mar 2, 2022
7ef60eb
add voxceleb1 data prepare
LeoMax-Xiong Mar 2, 2022
0780d18
remove personal code test=doc
LeoMax-Xiong Mar 2, 2022
3a943ca
repair the variable name bug
LeoMax-Xiong Mar 2, 2022
dc28ebe
move the csv vox format to paddleaudio, test=doc
LeoMax-Xiong Mar 3, 2022
57c4f4a
add sid learning rate and training model
LeoMax-Xiong Mar 3, 2022
6af2bc3
add sid loss wraper for voxceleb, test=doc
LeoMax-Xiong Mar 3, 2022
7668f61
add sid dataloader for training, test=doc
LeoMax-Xiong Mar 3, 2022
4648059
add training process for sid, test=doc
LeoMax-Xiong Mar 3, 2022
1f74af1
add training log info and comment, test=doc
LeoMax-Xiong Mar 3, 2022
97ec012
add speaker verification using cosine score, test=doc
LeoMax-Xiong Mar 4, 2022
016ed6d
repair the code according to the part comment, test=doc
LeoMax-Xiong Mar 4, 2022
ac4967e
optimize the data prepare process
LeoMax-Xiong Mar 6, 2022
2d89c80
add waveform augment pipeline, test=doc
LeoMax-Xiong Mar 7, 2022
7db7eb8
add extract audio embedding api, test=doc
LeoMax-Xiong Mar 7, 2022
386ef3f
add voxceleb augment unit test, test=doc
LeoMax-Xiong Mar 8, 2022
14efbf5
check extract embedding result, test=doc
LeoMax-Xiong Mar 8, 2022
60d73bb
add state 0 to prepare the voxcele data and augment data
LeoMax-Xiong Mar 9, 2022
0dee8f4
Merge branch 'PaddlePaddle:develop' into develop
LeoMax-Xiong Mar 9, 2022
4473405
merge develop to vox12, test=doc
LeoMax-Xiong Mar 9, 2022
0e87037
refactor to compilance paddleaudio
LeoMax-Xiong Mar 9, 2022
993d678
remove unused code, test=doc
LeoMax-Xiong Mar 9, 2022
584a2c0
add ecapa-tdnn config yaml file
LeoMax-Xiong Mar 9, 2022
8ed5c28
add vox2 data into VoxCeleb class
LeoMax-Xiong Mar 10, 2022
311fa87
add some comments to the code
LeoMax-Xiong Mar 13, 2022
7eb8fa7
convert save_freq to save_interval, test=doc
LeoMax-Xiong Mar 13, 2022
506d26a
change the code style to s2t code style, test=doc
LeoMax-Xiong Mar 14, 2022
d28ccfa
add vector cli component, test=doc
LeoMax-Xiong Mar 20, 2022
9c6735f
add vector voxceleb12 base mode url, test=doc
LeoMax-Xiong Mar 21, 2022
b9eafdd
change - to _ to distinguish field
LeoMax-Xiong Mar 21, 2022
9874fb7
add some comments in code
LeoMax-Xiong Mar 22, 2022
d85d1de
exec pre-commit in paddlespeech vector, test=doc
LeoMax-Xiong Mar 22, 2022
5221c27
add voxceleb dataset and trial info, test=doc
LeoMax-Xiong Mar 23, 2022
e2684e7
refactor the data prepare process
LeoMax-Xiong Mar 23, 2022
62cbce6
add vectorwrapper to extract audio embedding
LeoMax-Xiong Mar 24, 2022
0bb67d8
add vector cli unit test, test=doc
LeoMax-Xiong Mar 24, 2022
305bacd
Merge branch 'develop' into vox12
LeoMax-Xiong Mar 24, 2022
0f78d25
add vector cli batch and pipeline test demo, test=doc
LeoMax-Xiong Mar 24, 2022
3054659
remove debug info, test=doc
LeoMax-Xiong Mar 24, 2022
faf6b8d
add the vec cli test audio name, test=doc
LeoMax-Xiong Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
487 changes: 175 additions & 312 deletions dataset/voxceleb/voxceleb1.py

Large diffs are not rendered by default.

256 changes: 255 additions & 1 deletion examples/voxceleb/sv0/local/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler

from dataset.voxceleb.voxceleb1 import VoxCeleb1
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram
from paddlespeech.vector.training.time import Timer
from paddlespeech.vector.datasets.batch import feature_normalize
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification

# feat configuration
cpu_feat_conf = {
'n_mels': 80,
'window_size': 400,
'hop_length': 160,
}

def main(args):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)

# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
Expand All @@ -27,7 +48,208 @@ def main(args):
local_rank = paddle.distributed.get_rank()

# stage2: data prepare
# note: some cmd must do in rank==0
train_ds = VoxCeleb1('train', target_dir=args.data_dir)
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)

# stage3: build the dnn backbone model network
#"channels": [1024, 1024, 1024, 1024, 3072],
model_conf = {
"input_size": 80,
"channels": [512, 512, 512, 512, 1536],
"kernel_sizes": [5, 3, 3, 3, 1],
"dilations": [1, 2, 3, 4, 1],
"attention_channels": 128,
"lin_neurons": 192,
}
ecapa_tdnn = EcapaTdnn(**model_conf)

# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers)

# stage5: build the optimizer, we now only construct the AdamW optimizer
lr_schedule = CyclicLRScheduler(
base_lr=args.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_schedule, parameters=model.parameters())

# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))

# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0
if args.load_checkpoint:
print("load the check point")
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
try:
# load model checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)

# load optimizer checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdopt'))
optimizer.set_state_dict(state_dict)
if local_rank == 0:
print(f'Checkpoint loaded from {args.load_checkpoint}')
except FileExistsError:
if local_rank == 0:
print('Train from scratch.')

try:
start_epoch = int(args.load_checkpoint[-1])
print(f'Restore training from epoch {start_epoch}.')
except ValueError:
pass

# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_loader = DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=args.num_workers,
collate_fn=waveform_collate_fn,
return_list=True,
use_buffer_reader=True, )

# stage9: start to train
# we will comment the training process
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * args.epochs)
timer.start()

for epoch in range(start_epoch + 1, args.epochs + 1):
# at the begining, model must set to train mode
model.train()

avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
# stage 9-1: batch data is audio sample points and speaker id label
waveforms, labels = batch['waveforms'], batch['labels']

# stage 9-2: audio sample augment method, which is done on the audio sample point
# todo

# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **cpu_feat_conf)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))

# stage 9-4: feature normalize, which help converge and imporve the performance
feats = feature_normalize(
feats, mean_norm=True, std_norm=False) # Features normalization

# stage 9-5: model forward, such ecapa-tdnn, x-vector
logits = model(feats)

# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
loss = criterion(logits, labels)

# stage 9-7: update the gradient and clear the gradient cache
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()

# stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0]

# stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count() # step plus one in timer

# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= args.log_freq
avg_acc = num_corrects / num_samples

print_msg = 'Train Epoch={}/{}, Step={}/{}'.format(
epoch, args.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
print(print_msg)

avg_loss = 0
num_corrects = 0
num_samples = 0

# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0:
paddle.distributed.barrier(
) # Wait for valid step in main process
continue # Resume trainning on other process

# stage 9-12: construct the valid dataset dataloader
dev_sampler = BatchSampler(
dev_ds,
batch_size=args.batch_size // 4,
shuffle=False,
drop_last=False)
dev_loader = DataLoader(
dev_ds,
batch_sampler=dev_sampler,
collate_fn=waveform_collate_fn,
num_workers=args.num_workers,
return_list=True, )

# set the model to eval mode
model.eval()
num_corrects = 0
num_samples = 0

# stage 9-13: evaluation the valid dataset batch data
print('Evaluate on validation dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(dev_loader):
waveforms, labels = batch['waveforms'], batch['labels']

feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **cpu_feat_conf)
feats.append(feat)

feats = paddle.to_tensor(np.asarray(feats))
feats = feature_normalize(
feats, mean_norm=True, std_norm=False)
logits = model(feats)

preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]

print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
print(print_msg)

# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
print('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))

if nranks > 1:
paddle.distributed.barrier() # Main process


if __name__ == "__main__":
Expand All @@ -41,6 +263,38 @@ def main(args):
default="./data/",
type=str,
help="data directory")
parser.add_argument("--learning-rate",
type=float,
default=1e-8,
help="Learning rate used to train with warmup.")
parser.add_argument("--load-checkpoint",
type=str,
default=None,
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--batch-size",
type=int, default=64,
help="Total examples' number in batch for training.")
parser.add_argument("--num-workers",
type=int,
default=0,
help="Number of workers in dataloader.")
parser.add_argument("--epochs",
type=int,
default=50,
help="Number of epoches for fine-tuning.")
parser.add_argument("--log-freq",
type=int,
default=10,
help="Log the training infomation every n steps.")
parser.add_argument("--save-freq",
type=int,
default=1,
help="Save checkpoint every n epoch.")
parser.add_argument("--checkpoint-dir",
type=str,
default='./checkpoint',
help="Directory to save model checkpoints.")

args = parser.parse_args()
# yapf: enable

Expand Down
Loading