Skip to content

Commit fc72295

Browse files
authored
Merge pull request #1651 from ccrrong/ami
[vec] add speaker diarization pipeline
2 parents 44ee5cd + d16e625 commit fc72295

File tree

8 files changed

+1018
-55
lines changed

8 files changed

+1018
-55
lines changed

examples/ami/sd0/conf/ecapa_tdnn.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
###########################################################
2+
# AMI DATA PREPARE SETTING #
3+
###########################################################
4+
split_type: 'full_corpus_asr'
5+
skip_TNO: True
6+
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
7+
mic_type: 'Mix-Headset'
8+
vad_type: 'oracle'
9+
max_subseg_dur: 3.0
10+
overlap: 1.5
11+
# Some more exp folders (for cleaner structure).
12+
embedding_dir: emb #!ref <save_folder>/emb
13+
meta_data_dir: metadata #!ref <save_folder>/metadata
14+
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
15+
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
16+
der_dir: DER #!ref <save_folder>/DER
17+
18+
19+
###########################################################
20+
# FEATURE EXTRACTION SETTING #
21+
###########################################################
22+
# currently, we only support fbank
23+
sr: 16000 # sample rate
24+
n_mels: 80
25+
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
26+
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
27+
#left_frames: 0
28+
#right_frames: 0
29+
#deltas: False
30+
31+
32+
###########################################################
33+
# MODEL SETTING #
34+
###########################################################
35+
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
36+
# if we want use another model, please choose another configuration yaml file
37+
seed: 1234
38+
emb_dim: 192
39+
batch_size: 16
40+
model:
41+
input_size: 80
42+
channels: [1024, 1024, 1024, 1024, 3072]
43+
kernel_sizes: [5, 3, 3, 3, 1]
44+
dilations: [1, 2, 3, 4, 1]
45+
attention_channels: 128
46+
lin_neurons: 192
47+
# Will automatically download ECAPA-TDNN model (best).
48+
49+
###########################################################
50+
# SPECTRAL CLUSTERING SETTING #
51+
###########################################################
52+
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
53+
affinity: 'cos' # options: cos, nn
54+
max_num_spkrs: 10
55+
oracle_n_spkrs: True
56+
57+
58+
###########################################################
59+
# DER EVALUATION SETTING #
60+
###########################################################
61+
ignore_overlap: True
62+
forgiveness_collar: 0.25
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import argparse
15+
import json
16+
import os
17+
import pickle
18+
import sys
19+
20+
import numpy as np
21+
import paddle
22+
from paddle.io import BatchSampler
23+
from paddle.io import DataLoader
24+
from tqdm.contrib import tqdm
25+
from yacs.config import CfgNode
26+
27+
from paddlespeech.s2t.utils.log import Log
28+
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
29+
from paddlespeech.vector.io.batch import batch_feature_normalize
30+
from paddlespeech.vector.io.dataset_from_json import JSONDataset
31+
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
32+
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
33+
from paddlespeech.vector.training.seeding import seed_everything
34+
35+
# Logger setup
36+
logger = Log(__name__).getlog()
37+
38+
39+
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
40+
"""Prepares metadata for a given recording ID.
41+
42+
Arguments
43+
---------
44+
full_meta_data : json
45+
Full meta (json) containing all the recordings
46+
rec_id : str
47+
The recording ID for which meta (json) has to be prepared
48+
out_meta_file : str
49+
Path of the output meta (json) file.
50+
"""
51+
52+
subset = {}
53+
for key in full_meta_data:
54+
k = str(key)
55+
if k.startswith(rec_id):
56+
subset[key] = full_meta_data[key]
57+
58+
with open(out_meta_file, mode="w") as json_f:
59+
json.dump(subset, json_f, indent=2)
60+
61+
62+
def create_dataloader(json_file, batch_size):
63+
"""Creates the datasets and their data processing pipelines.
64+
This is used for multi-mic processing.
65+
"""
66+
67+
# create datasets
68+
dataset = JSONDataset(
69+
json_file=json_file,
70+
feat_type='melspectrogram',
71+
n_mels=config.n_mels,
72+
window_size=config.window_size,
73+
hop_length=config.hop_size)
74+
75+
# create dataloader
76+
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
77+
dataloader = DataLoader(dataset,
78+
batch_sampler=batch_sampler,
79+
collate_fn=lambda x: batch_feature_normalize(
80+
x, mean_norm=True, std_norm=False),
81+
return_list=True)
82+
83+
return dataloader
84+
85+
86+
def main(args, config):
87+
# set the training device, cpu or gpu
88+
paddle.set_device(args.device)
89+
# set the random seed
90+
seed_everything(config.seed)
91+
92+
# stage1: build the dnn backbone model network
93+
ecapa_tdnn = EcapaTdnn(**config.model)
94+
95+
# stage2: build the speaker verification eval instance with backbone model
96+
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
97+
98+
# stage3: load the pre-trained model
99+
# we get the last model from the epoch and save_interval
100+
args.load_checkpoint = os.path.abspath(
101+
os.path.expanduser(args.load_checkpoint))
102+
103+
# load model checkpoint to sid model
104+
state_dict = paddle.load(
105+
os.path.join(args.load_checkpoint, 'model.pdparams'))
106+
model.set_state_dict(state_dict)
107+
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
108+
109+
# set the model to eval mode
110+
model.eval()
111+
112+
# load meta data
113+
meta_file = os.path.join(
114+
args.data_dir,
115+
config.meta_data_dir,
116+
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
117+
with open(meta_file, "r") as f:
118+
full_meta = json.load(f)
119+
120+
# get all the recording IDs in this dataset.
121+
all_keys = full_meta.keys()
122+
A = [word.rstrip().split("_")[0] for word in all_keys]
123+
all_rec_ids = list(set(A[1:]))
124+
all_rec_ids.sort()
125+
split = "AMI_" + args.dataset
126+
i = 1
127+
128+
msg = "Extra embdding for " + args.dataset + " set"
129+
logger.info(msg)
130+
131+
if len(all_rec_ids) <= 0:
132+
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
133+
logger.error(msg)
134+
sys.exit()
135+
136+
# extra different recordings embdding in a dataset.
137+
for rec_id in tqdm(all_rec_ids):
138+
# This tag will be displayed in the log.
139+
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
140+
str(len(all_rec_ids)) + "]")
141+
i = i + 1
142+
143+
# log message.
144+
msg = "Embdding %s : %s " % (tag, rec_id)
145+
logger.debug(msg)
146+
147+
# embedding directory.
148+
if not os.path.exists(
149+
os.path.join(args.data_dir, config.embedding_dir, split)):
150+
os.makedirs(
151+
os.path.join(args.data_dir, config.embedding_dir, split))
152+
153+
# file to store embeddings.
154+
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
155+
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
156+
split, emb_file_name)
157+
158+
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
159+
# lets keep this meta-info in embedding directory itself.
160+
json_file_name = rec_id + "." + config.mic_type + ".json"
161+
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
162+
split, json_file_name)
163+
164+
# write subset (meta for one recording) json metadata.
165+
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
166+
167+
# prepare data loader.
168+
diary_set_loader = create_dataloader(meta_per_rec_file,
169+
config.batch_size)
170+
171+
# extract embeddings (skip if already done).
172+
if not os.path.isfile(diary_stat_emb_file):
173+
logger.debug("Extracting deep embeddings")
174+
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
175+
segset = []
176+
177+
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
178+
# extrac the audio embedding
179+
ids, feats, lengths = batch['ids'], batch['feats'], batch[
180+
'lengths']
181+
seg = [x for x in ids]
182+
segset = segset + seg
183+
emb = model.backbone(feats, lengths).squeeze(
184+
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
185+
embeddings = np.concatenate((embeddings, emb), axis=0)
186+
187+
segset = np.array(segset, dtype="|O")
188+
stat_obj = EmbeddingMeta(
189+
segset=segset,
190+
stats=embeddings, )
191+
logger.debug("Saving Embeddings...")
192+
with open(diary_stat_emb_file, "wb") as output:
193+
pickle.dump(stat_obj, output)
194+
195+
else:
196+
logger.debug("Skipping embedding extraction (as already present).")
197+
198+
199+
# Begin experiment!
200+
if __name__ == "__main__":
201+
parser = argparse.ArgumentParser(__doc__)
202+
parser.add_argument(
203+
'--device',
204+
default="gpu",
205+
help="Select which device to perform diarization, defaults to gpu.")
206+
parser.add_argument(
207+
"--config", default=None, type=str, help="configuration file")
208+
parser.add_argument(
209+
"--data-dir",
210+
default="../save/",
211+
type=str,
212+
help="processsed data directory")
213+
parser.add_argument(
214+
"--dataset",
215+
choices=['dev', 'eval'],
216+
default="dev",
217+
type=str,
218+
help="Select which dataset to extra embdding, defaults to dev")
219+
parser.add_argument(
220+
"--load-checkpoint",
221+
type=str,
222+
default='',
223+
help="Directory to load model checkpoint to compute embeddings.")
224+
args = parser.parse_args()
225+
config = CfgNode(new_allowed=True)
226+
if args.config:
227+
config.merge_from_file(args.config)
228+
229+
config.freeze()
230+
231+
main(args, config)

examples/ami/sd0/local/data.sh

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)