diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml index 09e4bb378c69..6c24683dd885 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml @@ -52,7 +52,9 @@ diarizer: max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. sparse_search_volume: 10 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. - + chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) + msdd_model: model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) parameters: @@ -88,5 +90,4 @@ diarizer: arpa_language_model: null # Provide a KenLM language model in .arpa format. min_number_of_words: 3 # Min number of words for the left context. max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. - + logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml index 2d53d0916dde..738cbfd0ca72 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml @@ -52,6 +52,8 @@ diarizer: max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. + chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) msdd_model: model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) @@ -88,5 +90,4 @@ diarizer: arpa_language_model: null # Provide a KenLM language model in .arpa format. min_number_of_words: 3 # Min number of words for the left context. max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. - + logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml index c9d7cdf32f45..8a7530577c01 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml @@ -44,7 +44,7 @@ diarizer: multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. - clustering: + clustering: parameters: oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. @@ -52,6 +52,8 @@ diarizer: max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. + chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) msdd_model: model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) @@ -88,5 +90,4 @@ diarizer: arpa_language_model: null # Provide a KenLM language model in .arpa format. min_number_of_words: 3 # Min number of words for the left context. max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. - + logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/longform_clustering.py b/nemo/collections/asr/parts/utils/longform_clustering.py new file mode 100644 index 000000000000..171c074d9e10 --- /dev/null +++ b/nemo/collections/asr/parts/utils/longform_clustering.py @@ -0,0 +1,422 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple +import torch +from tqdm import tqdm +from nemo.collections.asr.parts.utils.offline_clustering import ( + SpeakerClustering, + get_scale_interpolated_embs, + getCosAffinityMatrix, + split_input_data, +) +from nemo.collections.asr.parts.utils.online_clustering import get_merge_quantity, run_reducer + + +class LongFormSpeakerClustering(torch.nn.Module): + def __init__(self, cuda: bool = False): + """ + Initializes a speaker clustering class tailored for long-form audio, leveraging methods from the `SpeakerClustering` class. + The clustering algorithm for long-form content is executed via the `forward_infer` function (not shown here). Input embedding + vectors are divided into chunks, each of size `embeddings_per_chunk`. Within every chunk, the clustering algorithm aims + to identify `chunk_cluster_count` distinct clusters. The resulting clustering labels are then expanded to match the original + length of the input embeddings. + + NOTE: torch.jit.script currently does not support inherited methods with a `super()` call. + + Args: + cuda (bool): + Flag indicating whether CUDA is available for computation. + """ + super().__init__() + self.speaker_clustering = SpeakerClustering(cuda=cuda) + self.embeddings_in_scales: List[torch.Tensor] = [torch.tensor([0])] + self.timestamps_in_scales: List[torch.Tensor] = [torch.tensor([0])] + self.cuda = cuda + self.device = torch.device("cuda") if self.cuda else torch.device("cpu") + + def check_input(self, embeddings_per_chunk: int, chunk_cluster_count: int, max_num_speakers: int) -> None: + """ + Checks the validity of the input parameters. + + Args: + embeddings_per_chunk (int): + The size of the windows in which the algorithm aims to identify `chunk_cluster_count` clusters. + chunk_cluster_count (int): + The target number of clusters to identify within each window. + max_num_speakers (int): + The maximum number of speakers to be detected in the audio. + """ + if chunk_cluster_count is None or embeddings_per_chunk is None: + raise ValueError( + f"chunk_cluster_count ({chunk_cluster_count}) and embeddings_per_chunk ({embeddings_per_chunk}) should be set." + ) + elif ( + all(v is not None for v in [chunk_cluster_count, embeddings_per_chunk]) + and chunk_cluster_count >= embeddings_per_chunk + ): + raise ValueError( + f"chunk_cluster_count ({chunk_cluster_count}) should be smaller than embeddings_per_chunk ({embeddings_per_chunk})." + ) + + if chunk_cluster_count <= max_num_speakers: + raise ValueError( + f"chunk_cluster_count ({chunk_cluster_count}) should be larger than max_num_speakers ({max_num_speakers})." + ) + + def unpack_labels( + self, + Y_aggr: torch.Tensor, + window_range_list: List[List[int]], + absolute_merge_mapping: List[List[torch.Tensor]], + org_len: int, + ) -> torch.LongTensor: + """ + Unpack the labels from the aggregated labels to the original labels. + + Args: + Y_aggr (Tensor): + Aggregated label vector from the merged segments. + window_range_list (List[List[int]]): + List of window ranges for each of the merged segments. + absolute_merge_mapping (List[List[torch.Tensor]]): + List of absolute mappings for each of the merged segments. Each list element contains two tensors: + - The first tensor represents the absolute index of the bypassed segment (segments that remain unchanged). + - The second tensor represents the absolute index of the merged segment (segments that have had their indexes changed). + org_len (int): + Original length of the labels. In most cases, this is a fairly large number (on the order of 10^5). + + Returns: + Y_unpack (Tensor): + Unpacked labels derived from the aggregated labels. + """ + Y_unpack = torch.zeros((org_len,)).long().to(Y_aggr.device) + for (win_rng, abs_mapping) in zip(window_range_list, absolute_merge_mapping): + inferred_merged_embs = Y_aggr[win_rng[0] : win_rng[1]] + if len(abs_mapping[1]) > 0: + Y_unpack[abs_mapping[1]] = inferred_merged_embs[-1].clone() # Merged + if len(abs_mapping[0]) > 0: + Y_unpack[abs_mapping[0]] = inferred_merged_embs[:-1].clone() # Bypass + else: + if len(abs_mapping[0]) > 0: + Y_unpack[abs_mapping[0]] = inferred_merged_embs.clone() + return Y_unpack + + def split_embs_to_windows( + self, index: int, emb: torch.Tensor, embeddings_per_chunk: int, + ) -> Tuple[torch.Tensor, int]: + """ + Splits the embedding tensor into smaller window-sized tensors based on a given index. + + Args: + index (int): The index of the desired window. This determines the starting point + of the window using the formula: + start = embeddings_per_chunk * index + emb (Tensor): The embedding tensor which needs to be split. + embeddings_per_chunk (int): + The size of the windows in which the algorithm aims to identify `chunk_cluster_count` clusters. + + Returns: + emb_part (Tensor): + The window-sized tensor, which is a portion of the `emb`. + offset_index (int): + The starting position of the window in the `emb` tensor. + """ + if embeddings_per_chunk * (index + 1) > emb.shape[0]: + emb_part = emb[-1 * embeddings_per_chunk :] + offset_index = emb.shape[0] - embeddings_per_chunk + else: + emb_part = emb[embeddings_per_chunk * index : embeddings_per_chunk * (index + 1)] + offset_index = embeddings_per_chunk * index + return emb_part, offset_index + + def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor: + """ + A function wrapper designed for performing inference using an exported script format. + + Note: + A dictionary is used to facilitate inference with the exported jit model in the Triton server. + This is done using an easy-to-understand naming convention. + See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#special-conventions-for-pytorch-backend + + Args: + param_dict (dict): + Dictionary containing the arguments for speaker clustering. + See `forward_infer` function for the argument information. + + Returns: + (LongTensor): Speaker labels for the segments in the given input embeddings. + """ + embeddings_in_scales = param_dict['embeddings'] + timestamps_in_scales = param_dict['timestamps'] + multiscale_segment_counts = param_dict['multiscale_segment_counts'] + multiscale_weights = param_dict['multiscale_weights'] + oracle_num_speakers = int(param_dict['oracle_num_speakers'].item()) + max_num_speakers = int(param_dict['max_num_speakers'].item()) + enhanced_count_thres = int(param_dict['enhanced_count_thres'].item()) + sparse_search_volume = int(param_dict['sparse_search_volume'].item()) + max_rp_threshold = float(param_dict['max_rp_threshold'].item()) + fixed_thres = float(param_dict['fixed_thres'].item()) + return self.forward_infer( + embeddings_in_scales=embeddings_in_scales, + timestamps_in_scales=timestamps_in_scales, + multiscale_segment_counts=multiscale_segment_counts, + multiscale_weights=multiscale_weights, + oracle_num_speakers=oracle_num_speakers, + max_rp_threshold=max_rp_threshold, + max_num_speakers=max_num_speakers, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, + ) + + def get_div_ceil_count(self, numer: int, denomin: int) -> int: + """ + Calculates the ceiling of the division of two integers. + + Args: + numer (int): Numerator, the number of segments or clusters, for example. + denomin (int): Denominator, the number of speakers or clusters, for example. + + Returns: + (int): The ceiling of the division of the two integers (number of chunks). + """ + return int(torch.ceil(torch.tensor(numer / denomin)).item()) + + def long_forward_infer( + self, + embeddings_in_scales: torch.Tensor, + timestamps_in_scales: torch.Tensor, + multiscale_segment_counts: torch.LongTensor, + multiscale_weights: torch.Tensor, + oracle_num_speakers: int, + max_rp_threshold: float, + max_num_speakers: int, + sparse_search_volume: int, + fixed_thres: float, + chunk_cluster_count: int, + embeddings_per_chunk: int, + ) -> torch.LongTensor: + """ + This is forward function for long-form speaker clustering. + Please refer to `SpeakerClustering` class for the original argument information. + + In the `LongFormSpeakerClustering` process: + Step-1: Input embeddings are divided into smaller windows of size `embeddings_per_chunk`. + Step-2: Each window undergoes overclustering, resulting in `chunk_cluster_count` fine-grained clusters. + Step-3: These fine-grained clusters are merged to form the aggregated clustering labels `Y_aggr`. + Step-4: The `unpack_labels` function is then employed to map the aggregated labels `Y_aggr` back to the + original labels for all `org_len` input embeddings: `Y_unpack`. + + Args: + embeddings_in_scales (Tensor): + List containing concatenated Torch tensor embeddings across multiple scales. + The length of the list is equal to the number of scales. + Each tensor has dimensions of (Number of base segments) x (Embedding Dimension). + timestamps_in_scales (Tensor): + List containing concatenated Torch tensor timestamps across multiple scales. + The length of the list is equal to the number of scales. + Each tensor has dimensions of (Total number of segments across all scales) x 2. + Example: + >>> timestamps_in_scales[0] = \ + torch.Tensor([[0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]]) + multiscale_segment_counts (LongTensor): + A Torch tensor containing the number of segments for each scale. + The tensor has dimensions of (Number of scales). + Example: + >>> multiscale_segment_counts = torch.LongTensor([31, 52, 84, 105, 120]) + multiscale_weights (Tensor): + Multi-scale weights used when merging affinity scores. + Example: + >>> multiscale_weights = torch.tensor([1.4, 1.3, 1.2, 1.1, 1.0]) + oracle_num_speakers (int): + The number of speakers in a session as given by the reference transcript. + max_num_speakers (int): + The upper bound for the number of speakers in each session. + max_rp_threshold (float): + Limits the range of parameter search. + The clustering performance can vary based on this range. + The default value is 0.15. + enhanced_count_thres (int): + For shorter audio recordings, the clustering algorithm might not accumulate enough speaker profiles for each cluster. + Thus, the function `getEnhancedSpeakerCount` uses anchor embeddings (dummy representations) to mitigate the effects of cluster sparsity. + A value of 80 is recommended for `enhanced_count_thres`. + sparse_search_volume (int): + The number of p_values considered during NME analysis. + The default is 30. Lower values speed up the NME-analysis but might lead to poorer parameter estimations. Values below 20 are not recommended. + fixed_thres (float): + If a `fixed_thres` value is provided, the NME-analysis process will be skipped. + This value should be optimized on a development set for best results. + By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold. + kmeans_random_trials (int): + The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1. + chunk_cluster_count (int): + The target number of clusters to identify within each chunk. + embeddings_per_chunk (int): + The size of the chunks in which the algorithm aims to identify `chunk_cluster_count` clusters. + + Returns: + Y_unpack (LongTensor): + Speaker labels for the segments in the provided input embeddings. + """ + self.check_input(embeddings_per_chunk, chunk_cluster_count, max_num_speakers) + + self.embeddings_in_scales, self.timestamps_in_scales = split_input_data( + embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts + ) + emb, _ = get_scale_interpolated_embs( + multiscale_weights, self.embeddings_in_scales, self.timestamps_in_scales, self.device + ) + offset_index: int = 0 + window_offset: int = 0 + total_emb: List[torch.Tensor] = [] + window_range_list: List[List[int]] = [] + absolute_merge_mapping: List[List[torch.Tensor]] = [] + total_window_count = self.get_div_ceil_count(numer=emb.shape[0], denomin=embeddings_per_chunk) + + if not torch.jit.is_scripting(): + pbar = tqdm(range(total_window_count), desc="Clustering Sub-Windows", leave=True, unit="window") + else: + pbar = range(total_window_count) + + for win_index in pbar: + # Step-1: Split the embeddings into smaller chunks + emb_part, offset_index = self.split_embs_to_windows( + index=win_index, emb=emb, embeddings_per_chunk=embeddings_per_chunk + ) + + # Step-2: Perform overclustering on the chunks to identify `chunk_cluster_count` clusters + if emb_part.shape[0] == 1: + Y_part = torch.zeros((1,), dtype=torch.int64) + else: + mat = getCosAffinityMatrix(emb_part) + overcluster_count = min(chunk_cluster_count, mat.shape[0]) + Y_part = self.speaker_clustering.forward_unit_infer( + mat=mat, + oracle_num_speakers=overcluster_count, + max_rp_threshold=max_rp_threshold, + max_num_speakers=chunk_cluster_count, + sparse_search_volume=sparse_search_volume, + ) + + # Step-3: Merge the clusters to form the aggregated clustering labels `Y_aggr` + num_to_be_merged = int(min(embeddings_per_chunk, emb_part.shape[0]) - chunk_cluster_count) + min_count_per_cluster = self.get_div_ceil_count( + numer=chunk_cluster_count, denomin=len(torch.unique(Y_part)) + ) + + # We want only one embedding vector for each cluster, so we calculate the number of embedding vectors to be removed + class_target_vol = get_merge_quantity( + num_to_be_removed=num_to_be_merged, + pre_clus_labels=Y_part, + min_count_per_cluster=min_count_per_cluster, + ) + if not torch.jit.is_scripting(): + pbar.update(1) + + # `class_target_vol` is a list of cluster-indices from overclustering + for spk_idx, merge_quantity in enumerate(list(class_target_vol)): + merged_embs, merged_clus_labels, index_mapping = run_reducer( + pre_embs=emb_part, target_spk_idx=spk_idx, merge_quantity=merge_quantity, pre_clus_labels=Y_part, + ) + total_emb.append(merged_embs) + absolute_index_mapping = [x + offset_index for x in index_mapping] + absolute_merge_mapping.append(absolute_index_mapping) + window_range_list.append([window_offset, window_offset + merged_embs.shape[0]]) + window_offset += merged_embs.shape[0] + + if not torch.jit.is_scripting(): + pbar.close() + + # Concatenate the reduced embeddings then perform high-level clustering + reduced_embs = torch.cat(total_emb) + reduced_mat = getCosAffinityMatrix(reduced_embs) + + # Step-4: Map the aggregated labels `Y_aggr` back to the original labels for all `org_len` input embeddings: `Y_unpack` + Y_aggr = self.speaker_clustering.forward_unit_infer( + mat=reduced_mat, + oracle_num_speakers=oracle_num_speakers, + max_rp_threshold=max_rp_threshold, + max_num_speakers=max_num_speakers, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, + ) + if reduced_embs.shape[0] != Y_aggr.shape[0]: + raise ValueError( + f"The number of embeddings ({reduced_embs.shape[0]}) and the number of clustered labels ({Y_aggr.shape[0]}) do not match." + ) + + # Reassign the labels to the original embeddings + Y_unpack = self.unpack_labels( + Y_aggr=Y_aggr, + window_range_list=window_range_list, + absolute_merge_mapping=absolute_merge_mapping, + org_len=emb.shape[0], + ) + if Y_unpack.shape[0] != emb.shape[0]: + raise ValueError( + f"The number of raw input embeddings ({emb.shape[0]}) and the number of clustered labels ({Y_unpack.shape[0]}) do not match." + ) + return Y_unpack + + def forward_infer( + self, + embeddings_in_scales: torch.Tensor, + timestamps_in_scales: torch.Tensor, + multiscale_segment_counts: torch.LongTensor, + multiscale_weights: torch.Tensor, + oracle_num_speakers: int = -1, + max_rp_threshold: float = 0.15, + max_num_speakers: int = 8, + enhanced_count_thres: int = 80, + sparse_search_volume: int = 30, + fixed_thres: float = -1.0, + chunk_cluster_count: int = 50, + embeddings_per_chunk: int = 10000, + ) -> torch.LongTensor: + """ + This function is a wrapper designed for toggling between long-form and short-form speaker clustering. + The details of short-form clustering is in `SpeakerClustering` class. + NOTE: `torch.jit.script` currently does not support `**kwargs` in the function signature therefore, + we need to use a wrapper function to handle the arguments. + """ + if embeddings_per_chunk is not None and torch.max(multiscale_segment_counts) > embeddings_per_chunk: + return self.long_forward_infer( + embeddings_in_scales=embeddings_in_scales, + timestamps_in_scales=timestamps_in_scales, + multiscale_segment_counts=multiscale_segment_counts, + multiscale_weights=multiscale_weights, + oracle_num_speakers=oracle_num_speakers, + max_rp_threshold=max_rp_threshold, + max_num_speakers=max_num_speakers, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, + chunk_cluster_count=chunk_cluster_count, + embeddings_per_chunk=embeddings_per_chunk, + ) + else: + cluster_labels = self.speaker_clustering.forward_infer( + embeddings_in_scales=embeddings_in_scales, + timestamps_in_scales=timestamps_in_scales, + multiscale_segment_counts=multiscale_segment_counts, + multiscale_weights=multiscale_weights, + oracle_num_speakers=oracle_num_speakers, + max_rp_threshold=max_rp_threshold, + max_num_speakers=max_num_speakers, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, + ) + self.timestamps_in_scales = self.speaker_clustering.timestamps_in_scales + return cluster_labels diff --git a/nemo/collections/asr/parts/utils/offline_clustering.py b/nemo/collections/asr/parts/utils/offline_clustering.py index d62ba23d8b6b..3f6c90d945ef 100644 --- a/nemo/collections/asr/parts/utils/offline_clustering.py +++ b/nemo/collections/asr/parts/utils/offline_clustering.py @@ -402,7 +402,6 @@ def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tenso for scale_idx in scale_list: time_stamps_float = timestamps_in_scales[scale_idx] segment_anchor_list.append(torch.mean(time_stamps_float, dim=1)) - base_scale_idx = max(scale_list) base_scale_anchor = segment_anchor_list[base_scale_idx] session_scale_mapping_list = [] @@ -463,7 +462,7 @@ def get_scale_interpolated_embs( Torch device variable Returns: - context_emb (torch.tensor): + context_emb (Tensor): A set of scale-interpolated embedding vectors. Dimensions: (Number of base-scale segments) x (Dimensions of embedding vector) session_scale_mapping_list (list): @@ -697,6 +696,19 @@ def split_input_data( timestamps_in_scales (list): List containing split timestamps tensors by each scale """ + if len(embeddings_in_scales.shape) != 2: + raise ValueError( + f"embeddings_in_scales Tensor should have 2 dimensions, but got {len(embeddings_in_scales.shape)}." + ) + elif len(timestamps_in_scales.shape) != 2: + raise ValueError( + f"timestamps_in_scales Tensor should have 2 dimensions, but got {len(timestamps_in_scales.shape)}." + ) + elif not (torch.sum(multiscale_segment_counts) == embeddings_in_scales.shape[0] == timestamps_in_scales.shape[0]): + raise ValueError( + f"multiscale_segment_counts, embeddings_in_scales, and timestamps_in_scales should have the same length, \ + but got {multiscale_segment_counts.shape[0]}, {embeddings_in_scales.shape[0]}, and {timestamps_in_scales.shape[0]} respectively." + ) split_index: List[int] = multiscale_segment_counts.tolist() embeddings_in_scales = torch.split(embeddings_in_scales, split_index, dim=0) timestamps_in_scales = torch.split(timestamps_in_scales, split_index, dim=0) @@ -781,7 +793,7 @@ def forward(self, X) -> torch.Tensor: Returns: labels (Tensor): - clustering label output + Clustering label output """ if X.shape[0] != X.shape[1]: raise ValueError("The affinity matrix is not a square matrix.") @@ -936,7 +948,6 @@ def __init__( Use cuda for Eigen decomposition if cuda=True. device (torch.device): Torch device variable - """ self.max_num_speakers: int = max_num_speakers self.max_rp_threshold: float = max_rp_threshold @@ -1129,6 +1140,8 @@ def __init__( The minimum number of samples required for NME clustering. This avoids zero p_neighbour_lists. If the input has fewer segments than min_samples, it is directed to the enhanced speaker counting mode. + nme_mat_size (int): + The targeted matrix size for NME analysis. sparse_search (bool): Toggle sparse search mode. If True, limit the size of p_value_list to sparse_search_volume. maj_vote_spk_count (bool): @@ -1151,6 +1164,86 @@ def __init__( self.timestamps_in_scales: List[torch.Tensor] = [torch.Tensor(0)] self.device = torch.device("cuda") if self.cuda else torch.device("cpu") + def forward_unit_infer( + self, + mat: torch.Tensor, + oracle_num_speakers: int = -1, + max_num_speakers: int = 8, + max_rp_threshold: float = 0.15, + sparse_search_volume: int = 30, + est_num_of_spk_enhanced: torch.Tensor = torch.tensor(-1), + fixed_thres: float = -1.0, + kmeans_random_trials: int = 1, + ) -> torch.LongTensor: + """ + This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments + in the given input embeddings. + + Args: + mat (Tensor): + Cosine similarity matrix (affinity matrix) calculated from the provided speaker embeddings. + oracle_num_speakers (int): + The number of speakers in a session, as specified by the reference transcript. + Can be used as `chunk_cluster_count` in long-form clustering mode. + max_num_speakers (int): + The upper bound for the number of speakers in each session. + max_rp_threshold (float): + Limits the range of parameter search. + The clustering performance can vary based on this range. + The default value is 0.15. + sparse_search_volume (int): + The number of p_values considered during NME analysis. + The default is 30. Lower values speed up the NME-analysis but might lead to poorer parameter estimations. Values below 20 are not recommended. + est_num_of_spk_enhanced (int): + The number of speakers estimated from enhanced speaker counting. + If the value is -1, the enhanced speaker counting is skipped. + fixed_thres (float): + If a `fixed_thres` value is provided, the NME-analysis process will be skipped. + This value should be optimized on a development set for best results. + By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold. + kmeans_random_trials (int): + The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1. + + Returns: + Y (LongTensor): + Speaker labels (clustering output) in integer format for the segments in the given input embeddings. + """ + nmesc = NMESC( + mat, + max_num_speakers=max_num_speakers, + max_rp_threshold=max_rp_threshold, + sparse_search=self.sparse_search, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, + nme_mat_size=self.nme_mat_size, + maj_vote_spk_count=self.maj_vote_spk_count, + parallelism=self.parallelism, + cuda=self.cuda, + device=self.device, + ) + # If there are less than `min_samples_for_nmesc` segments, est_num_of_spk is 1. + if mat.shape[0] > self.min_samples_for_nmesc: + est_num_of_spk, p_hat_value = nmesc.forward() + affinity_mat = getAffinityGraphMat(mat, p_hat_value) + else: + nmesc.fixed_thres = max_rp_threshold + est_num_of_spk, p_hat_value = nmesc.forward() + affinity_mat = mat + + # `n_clusters` is number of speakers estimated from spectral clustering. + if oracle_num_speakers > 0: + n_clusters = int(oracle_num_speakers) + elif est_num_of_spk_enhanced > 0: + n_clusters = int(est_num_of_spk_enhanced.item()) + else: + n_clusters = int(est_num_of_spk.item()) + + spectral_model = SpectralClustering( + n_clusters=n_clusters, n_random_trials=kmeans_random_trials, cuda=self.cuda, device=self.device + ) + Y = spectral_model.forward(affinity_mat) + return Y + def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor: """ A function wrapper designed for inference in exported script format. @@ -1166,21 +1259,18 @@ def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor: See `forward_infer` function for the argument information. Returns: - Y (LongTensor): - Speaker labels for the segments in the given input embeddings. + (LongTensor): Speaker labels for the segments in the given input embeddings. """ embeddings_in_scales = param_dict['embeddings'] timestamps_in_scales = param_dict['timestamps'] multiscale_segment_counts = param_dict['multiscale_segment_counts'] multiscale_weights = param_dict['multiscale_weights'] - oracle_num_speakers = int(param_dict['oracle_num_speakers'].item()) max_num_speakers = int(param_dict['max_num_speakers'].item()) enhanced_count_thres = int(param_dict['enhanced_count_thres'].item()) sparse_search_volume = int(param_dict['sparse_search_volume'].item()) max_rp_threshold = float(param_dict['max_rp_threshold'].item()) fixed_thres = float(param_dict['fixed_thres'].item()) - return self.forward_infer( embeddings_in_scales=embeddings_in_scales, timestamps_in_scales=timestamps_in_scales, @@ -1201,72 +1291,65 @@ def forward_infer( multiscale_segment_counts: torch.LongTensor, multiscale_weights: torch.Tensor, oracle_num_speakers: int = -1, - max_rp_threshold: float = 0.15, max_num_speakers: int = 8, + max_rp_threshold: float = 0.15, enhanced_count_thres: int = 40, sparse_search_volume: int = 30, fixed_thres: float = -1.0, kmeans_random_trials: int = 1, ) -> torch.LongTensor: """ - Calculate affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best - p-value and perform spectral clustering based on the estimated p-value and the calculated affinity matrix. + Calculate the affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best + p-value, and perform spectral clustering based on the estimated p-value and the calculated affinity matrix. Caution: - For the sake of compatibility with libtorch, python boolean `False` is replaced with `torch.LongTensor(-1)`. + For compatibility with libtorch, python boolean `False` has been replaced with `torch.LongTensor(-1)`. Args: - Dict containing following keys associated with tensors. - embeddings (Tensor): - Concatenated Torch tensor containing embeddings in multiple scales - This tensor has dimensions of (Number of base segments) x (Embedding Dimension) - timestamps (Tensor): - Concatenated Torch tensor containing timestamps in multiple scales. - This tensor has dimensions of (Total number of segments all scales) x 2 + embeddings_in_scales (Tensor): + List containing concatenated Torch tensor embeddings across multiple scales. + The length of the list is equal to the number of scales. + Each tensor has dimensions of (Number of base segments) x (Embedding Dimension). + timestamps_in_scales (Tensor): + List containing concatenated Torch tensor timestamps across multiple scales. + The length of the list is equal to the number of scales. + Each tensor has dimensions of (Total number of segments across all scales) x 2. Example: - >>> timestamps_in_scales = \ - >>> torch.tensor([0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]]) - + >>> timestamps_in_scales[0] = \ + torch.Tensor([[0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]]) multiscale_segment_counts (LongTensor): - Concatenated Torch tensor containing number of segments per each scale - This tensor has dimensions of (Number of scales) + A Torch tensor containing the number of segments for each scale. + The tensor has dimensions of (Number of scales). Example: >>> multiscale_segment_counts = torch.LongTensor([31, 52, 84, 105, 120]) - multiscale_weights (Tensor): - Multi-scale weights that are used when affinity scores are merged. + Multi-scale weights used when merging affinity scores. Example: >>> multiscale_weights = torch.tensor([1.4, 1.3, 1.2, 1.1, 1.0]) - oracle_num_speakers (int): - The number of speakers in a session from the reference transcript + The number of speakers in a session as given by the reference transcript. max_num_speakers (int): - The upper bound for the number of speakers in each session + The upper bound for the number of speakers in each session. max_rp_threshold (float): Limits the range of parameter search. - Clustering performance can vary depending on this range. - Default is 0.15. + The clustering performance can vary based on this range. + The default value is 0.15. enhanced_count_thres (int): - For the short audio recordings, clustering algorithm cannot - accumulate enough amount of speaker profile for each cluster. - Thus, function `getEnhancedSpeakerCount` employs anchor embeddings - (dummy representations) to mitigate the effect of cluster sparsity. - enhanced_count_thres = 80 is recommended. + For shorter audio recordings, the clustering algorithm might not accumulate enough speaker profiles for each cluster. + Thus, the function `getEnhancedSpeakerCount` uses anchor embeddings (dummy representations) to mitigate the effects of cluster sparsity. + A value of 80 is recommended for `enhanced_count_thres`. sparse_search_volume (int): - Number of p_values we search during NME analysis. - Default is 30. The lower the value, the faster NME-analysis becomes. - Lower than 20 might cause a poor parameter estimation. + The number of p_values considered during NME analysis. + The default is 30. Lower values speed up the NME-analysis but might lead to poorer parameter estimations. Values below 20 are not recommended. fixed_thres (float): - If fixed_thres value is provided, NME-analysis process will be skipped. - This value should be optimized on a development set to obtain a quality result. - Default is None and performs NME-analysis to estimate the threshold. + If a `fixed_thres` value is provided, the NME-analysis process will be skipped. + This value should be optimized on a development set for best results. + By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold. kmeans_random_trials (int): - Number of random trials for initializing k-means clustering. More trials - will result in a more stable clustering result. Default is 1. + The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1. Returns: - Y (LongTensor): - Speaker labels for the segments in the given input embeddings. + (LongTensor): Speaker labels for the segments in the provided input embeddings. """ self.embeddings_in_scales, self.timestamps_in_scales = split_input_data( embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts @@ -1286,42 +1369,19 @@ def forward_infer( max_num_speakers = oracle_num_speakers mat = getMultiScaleCosAffinityMatrix( - multiscale_weights, self.embeddings_in_scales, self.timestamps_in_scales, self.device + multiscale_weights=multiscale_weights, + embeddings_in_scales=self.embeddings_in_scales, + timestamps_in_scales=self.timestamps_in_scales, + device=self.device, ) - nmesc = NMESC( - mat, - max_num_speakers=max_num_speakers, + return self.forward_unit_infer( + mat=mat, + oracle_num_speakers=oracle_num_speakers, max_rp_threshold=max_rp_threshold, - sparse_search=self.sparse_search, + max_num_speakers=max_num_speakers, sparse_search_volume=sparse_search_volume, + est_num_of_spk_enhanced=est_num_of_spk_enhanced, + kmeans_random_trials=kmeans_random_trials, fixed_thres=fixed_thres, - nme_mat_size=self.nme_mat_size, - maj_vote_spk_count=self.maj_vote_spk_count, - parallelism=self.parallelism, - cuda=self.cuda, - device=self.device, ) - - # If there are less than `min_samples_for_nmesc` segments, est_num_of_spk is 1. - if mat.shape[0] > self.min_samples_for_nmesc: - est_num_of_spk, p_hat_value = nmesc.forward() - affinity_mat = getAffinityGraphMat(mat, p_hat_value) - else: - nmesc.fixed_thres = max_rp_threshold - est_num_of_spk, p_hat_value = nmesc.forward() - affinity_mat = mat - - # n_clusters is number of speakers estimated from spectral clustering. - if oracle_num_speakers > 0: - n_clusters = int(oracle_num_speakers) - elif est_num_of_spk_enhanced > 0: - n_clusters = int(est_num_of_spk_enhanced.item()) - else: - n_clusters = int(est_num_of_spk.item()) - - spectral_model = SpectralClustering( - n_clusters=n_clusters, n_random_trials=kmeans_random_trials, cuda=self.cuda, device=self.device - ) - Y = spectral_model.forward(affinity_mat) - return Y diff --git a/nemo/collections/asr/parts/utils/online_clustering.py b/nemo/collections/asr/parts/utils/online_clustering.py index 9620a87144b9..23ebe6c6dbbf 100644 --- a/nemo/collections/asr/parts/utils/online_clustering.py +++ b/nemo/collections/asr/parts/utils/online_clustering.py @@ -31,16 +31,17 @@ # https://arxiv.org/pdf/2003.02405.pdf and the implementation from # https://github.com/tango4j/Auto-Tuning-Spectral-Clustering. -from typing import List, Set, Tuple - -import numpy as np +from typing import List, Tuple import torch from nemo.collections.asr.parts.utils.offline_clustering import ( NMESC, + SpeakerClustering, SpectralClustering, + get_scale_interpolated_embs, getAffinityGraphMat, getCosAffinityMatrix, + split_input_data, ) from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment @@ -234,7 +235,7 @@ def calculate_removable_counts(removable_counts_mat: torch.Tensor, remain_count: rem_labels = remain_count_rem % (num_clus - ind) removable_counts_mat[removable_count_args[: (num_clus - ind)]] -= num_labels removable_counts_mat[removable_count_args[:rem_labels]] -= 1 - return removable_counts_mat + return removable_counts_mat.int() def get_merge_quantity( @@ -333,8 +334,12 @@ def merge_vectors( bypass_inds_list.append(k) bypass_inds = torch.tensor(bypass_inds_list) selected_inds = torch.tensor(selected_inds_list) - merged_vecs = torch.vstack((emb_ndx[bypass_inds], avg_emb)) - merged_clus_labels = torch.hstack((pre_cluster_labels[bypass_inds], merged_clus_labels[0])) + if bypass_inds.shape[0] == 0: + merged_vecs = avg_emb.unsqueeze(0) + merged_clus_labels = merged_clus_labels.unsqueeze(0) + else: + merged_vecs = torch.vstack((emb_ndx[bypass_inds], avg_emb)) + merged_clus_labels = torch.hstack((pre_cluster_labels[bypass_inds], merged_clus_labels[0])) return merged_vecs, merged_clus_labels @@ -478,7 +483,7 @@ def get_first_arg_index(mat: torch.Tensor, label: int) -> int: Label which we want to find the first occuring index Returns: - (int) The first index of the given label + (int): The first index of the given label """ return int(torch.where(mat == label)[0][0]) @@ -967,11 +972,11 @@ def update_speaker_history_buffer( ) # Merge the segments in the history buffer - for spk_idx, target_num in enumerate(list(class_target_vol)): + for spk_idx, sub_cluster_num in enumerate(list(class_target_vol)): merged_embs, merged_clus_labels, _ = run_reducer( pre_embs=pre_embs, target_spk_idx=spk_idx, - merge_quantity=target_num, + merge_quantity=sub_cluster_num.item(), pre_clus_labels=pre_clus_labels, ) total_emb.append(merged_embs) @@ -1037,7 +1042,6 @@ def get_reduced_mat(self, emb_in: torch.Tensor, base_segment_indexes: torch.Tens Boolean that indicates whether there is a new set of segments. Depending on the VAD timestamps, the number of subsegments can be ocassionally decreased. If `add_new=True`, then it adds the newly acquired cluster labels. - """ margin_seg_n = emb_in.shape[0] - (self.current_n + self.history_n) if len(self.Y_fullhist) == 0 and margin_seg_n > 0: @@ -1172,8 +1176,8 @@ def forward_infer( self.enhanced_count_thres = enhanced_count_thres self.sparse_search_volume = sparse_search_volume self.fixed_thres = fixed_thres - # Merge the closest embeddings and reduce the size of the embedding count. + # Merge the closest embeddings and reduce the size of the embedding count. if cuda and (curr_emb.device == torch.device("cpu") or base_segment_indexes.device == torch.device("cpu")): raise ValueError(f"CUDA is enabled but the input {curr_emb} or {base_segment_indexes} is not on the GPU.") diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 413b24f97e81..5d3a0bf4274e 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -28,10 +28,10 @@ from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal +from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data from nemo.utils import logging - """ This file contains all the utility functions required for speaker embeddings part in diarization scripts """ @@ -464,9 +464,8 @@ def perform_clustering( logging.warning("cuda=False, using CPU for eigen decomposition. This might slow down the clustering process.") cuda = False - speaker_clustering = SpeakerClustering(cuda=cuda) + speaker_clustering = LongFormSpeakerClustering(cuda=cuda) - # If True, export torch script module and save it to the base folder. if clustering_params.get('export_script_module', False): speaker_clustering = torch.jit.script(speaker_clustering) torch.jit.save(speaker_clustering, 'speaker_clustering_script.pt') @@ -492,6 +491,8 @@ def perform_clustering( max_num_speakers=int(clustering_params.max_num_speakers), max_rp_threshold=float(clustering_params.max_rp_threshold), sparse_search_volume=int(clustering_params.sparse_search_volume), + chunk_cluster_count=clustering_params.get('chunk_cluster_count', None), + embeddings_per_chunk=clustering_params.get('embeddings_per_chunk', None), ) del uniq_embs_and_timestamps @@ -499,8 +500,8 @@ def perform_clustering( torch.cuda.empty_cache() else: gc.collect() - timestamps = speaker_clustering.timestamps_in_scales[base_scale_idx] + cluster_labels = cluster_labels.cpu().numpy() if len(cluster_labels) != timestamps.shape[0]: raise ValueError("Mismatch of length between cluster_labels and timestamps.") diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index fea232cee31b..f48292d27981 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -20,6 +20,7 @@ from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment from nemo.collections.asr.data.audio_to_label import repeat_signal +from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering from nemo.collections.asr.parts.utils.offline_clustering import ( SpeakerClustering, get_scale_interpolated_embs, @@ -563,6 +564,23 @@ def test_get_k_neighbors_connections(self, p_value: int, N: int, mask_method: st elif mask_method == 'drop': assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value)) + @pytest.mark.unit + @pytest.mark.parametrize("Y_aggr", [torch.tensor([0, 1, 0, 1])]) + @pytest.mark.parametrize("chunk_cluster_count, embeddings_per_chunk", [(2, 50)]) + @pytest.mark.parametrize("window_range_list", [[[0, 1], [1, 2], [2, 3], [3, 4]]]) + @pytest.mark.parametrize( + "absolute_merge_mapping", + [[[torch.tensor([]), torch.tensor([0, 2])], [torch.tensor([]), torch.tensor([1, 3])]]], + ) + @pytest.mark.parametrize("org_len", [4]) + def test_unpack_labels( + self, Y_aggr, window_range_list, absolute_merge_mapping, chunk_cluster_count, embeddings_per_chunk, org_len + ): + expected_result = Y_aggr + longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) + output = longform_speaker_clustering.unpack_labels(Y_aggr, window_range_list, absolute_merge_mapping, org_len) + assert torch.equal(output, expected_result) + class TestSpeakerClustering: """ @@ -679,7 +697,7 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed + self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -704,6 +722,102 @@ def test_offline_speaker_clustering_very_short_gpu( assert Y_out.shape[0] == mc[-1] assert all(permuted_Y == gt) + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) + @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) + @pytest.mark.parametrize("seed", [0]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_longform_speaker_clustering_cpu( + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + chunk_cluster_count, + embeddings_per_chunk, + jit_script, + seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) + if jit_script: + longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) + else: + assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) + Y_out = longform_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=n_spks, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + chunk_cluster_count=chunk_cluster_count, + embeddings_per_chunk=embeddings_per_chunk, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + + # mc[-1] is the number of base scale segments + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) + @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) + @pytest.mark.parametrize("seed", [0]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_longform_speaker_clustering_gpu( + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + chunk_cluster_count, + embeddings_per_chunk, + jit_script, + seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + longform_speaker_clustering = LongFormSpeakerClustering(cuda=True) + + if jit_script: + longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) + else: + assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) + + Y_out = longform_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=n_spks, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + chunk_cluster_count=chunk_cluster_count, + embeddings_per_chunk=embeddings_per_chunk, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + + # mc[-1] is the number of base scale segments + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + @pytest.mark.run_only_on('GPU') @pytest.mark.unit @pytest.mark.parametrize("n_spks", [1, 2, 3])