Skip to content

Commit

Permalink
clustering: use dict in forward
Browse files Browse the repository at this point in the history
This allows triton inference using human readable dict keys for parameter
names instead of cryptic INPUT_x notation

Signed-off-by: Viraj Karandikar <[email protected]>
  • Loading branch information
virajkarandikar committed Nov 11, 2022
1 parent d01a991 commit dafa2de
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
42 changes: 21 additions & 21 deletions nemo/collections/asr/parts/utils/nmesc_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,31 +1107,26 @@ def __init__(
self.timestamps_in_scales: List[torch.Tensor] = [torch.Tensor(0)]
self.device = torch.device("cuda") if self.cuda else torch.device("cpu")

def forward(
self,
embeddings_in_scales: torch.Tensor,
timestamps_in_scales: torch.Tensor,
multiscale_segment_counts: torch.LongTensor,
multiscale_weights: torch.Tensor,
oracle_num_speakers: torch.LongTensor,
max_num_speakers: torch.LongTensor,
enhanced_count_thres: torch.LongTensor = torch.LongTensor([80]),
sparse_search_volume: torch.LongTensor = torch.LongTensor([30]),
max_rp_threshold: torch.Tensor = torch.Tensor([0.15]),
fixed_thres: torch.Tensor = torch.Tensor([-1.0]),
) -> torch.LongTensor:
def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor:

"""
Calculate affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best
p-value and perform spectral clustering based on the estimated p-value and the calculated affinity matrix.
Caution:
For the sake of compatibility with libtorch, python boolean `False` is replaced with `torch.LongTensor(-1)`.
Note:
Dict is used to allow easy inference of the exported jit model in Triton server using easy to understand
naming convention.
See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#special-conventions-for-pytorch-backend
Args:
embeddings_in_scales (Tensor):
Dict containing following keys associated with tensors.
embeddings (Tensor):
Concatenated Torch tensor containing embeddings in multiple scales
This tensor has dimensions of (Number of base segments) x (Embedding Dimension)
timestamps_in_scales (Tensor):
timestamps (Tensor):
Concatenated Torch tensor containing timestamps in multiple scales.
This tensor has dimensions of (Total number of segments all scales) x 2
Example:
Expand Down Expand Up @@ -1177,12 +1172,17 @@ def forward(
Speaker label for each segment.
"""

oracle_num_speakers = int(oracle_num_speakers.item())
max_num_speakers = int(max_num_speakers.item())
enhanced_count_thres = int(enhanced_count_thres.item())
sparse_search_volume = int(sparse_search_volume.item())
max_rp_threshold = float(max_rp_threshold.item())
fixed_thres = float(fixed_thres.item())
embeddings_in_scales = param_dict['embeddings']
timestamps_in_scales = param_dict['timestamps']
multiscale_segment_counts = param_dict['multiscale_segment_counts']
multiscale_weights = param_dict['multiscale_weights']

oracle_num_speakers = int(param_dict['oracle_num_speakers'].item())
max_num_speakers = int(param_dict['max_num_speakers'].item())
enhanced_count_thres = int(param_dict['enhanced_count_thres'].item())
sparse_search_volume = int(param_dict['sparse_search_volume'].item())
max_rp_threshold = float(param_dict['max_rp_threshold'].item())
fixed_thres = float(param_dict['fixed_thres'].item())

self.embeddings_in_scales, self.timestamps_in_scales = split_input_data(
embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts
Expand Down
26 changes: 14 additions & 12 deletions nemo/collections/asr/parts/utils/speaker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,18 +464,20 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste
else:
num_speakers = -1

cluster_labels = speaker_clustering.forward(
embeddings_in_scales=uniq_embs_and_timestamps['embeddings'],
timestamps_in_scales=uniq_embs_and_timestamps['timestamps'],
multiscale_segment_counts=uniq_embs_and_timestamps['multiscale_segment_counts'],
multiscale_weights=uniq_embs_and_timestamps['multiscale_weights'],
oracle_num_speakers=torch.LongTensor([num_speakers]),
max_num_speakers=torch.LongTensor([clustering_params.max_num_speakers]),
enhanced_count_thres=torch.LongTensor([80]),
sparse_search_volume=torch.LongTensor([clustering_params.sparse_search_volume]),
max_rp_threshold=torch.tensor([clustering_params.max_rp_threshold]),
fixed_thres=torch.tensor([-1.0]),
)
clustering_param_dict = {
'embeddings': uniq_embs_and_timestamps['embeddings'],
'timestamps': uniq_embs_and_timestamps['timestamps'],
'multiscale_segment_counts': uniq_embs_and_timestamps['multiscale_segment_counts'],
'multiscale_weights': uniq_embs_and_timestamps['multiscale_weights'],
'oracle_num_speakers': torch.LongTensor([num_speakers]),
'max_num_speakers': torch.LongTensor([clustering_params.max_num_speakers]),
'enhanced_count_thres': torch.LongTensor([80]),
'sparse_search_volume': torch.LongTensor([clustering_params.sparse_search_volume]),
'max_rp_threshold': torch.tensor([clustering_params.max_rp_threshold]),
'fixed_thres': torch.tensor([-1.0]),
}

cluster_labels = speaker_clustering.forward(clustering_param_dict)

base_scale_idx = uniq_embs_and_timestamps['multiscale_segment_counts'].shape[0] - 1
timestamps = speaker_clustering.timestamps_in_scales[base_scale_idx]
Expand Down

0 comments on commit dafa2de

Please sign in to comment.