Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <[email protected]>
  • Loading branch information
SeanNaren committed Feb 22, 2023
1 parent f0ad6d2 commit 7258142
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/asr/models/msdd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,13 +1176,13 @@ def get_emb_clus_infer(self, cluster_embeddings):
self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict
self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test

@torch.no_grad()
def diarize(self) -> Optional[List[Optional[List[Tuple[DiarizationErrorRate, Dict]]]]]:
"""
Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), initialization clustering and multiscale diarization decoder (MSDD).
Note that the result of MSDD can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on `make_rttm_with_overlap()`
function that can generate overlapping timestamps. `self.run_overlap_aware_eval()` function performs DER evaluation.
"""
torch.set_grad_enabled(False)
self.clustering_embedding.prepare_cluster_embs_infer()
self.msdd_model.pairwise_infer = True
self.get_emb_clus_infer(self.clustering_embedding)
Expand Down Expand Up @@ -1242,7 +1242,7 @@ def get_range_average(
target_clus_label_bool = target_clus_label_tensor == test_data_collection.target_spks[spk_idx]

# There are cases where there is no corresponding speaker in split range, so any(target_clus_label_bool) could be False.
if any(target_clus_label_bool) == True:
if any(target_clus_label_bool):
emb_vectors_split[:, :, spk_idx] = torch.mean(emb_seq[target_clus_label_bool], dim=0)

# In case when the loop reaches the end of the sequence
Expand Down Expand Up @@ -1354,6 +1354,7 @@ def diar_infer(
preds[:, : _preds.shape[1], :] = _preds
return preds, targets, signal_lengths

@torch.no_grad()
def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Setup the parameters needed for batch inference and run batch inference. Note that each sample is pairwise speaker input.
Expand All @@ -1370,7 +1371,6 @@ def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tenso
self.out_rttm_dir = self.clustering_embedding.out_rttm_dir
self.msdd_model.setup_test_data(self.msdd_model.cfg.test_ds)
self.msdd_model.eval()
torch.set_grad_enabled(False)
cumul_sample_count = [0]
preds_list, targets_list, signal_lengths_list = [], [], []
uniq_id_list = get_uniq_id_list_from_manifest(self.msdd_model.cfg.test_ds.manifest_filepath)
Expand Down

0 comments on commit 7258142

Please sign in to comment.