Skip to content

Commit

Permalink
Enhanced speaker counting for short audio recordings (#2729)
Browse files Browse the repository at this point in the history
* Update enhanced speaker counting for short samples

Signed-off-by: Taejin Park <[email protected]>

* Update and doc string change

Signed-off-by: Taejin Park <[email protected]>

* Reflected PR review comments

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Update enhanced speaker counting for short samples

Signed-off-by: Taejin Park <[email protected]>

* Update and doc string change

Signed-off-by: Taejin Park <[email protected]>

* Reflected PR review comments

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

* Ran style fix again to fix it

Signed-off-by: Taejin Park <[email protected]>

Co-authored-by: Nithin Rao <[email protected]>
  • Loading branch information
tango4j and nithinraok authored Aug 30, 2021
1 parent d356243 commit b2ace62
Showing 1 changed file with 106 additions and 18 deletions.
124 changes: 106 additions & 18 deletions nemo/collections/asr/parts/utils/nmse_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
# This file is part of https://github.com/scikit-learn/scikit-learn/blob/114616d9f6ce9eba7c1aacd3d4a254f868010e25/sklearn/manifold/_spectral_embedding.py and
# https://github.com/tango4j/Auto-Tuning-Spectral-Clustering.

from collections import Counter

import numpy as np
import torch
from sklearn.cluster._kmeans import k_means
Expand All @@ -44,6 +46,7 @@
from torch.linalg import eigh as eigh

TORCH_EIGN = True

except ImportError:
TORCH_EIGN = False
from scipy.linalg import eigh as eigh
Expand Down Expand Up @@ -118,6 +121,71 @@ def getMinimumConnection(mat, max_N, n_list):
return affinity_mat, p_value


def addAnchorEmb(emb, anchor_sample_n, anchor_spk_n, sigma):
"""
Add randomly generated synthetic embeddings to make eigen analysis more stable.
We refer to these embeddings as anchor embeddings.
emb (float):
The input embedding from the emebedding extractor.
anchor_sample_n (int):
The number of embedding samples per speaker.
anchor_sample_n = 10 is recommended.
anchor_spk_n (int):
The number of speakers for synthetic embedding.
anchor_spk_n = 3 is recommended.
sigma (int):
The amplitude of synthetic noise for each embedding vector.
If sigma value is too small, under-counting could happen.
If sigma value is too large, over-counting could happen.
sigma = 50 is recommended.
"""
emb_dim = emb.shape[1]
mean, std_org = np.mean(emb, axis=0), np.std(emb, axis=0)
new_emb_list = []
for _ in range(anchor_spk_n):
emb_m = np.tile(np.random.randn(1, emb_dim), (anchor_sample_n, 1))
emb_noise = np.random.randn(anchor_sample_n, emb_dim).T
emb_noise = np.dot(np.diag(std_org), emb_noise / np.max(np.abs(emb_noise))).T
emb_gen = emb_m + sigma * emb_noise
new_emb_list.append(emb_gen)

new_emb_list.append(emb)
new_emb_np = np.vstack(new_emb_list)
return new_emb_np, anchor_sample_n * anchor_spk_n


def getEnhancedSpeakerCount(key, emb, cuda, random_test_count=5, anchor_spk_n=3, anchor_sample_n=10, sigma=50):
"""
Calculates the number of speakers using NME analysis with anchor embeddings.
"""
est_num_of_spk_list = []
for seed in range(random_test_count):
np.random.seed(seed)
emb_aug, anchor_length = addAnchorEmb(emb, anchor_sample_n, anchor_spk_n, sigma)
mat = getCosAffinityMatrix(emb_aug)
nmesc = NMESC(
mat,
max_num_speaker=emb.shape[0],
max_rp_threshold=0.25,
sparse_search=True,
sparse_search_volume=30,
fixed_thres=None,
NME_mat_size=300,
cuda=cuda,
)
est_num_of_spk, _, _ = nmesc.NMEanalysis()
est_num_of_spk_list.append(est_num_of_spk)

ctt = Counter(est_num_of_spk_list)
oracle_num_speakers = max(ctt.most_common(1)[0][0] - anchor_spk_n, 1)
return oracle_num_speakers


def getCosAffinityMatrix(emb):
"""
Calculates cosine similarity values among speaker embeddings.
Expand Down Expand Up @@ -327,10 +395,8 @@ def NMEanalysis(self):
if self.use_subsampling_for_NME:
subsample_ratio = self.subsampleAffinityMat(self.NME_mat_size)

"""
Scans p_values and find a p_value that generates
the smallest g_p value.
"""
# Scans p_values and find a p_value that generates
# the smallest g_p value.
eig_ratio_list, est_spk_n_dict = [], {}
self.p_value_list = self.getPvalueList()
for p_value in self.p_value_list:
Expand All @@ -342,16 +408,14 @@ def NMEanalysis(self):
rp_p_value = self.p_value_list[index_nn]
affinity_mat = getAffinityGraphMat(self.mat, rp_p_value)

"""
Checks whether affinity graph is fully connected.
If not, it adds minimum number of connections to make it fully connected.
"""
# Checks whether affinity graph is fully connected.
# If not, it adds minimum number of connections to make it fully connected.
if not isGraphFullyConnected(affinity_mat):
affinity_mat, rp_p_value = getMinimumConnection(self.mat, self.max_N, self.p_value_list)

p_hat_value = int(subsample_ratio * rp_p_value)
est_num_of_spk = est_spk_n_dict[rp_p_value]
return est_num_of_spk, p_hat_value
return est_num_of_spk, p_hat_value, eig_ratio_list[index_nn]

def subsampleAffinityMat(self, NME_mat_size):
"""
Expand Down Expand Up @@ -426,7 +490,16 @@ def getPvalueList(self):
return p_value_list


def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_samples=6, fixed_thres=None, cuda=False):
def COSclustering(
key,
emb,
oracle_num_speakers=None,
max_num_speaker=8,
min_samples_for_NMESC=6,
enhanced_count_thres=80,
fixed_thres=None,
cuda=False,
):
"""
Clustering method for speaker diarization based on cosine similarity.
Expand All @@ -443,19 +516,33 @@ def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_sam
max_num_speaker: (int)
Maximum number of clusters to consider for each session
min_samples: (int)
min_samples_for_NMESC: (int)
Minimum number of samples required for NME clustering, this avoids
zero p_neighbour_lists. Default of 6 is selected since (1/rp_threshold) >= 4
when max_rp_threshold = 0.25. Thus, NME analysis is skipped for matrices
smaller than (min_samples)x(min_samples).
zero p_neighbour_lists. If the input has fewer segments than min_samples,
it is directed to the enhanced speaker counting mode.
enhanced_count_thres: (int)
For short audio recordings under 60 seconds, clustering algorithm cannot
accumulate enough amount of speaker profile for each cluster.
Thus, getEnhancedSpeakerCount() employs anchor embeddings (dummy representations)
to mitigate the effect of cluster sparsity.
enhanced_count_thres = 80 is recommended.
Returns:
Y: (List[int])
Speaker label for each segment.
"""
mat = getCosAffinityMatrix(emb)
if emb.shape[0] == 1:
return np.array([0])
elif emb.shape[0] <= max(enhanced_count_thres, min_samples_for_NMESC) and oracle_num_speakers is None:
est_num_of_spk_enhanced = getEnhancedSpeakerCount(key, emb, cuda)
else:
est_num_of_spk_enhanced = None

if oracle_num_speakers:
max_num_speaker = oracle_num_speakers

mat = getCosAffinityMatrix(emb)
nmesc = NMESC(
mat,
max_num_speaker=max_num_speaker,
Expand All @@ -467,15 +554,16 @@ def COSclustering(key, emb, oracle_num_speakers=None, max_num_speaker=8, min_sam
cuda=cuda,
)

if emb.shape[0] > min_samples:
est_num_of_spk, p_hat_value = nmesc.NMEanalysis()
if emb.shape[0] > min_samples_for_NMESC:
est_num_of_spk, p_hat_value, best_g_p_value = nmesc.NMEanalysis()
affinity_mat = getAffinityGraphMat(mat, p_hat_value)
else:
affinity_mat = mat
est_num_of_spk, _, _ = estimateNumofSpeakers(affinity_mat, max_num_speaker, cuda)

if oracle_num_speakers:
est_num_of_spk = oracle_num_speakers
elif est_num_of_spk_enhanced:
est_num_of_spk = est_num_of_spk_enhanced

spectral_model = _SpectralClustering(n_clusters=est_num_of_spk, cuda=cuda)
Y = spectral_model.predict(affinity_mat)
Expand Down

0 comments on commit b2ace62

Please sign in to comment.