Skip to content

[vec] add speaker diarization pipeline #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/ami/sd0/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
###########################################################
# AMI DATA PREPARE SETTING #
###########################################################
split_type: 'full_corpus_asr'
skip_TNO: True
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
mic_type: 'Mix-Headset'
vad_type: 'oracle'
max_subseg_dur: 3.0
overlap: 1.5
# Some more exp folders (for cleaner structure).
embedding_dir: emb #!ref <save_folder>/emb
meta_data_dir: metadata #!ref <save_folder>/metadata
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
der_dir: DER #!ref <save_folder>/DER


###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sr: 16000 # sample rate
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
#left_frames: 0
#right_frames: 0
#deltas: False


###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
seed: 1234
emb_dim: 192
batch_size: 16
model:
input_size: 80
channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
attention_channels: 128
lin_neurons: 192
# Will automatically download ECAPA-TDNN model (best).

###########################################################
# SPECTRAL CLUSTERING SETTING #
###########################################################
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
affinity: 'cos' # options: cos, nn
max_num_spkrs: 10
oracle_n_spkrs: True


###########################################################
# DER EVALUATION SETTING #
###########################################################
ignore_overlap: True
forgiveness_collar: 0.25
231 changes: 231 additions & 0 deletions examples/ami/sd0/local/compute_embdding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import pickle
import sys

import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm.contrib import tqdm
from yacs.config import CfgNode

from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset_from_json import JSONDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything

# Logger setup
logger = Log(__name__).getlog()


def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
"""Prepares metadata for a given recording ID.

Arguments
---------
full_meta_data : json
Full meta (json) containing all the recordings
rec_id : str
The recording ID for which meta (json) has to be prepared
out_meta_file : str
Path of the output meta (json) file.
"""

subset = {}
for key in full_meta_data:
k = str(key)
if k.startswith(rec_id):
subset[key] = full_meta_data[key]

with open(out_meta_file, mode="w") as json_f:
json.dump(subset, json_f, indent=2)


def create_dataloader(json_file, batch_size):
"""Creates the datasets and their data processing pipelines.
This is used for multi-mic processing.
"""

# create datasets
dataset = JSONDataset(
json_file=json_file,
feat_type='melspectrogram',
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)

# create dataloader
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(dataset,
batch_sampler=batch_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
return_list=True)

return dataloader


def main(args, config):
# set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed
seed_everything(config.seed)

# stage1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)

# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)

# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))

# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')

# set the model to eval mode
model.eval()

# load meta data
meta_file = os.path.join(
args.data_dir,
config.meta_data_dir,
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
with open(meta_file, "r") as f:
full_meta = json.load(f)

# get all the recording IDs in this dataset.
all_keys = full_meta.keys()
A = [word.rstrip().split("_")[0] for word in all_keys]
all_rec_ids = list(set(A[1:]))
all_rec_ids.sort()
split = "AMI_" + args.dataset
i = 1

msg = "Extra embdding for " + args.dataset + " set"
logger.info(msg)

if len(all_rec_ids) <= 0:
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
logger.error(msg)
sys.exit()

# extra different recordings embdding in a dataset.
for rec_id in tqdm(all_rec_ids):
# This tag will be displayed in the log.
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
str(len(all_rec_ids)) + "]")
i = i + 1

# log message.
msg = "Embdding %s : %s " % (tag, rec_id)
logger.debug(msg)

# embedding directory.
if not os.path.exists(
os.path.join(args.data_dir, config.embedding_dir, split)):
os.makedirs(
os.path.join(args.data_dir, config.embedding_dir, split))

# file to store embeddings.
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
split, emb_file_name)

# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
# lets keep this meta-info in embedding directory itself.
json_file_name = rec_id + "." + config.mic_type + ".json"
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
split, json_file_name)

# write subset (meta for one recording) json metadata.
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)

# prepare data loader.
diary_set_loader = create_dataloader(meta_per_rec_file,
config.batch_size)

# extract embeddings (skip if already done).
if not os.path.isfile(diary_stat_emb_file):
logger.debug("Extracting deep embeddings")
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
segset = []

for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
# extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
seg = [x for x in ids]
segset = segset + seg
emb = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
embeddings = np.concatenate((embeddings, emb), axis=0)

segset = np.array(segset, dtype="|O")
stat_obj = EmbeddingMeta(
segset=segset,
stats=embeddings, )
logger.debug("Saving Embeddings...")
with open(diary_stat_emb_file, "wb") as output:
pickle.dump(stat_obj, output)

else:
logger.debug("Skipping embedding extraction (as already present).")


# Begin experiment!
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
'--device',
default="gpu",
help="Select which device to perform diarization, defaults to gpu.")
parser.add_argument(
"--config", default=None, type=str, help="configuration file")
parser.add_argument(
"--data-dir",
default="../save/",
type=str,
help="processsed data directory")
parser.add_argument(
"--dataset",
choices=['dev', 'eval'],
default="dev",
type=str,
help="Select which dataset to extra embdding, defaults to dev")
parser.add_argument(
"--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to compute embeddings.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)

config.freeze()

main(args, config)
49 changes: 0 additions & 49 deletions examples/ami/sd0/local/data.sh

This file was deleted.

Loading