-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
msdd_models.py
1415 lines (1269 loc) · 70.2 KB
/
msdd_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import pickle as pkl
import tempfile
from collections import OrderedDict
from statistics import mode
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, open_dict
from pyannote.metrics.diarization import DiarizationErrorRate
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechMSDDInferDataset, AudioToSpeechMSDDTrainDataset
from nemo.collections.asr.metrics.der import score_labels
from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy
from nemo.collections.asr.models import ClusteringDiarizer
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
from nemo.collections.asr.models.clustering_diarizer import (
_MODEL_CONFIG_YAML,
_SPEAKER_MODEL,
_VAD_MODEL,
get_available_model_names,
)
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.utils.speaker_utils import (
audio_rttm_map,
get_embs_and_timestamps,
get_id_tup_dict,
get_scale_mapping_argmat,
get_uniq_id_list_from_manifest,
make_rttm_with_overlap,
parse_scale_configs,
)
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
from nemo.core.neural_types.elements import ProbsType
from nemo.utils import logging
try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager
@contextmanager
def autocast(enabled=None):
yield
__all__ = ['EncDecDiarLabelModel', 'ClusterEmbedding', 'NeuralDiarizer']
class EncDecDiarLabelModel(ModelPT, ExportableEncDecModel):
"""
Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, validation methods for setting
up data performing model forward pass.
This model class expects config dict for:
* preprocessor
* msdd_model
* speaker_model
"""
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
result = []
model = PretrainedModelInfo(
pretrained_model_name="diar_msdd_telephonic",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.0/files/diar_msdd_telephonic.nemo",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic",
)
result.append(model)
return result
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
"""
Initialize an MSDD model and the specified speaker embedding model. In this init function, training and validation datasets are prepared.
"""
self._trainer = trainer if trainer else None
self.cfg_msdd_model = cfg
if self._trainer:
self._init_segmentation_info()
self.world_size = trainer.num_nodes * trainer.num_devices
self.emb_batch_size = self.cfg_msdd_model.emb_batch_size
self.pairwise_infer = False
else:
self.world_size = 1
self.pairwise_infer = True
super().__init__(cfg=self.cfg_msdd_model, trainer=trainer)
window_length_in_sec = self.cfg_msdd_model.diarizer.speaker_embeddings.parameters.window_length_in_sec
if isinstance(window_length_in_sec, int) or len(window_length_in_sec) <= 1:
raise ValueError("window_length_in_sec should be a list containing multiple segment (window) lengths")
else:
self.cfg_msdd_model.scale_n = len(window_length_in_sec)
self.cfg_msdd_model.msdd_module.scale_n = self.cfg_msdd_model.scale_n
self.scale_n = self.cfg_msdd_model.scale_n
self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(self.cfg_msdd_model.preprocessor)
self.frame_per_sec = int(1 / self.preprocessor._cfg.window_stride)
self.msdd = EncDecDiarLabelModel.from_config_dict(self.cfg_msdd_model.msdd_module)
if trainer is not None:
self._init_speaker_model()
self.add_speaker_model_config(cfg)
else:
self.msdd._speaker_model = EncDecSpeakerLabelModel.from_config_dict(cfg.speaker_model_cfg)
# Call `self.save_hyperparameters` in modelPT.py again since cfg should contain speaker model's config.
self.save_hyperparameters("cfg")
self.loss = instantiate(self.cfg_msdd_model.loss)
self._accuracy_test = MultiBinaryAccuracy()
self._accuracy_train = MultiBinaryAccuracy()
self._accuracy_valid = MultiBinaryAccuracy()
def add_speaker_model_config(self, cfg):
"""
Add config dictionary of the speaker model to the model's config dictionary. This is required to
save and load speaker model with MSDD model.
Args:
cfg (DictConfig): DictConfig type variable that conatains hyperparameters of MSDD model.
"""
with open_dict(cfg):
cfg_cp = copy.copy(self.msdd._speaker_model.cfg)
cfg.speaker_model_cfg = cfg_cp
del cfg.speaker_model_cfg.train_ds
del cfg.speaker_model_cfg.validation_ds
def _init_segmentation_info(self):
"""Initialize segmentation settings: window, shift and multiscale weights.
"""
self._diarizer_params = self.cfg_msdd_model.diarizer
self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec,
self._diarizer_params.speaker_embeddings.parameters.multiscale_weights,
)
def _init_speaker_model(self):
"""
Initialize speaker embedding model with model name or path passed through config. Note that speaker embedding model is loaded to
`self.msdd` to enable multi-gpu and multi-node training. In addition, speaker embedding model is also saved with msdd model when
`.ckpt` files are saved.
"""
model_path = self.cfg_msdd_model.diarizer.speaker_embeddings.model_path
self._diarizer_params = self.cfg_msdd_model.diarizer
if not torch.cuda.is_available():
rank_id = torch.device('cpu')
elif self._trainer:
rank_id = torch.device(self._trainer.global_rank)
else:
rank_id = None
if model_path is not None and model_path.endswith('.nemo'):
self.msdd._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path, map_location=rank_id)
logging.info("Speaker Model restored locally from {}".format(model_path))
elif model_path.endswith('.ckpt'):
self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path, map_location=rank_id)
logging.info("Speaker Model restored locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecSpeakerLabelModel):
logging.warning(
"requested {} model name not available in pretrained models, instead".format(model_path)
)
model_path = "titanet_large"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self.msdd._speaker_model = EncDecSpeakerLabelModel.from_pretrained(
model_name=model_path, map_location=rank_id
)
self._speaker_params = self.cfg_msdd_model.diarizer.speaker_embeddings.parameters
def __setup_dataloader_from_config(self, config):
featurizer = WaveformFeaturizer(
sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=None
)
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
return None
dataset = AudioToSpeechMSDDTrainDataset(
manifest_filepath=config.manifest_filepath,
emb_dir=config.emb_dir,
multiscale_args_dict=self.multiscale_args_dict,
soft_label_thres=config.soft_label_thres,
featurizer=featurizer,
window_stride=self.cfg_msdd_model.preprocessor.window_stride,
emb_batch_size=config.emb_batch_size,
pairwise_infer=False,
global_rank=self._trainer.global_rank,
)
self.data_collection = dataset.collection
collate_ds = dataset
collate_fn = collate_ds.msdd_train_collate_fn
batch_size = config['batch_size']
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=False,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)
def __setup_dataloader_from_config_infer(
self, config: DictConfig, emb_dict: dict, emb_seq: dict, clus_label_dict: dict, pairwise_infer=False
):
shuffle = config.get('shuffle', False)
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
return None
dataset = AudioToSpeechMSDDInferDataset(
manifest_filepath=config['manifest_filepath'],
emb_dict=emb_dict,
clus_label_dict=clus_label_dict,
emb_seq=emb_seq,
soft_label_thres=config.soft_label_thres,
seq_eval_mode=config.seq_eval_mode,
window_stride=self._cfg.preprocessor.window_stride,
use_single_scale_clus=False,
pairwise_infer=pairwise_infer,
)
self.data_collection = dataset.collection
collate_ds = dataset
collate_fn = collate_ds.msdd_infer_collate_fn
batch_size = config['batch_size']
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)
def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,)
def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]):
self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,)
def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
if self.pairwise_infer:
self._test_dl = self.__setup_dataloader_from_config_infer(
config=test_data_config,
emb_dict=self.emb_sess_test_dict,
emb_seq=self.emb_seq_test,
clus_label_dict=self.clus_test_label_dict,
pairwise_infer=self.pairwise_infer,
)
def setup_multiple_test_data(self, test_data_config):
"""
MSDD does not use multiple_test_data template. This function is a placeholder for preventing error.
"""
return None
def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
if hasattr(self.preprocessor, '_sample_rate'):
audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
else:
audio_eltype = AudioSignal()
return {
"features": NeuralType(('B', 'T'), audio_eltype),
"feature_length": NeuralType(('B',), LengthsType()),
"ms_seg_timestamps": NeuralType(('B', 'C', 'T', 'D'), LengthsType()),
"ms_seg_counts": NeuralType(('B', 'C'), LengthsType()),
"clus_label_index": NeuralType(('B', 'T'), LengthsType()),
"scale_mapping": NeuralType(('B', 'C', 'T'), LengthsType()),
"targets": NeuralType(('B', 'T', 'C'), ProbsType()),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
return OrderedDict(
{
"probs": NeuralType(('B', 'T', 'C'), ProbsType()),
"scale_weights": NeuralType(('B', 'T', 'C', 'D'), ProbsType()),
}
)
def get_ms_emb_seq(
self, embs: torch.Tensor, scale_mapping: torch.Tensor, ms_seg_counts: torch.Tensor
) -> torch.Tensor:
"""
Reshape the given tensor and organize the embedding sequence based on the original sequence counts.
Repeat the embeddings according to the scale_mapping information so that the final embedding sequence has
the identical length for all scales.
Args:
embs (Tensor):
Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details.
Shape: (Total number of segments in the batch, emb_dim)
scale_mapping (Tensor):
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
segment index which has the closest center distance with (n+1)-th segment in the base scale.
Example:
scale_mapping_argmat[2][101] = 85
In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with
102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since
multiple base scale segments (since the base scale has the shortest length) fall into the range of the
longer segments. At the same time, each row contains N numbers of indices where N is number of
segments in the base-scale (i.e., the finest scale).
Shape: (batch_size, scale_n, self.diar_window_length)
ms_seg_counts (Tensor):
Cumulative sum of the number of segments in each scale. This information is needed to reconstruct
the multi-scale input matrix during forward propagating.
Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]
In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without
zero-padding.
Returns:
ms_emb_seq (Tensor):
Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated,
while shorter scales are more frequently repeated following the scale mapping tensor.
"""
scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0]
split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0)
batch_emb_list = [split_emb_tup[i : i + scale_n] for i in range(0, len(split_emb_tup), scale_n)]
ms_emb_seq_list = []
for batch_idx in range(batch_size):
feats_list = []
for scale_index in range(scale_n):
repeat_mat = scale_mapping[batch_idx][scale_index]
feats_list.append(batch_emb_list[batch_idx][scale_index][repeat_mat, :])
repp = torch.stack(feats_list).permute(1, 0, 2)
ms_emb_seq_list.append(repp)
ms_emb_seq = torch.stack(ms_emb_seq_list)
return ms_emb_seq
@torch.no_grad()
def get_cluster_avg_embs_model(
self, embs: torch.Tensor, clus_label_index: torch.Tensor, ms_seg_counts: torch.Tensor, scale_mapping
) -> torch.Tensor:
"""
Calculate the cluster-average speaker embedding based on the ground-truth speaker labels (i.e., cluster labels).
Args:
embs (Tensor):
Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details.
Shape: (Total number of segments in the batch, emb_dim)
clus_label_index (Tensor):
Merged ground-truth cluster labels from all scales with zero-padding. Each scale's index can be
retrieved by using segment index in `ms_seg_counts`.
Shape: (batch_size, maximum total segment count among the samples in the batch)
ms_seg_counts (Tensor):
Cumulative sum of the number of segments in each scale. This information is needed to reconstruct
multi-scale input tensors during forward propagating.
Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]
Counts of merged segments: (121, 131, 118)
embs has shape of (370, 192)
clus_label_index has shape of (3, 131)
Shape: (batch_size, scale_n)
Returns:
ms_avg_embs (Tensor):
Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used as reference for
each speaker to predict the speaker label for the given multi-scale embedding sequences.
Shape: (batch_size, scale_n, emb_dim, self.num_spks_per_model)
"""
scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0]
split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0)
batch_emb_list = [split_emb_tup[i : i + scale_n] for i in range(0, len(split_emb_tup), scale_n)]
ms_avg_embs_list = []
for batch_idx in range(batch_size):
oracle_clus_idx = clus_label_index[batch_idx]
max_seq_len = sum(ms_seg_counts[batch_idx])
clus_label_index_batch = torch.split(oracle_clus_idx[:max_seq_len], ms_seg_counts[batch_idx].tolist())
session_avg_emb_set_list = []
for scale_index in range(scale_n):
spk_set_list = []
for idx in range(self.cfg_msdd_model.max_num_of_spks):
_where = (clus_label_index_batch[scale_index] == idx).clone().detach()
if not torch.any(_where):
avg_emb = torch.zeros(self.msdd._speaker_model._cfg.decoder.emb_sizes).to(embs.device)
else:
avg_emb = torch.mean(batch_emb_list[batch_idx][scale_index][_where], dim=0)
spk_set_list.append(avg_emb)
session_avg_emb_set_list.append(torch.stack(spk_set_list))
session_avg_emb_set = torch.stack(session_avg_emb_set_list)
ms_avg_embs_list.append(session_avg_emb_set)
ms_avg_embs = torch.stack(ms_avg_embs_list).permute(0, 1, 3, 2)
ms_avg_embs = ms_avg_embs.float().detach().to(embs.device)
assert (
not ms_avg_embs.requires_grad
), "ms_avg_embs.requires_grad = True. ms_avg_embs should be detached from the torch graph."
return ms_avg_embs
@torch.no_grad()
def get_ms_mel_feat(
self,
processed_signal: torch.Tensor,
processed_signal_len: torch.Tensor,
ms_seg_timestamps: torch.Tensor,
ms_seg_counts: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Load acoustic feature from audio segments for each scale and save it into a torch.tensor matrix.
In addition, create variables containing the information of the multiscale subsegmentation information.
Note: `self.emb_batch_size` determines the number of embedding tensors attached to the computational graph.
If `self.emb_batch_size` is greater than 0, speaker embedding models are simultaneosly trained. Due to the
constrant of GPU memory size, only a subset of embedding tensors can be attached to the computational graph.
By default, the graph-attached embeddings are selected randomly by `torch.randperm`. Default value of
`self.emb_batch_size` is 0.
Args:
processed_signal (Tensor):
Zero-padded Feature input.
Shape: (batch_size, feat_dim, the longest feature sequence length)
processed_signal_len (Tensor):
The actual legnth of feature input without zero-padding.
Shape: (batch_size,)
ms_seg_timestamps (Tensor):
Timestamps of the base-scale segments.
Shape: (batch_size, scale_n, number of base-scale segments, self.num_spks_per_model)
ms_seg_counts (Tensor):
Cumulative sum of the number of segments in each scale. This information is needed to reconstruct
the multi-scale input matrix during forward propagating.
Shape: (batch_size, scale_n)
Returns:
ms_mel_feat (Tensor):
Feature input stream split into the same length.
Shape: (total number of segments, feat_dim, self.frame_per_sec * the-longest-scale-length)
ms_mel_feat_len (Tensor):
The actual length of feature without zero-padding.
Shape: (total number of segments,)
seq_len (Tensor):
The length of the input embedding sequences.
Shape: (total number of segments,)
detach_ids (tuple):
Tuple containing both detached embeding indices and attached embedding indices
"""
device = processed_signal.device
_emb_batch_size = min(self.emb_batch_size, ms_seg_counts.sum().item())
feat_dim = self.preprocessor._cfg.features
max_sample_count = int(self.multiscale_args_dict["scale_dict"][0][0] * self.frame_per_sec)
ms_mel_feat_len_list, sequence_lengths_list, ms_mel_feat_list = [], [], []
total_seg_count = torch.sum(ms_seg_counts)
batch_size = processed_signal.shape[0]
for batch_idx in range(batch_size):
for scale_idx in range(self.scale_n):
scale_seg_num = ms_seg_counts[batch_idx][scale_idx]
for k, (stt, end) in enumerate(ms_seg_timestamps[batch_idx][scale_idx][:scale_seg_num]):
stt, end = int(stt.detach().item()), int(end.detach().item())
end = min(end, stt + max_sample_count)
_features = torch.zeros(feat_dim, max_sample_count).to(torch.float32).to(device)
_features[:, : (end - stt)] = processed_signal[batch_idx][:, stt:end]
ms_mel_feat_list.append(_features)
ms_mel_feat_len_list.append(end - stt)
sequence_lengths_list.append(ms_seg_counts[batch_idx][-1])
ms_mel_feat = torch.stack(ms_mel_feat_list).to(device)
ms_mel_feat_len = torch.tensor(ms_mel_feat_len_list).to(device)
seq_len = torch.tensor(sequence_lengths_list).to(device)
if _emb_batch_size == 0:
attached, _emb_batch_size = torch.tensor([]), 0
detached = torch.arange(total_seg_count)
else:
torch.manual_seed(self._trainer.current_epoch)
attached = torch.randperm(total_seg_count)[:_emb_batch_size]
detached = torch.randperm(total_seg_count)[_emb_batch_size:]
detach_ids = (attached, detached)
return ms_mel_feat, ms_mel_feat_len, seq_len, detach_ids
def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets):
"""
Wrapper function for inference case.
"""
preds, scale_weights = self.msdd(
ms_emb_seq=input_signal, length=input_signal_length, ms_avg_embs=emb_vectors, targets=targets
)
return preds, scale_weights
@typecheck()
def forward(
self, features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets
):
processed_signal, processed_signal_len = self.msdd._speaker_model.preprocessor(
input_signal=features, length=feature_length
)
audio_signal, audio_signal_len, sequence_lengths, detach_ids = self.get_ms_mel_feat(
processed_signal, processed_signal_len, ms_seg_timestamps, ms_seg_counts
)
# For detached embeddings
with torch.no_grad():
self.msdd._speaker_model.eval()
logits, embs_d = self.msdd._speaker_model.forward_for_export(
processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]]
)
embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device)
embs[detach_ids[1], :] = embs_d.detach()
# For attached embeddings
self.msdd._speaker_model.train()
if len(detach_ids[0]) > 1:
logits, embs_a = self.msdd._speaker_model.forward_for_export(
processed_signal=audio_signal[detach_ids[0]], processed_signal_len=audio_signal_len[detach_ids[0]]
)
embs[detach_ids[0], :] = embs_a
ms_emb_seq = self.get_ms_emb_seq(embs, scale_mapping, ms_seg_counts)
ms_avg_embs = self.get_cluster_avg_embs_model(embs, clus_label_index, ms_seg_counts, scale_mapping)
preds, scale_weights = self.msdd(
ms_emb_seq=ms_emb_seq, length=sequence_lengths, ms_avg_embs=ms_avg_embs, targets=targets
)
return preds, scale_weights
def training_step(self, batch: list, batch_idx: int):
features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch
sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts.detach()])
preds, _ = self.forward(
features=features,
feature_length=feature_length,
ms_seg_timestamps=ms_seg_timestamps,
ms_seg_counts=ms_seg_counts,
clus_label_index=clus_label_index,
scale_mapping=scale_mapping,
targets=targets,
)
loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths)
self._accuracy_train(preds, targets, sequence_lengths)
torch.cuda.empty_cache()
f1_acc = self._accuracy_train.compute()
self.log('loss', loss, sync_dist=True)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'], sync_dist=True)
self.log('train_f1_acc', f1_acc, sync_dist=True)
self._accuracy_train.reset()
return {'loss': loss}
def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0):
features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch
sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts])
preds, _ = self.forward(
features=features,
feature_length=feature_length,
ms_seg_timestamps=ms_seg_timestamps,
ms_seg_counts=ms_seg_counts,
clus_label_index=clus_label_index,
scale_mapping=scale_mapping,
targets=targets,
)
loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths)
self._accuracy_valid(preds, targets, sequence_lengths)
f1_acc = self._accuracy_valid.compute()
self.log('val_loss', loss, sync_dist=True)
self.log('val_f1_acc', f1_acc, sync_dist=True)
return {
'val_loss': loss,
'val_f1_acc': f1_acc,
}
def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
f1_acc = self._accuracy_valid.compute()
self._accuracy_valid.reset()
self.log('val_loss', val_loss_mean, sync_dist=True)
self.log('val_f1_acc', f1_acc, sync_dist=True)
return {
'val_loss': val_loss_mean,
'val_f1_acc': f1_acc,
}
def multi_test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0):
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
f1_acc = self._accuracy_test.compute()
self._accuracy_test.reset()
self.log('test_f1_acc', f1_acc, sync_dist=True)
return {
'test_loss': test_loss_mean,
'test_f1_acc': f1_acc,
}
def compute_accuracies(self):
"""
Calculate F1 score and accuracy of the predicted sigmoid values.
Returns:
f1_score (float):
F1 score of the estimated diarized speaker label sequences.
simple_acc (float):
Accuracy of predicted speaker labels: (total # of correct labels)/(total # of sigmoid values)
"""
f1_score = self._accuracy_test.compute()
num_correct = torch.sum(self._accuracy_test.true.bool())
total_count = torch.prod(torch.tensor(self._accuracy_test.targets.shape))
simple_acc = num_correct / total_count
return f1_score, simple_acc
class ClusterEmbedding:
"""
This class is built for calculating cluster-average embeddings, segmentation and load/save of the estimated cluster labels.
The methods in this class is used for the inference of MSDD models.
Args:
cfg_diar_infer (DictConfig):
Config dictionary from diarization inference YAML file
cfg_msdd_model (DictConfig):
Config dictionary from MSDD model checkpoint file
Class Variables:
self.cfg_diar_infer (DictConfig):
Config dictionary from diarization inference YAML file
cfg_msdd_model (DictConfig):
Config dictionary from MSDD model checkpoint file
self._speaker_model (class `EncDecSpeakerLabelModel`):
This is a placeholder for class instance of `EncDecSpeakerLabelModel`
self.scale_window_length_list (list):
List containing the window lengths (i.e., scale length) of each scale.
self.scale_n (int):
Number of scales for multi-scale clustering diarizer
self.base_scale_index (int):
The index of the base-scale which is the shortest scale among the given multiple scales
"""
def __init__(self, cfg_diar_infer: DictConfig, cfg_msdd_model: DictConfig):
self.cfg_diar_infer = cfg_diar_infer
self._cfg_msdd = cfg_msdd_model
self.clus_diar_model = None
self._speaker_model = None
self.scale_window_length_list = list(
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.window_length_in_sec
)
self.scale_n = len(self.scale_window_length_list)
self.base_scale_index = len(self.scale_window_length_list) - 1
def prepare_cluster_embs_infer(self):
"""
Launch clustering diarizer to prepare embedding vectors and clustering results.
"""
self.max_num_speakers = self.cfg_diar_infer.diarizer.clustering.parameters.max_num_speakers
self.emb_sess_test_dict, self.emb_seq_test, self.clus_test_label_dict, _ = self.run_clustering_diarizer(
self._cfg_msdd.test_ds.manifest_filepath, self._cfg_msdd.test_ds.emb_dir
)
def assign_labels_to_longer_segs(self, base_clus_label_dict: Dict, session_scale_mapping_dict: Dict):
"""
In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale).
To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns
clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the
base-scale and non-base-scales.
Args:
base_clus_label_dict (dict):
Dictionary containing clustering results for base-scale segments. Indexed by `uniq_id` string.
session_scale_mapping_dict (dict):
Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string.
Returns:
all_scale_clus_label_dict (dict):
Dictionary containing clustering labels of all scales. Indexed by scale_index in integer format.
"""
all_scale_clus_label_dict = {scale_index: {} for scale_index in range(self.scale_n)}
for uniq_id, uniq_scale_mapping_dict in session_scale_mapping_dict.items():
base_scale_clus_label = np.array([x[-1] for x in base_clus_label_dict[uniq_id]])
all_scale_clus_label_dict[self.base_scale_index][uniq_id] = base_scale_clus_label
for scale_index in range(self.scale_n - 1):
new_clus_label = []
assert (
uniq_scale_mapping_dict[scale_index].shape[0] == base_scale_clus_label.shape[0]
), "The number of base scale labels does not match the segment numbers in uniq_scale_mapping_dict"
max_index = max(uniq_scale_mapping_dict[scale_index])
for seg_idx in range(max_index + 1):
if seg_idx in uniq_scale_mapping_dict[scale_index]:
seg_clus_label = mode(base_scale_clus_label[uniq_scale_mapping_dict[scale_index] == seg_idx])
else:
seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1]
new_clus_label.append(seg_clus_label)
all_scale_clus_label_dict[scale_index][uniq_id] = new_clus_label
return all_scale_clus_label_dict
def get_base_clus_label_dict(self, clus_labels: List[str], emb_scale_seq_dict: Dict[int, dict]):
"""
Retrieve base scale clustering labels from `emb_scale_seq_dict`.
Args:
clus_labels (list):
List containing cluster results generated by clustering diarizer.
emb_scale_seq_dict (dict):
Dictionary containing multiscale embedding input sequences.
Returns:
base_clus_label_dict (dict):
Dictionary containing start and end of base scale segments and its cluster label. Indexed by `uniq_id`.
emb_dim (int):
Embedding dimension in integer.
"""
base_clus_label_dict = {key: [] for key in emb_scale_seq_dict[self.base_scale_index].keys()}
for line in clus_labels:
uniq_id = line.split()[0]
label = int(line.split()[-1].split('_')[-1])
stt, end = [round(float(x), 2) for x in line.split()[1:3]]
base_clus_label_dict[uniq_id].append([stt, end, label])
emb_dim = emb_scale_seq_dict[0][uniq_id][0].shape[0]
return base_clus_label_dict, emb_dim
def get_cluster_avg_embs(
self, emb_scale_seq_dict: Dict, clus_labels: List, speaker_mapping_dict: Dict, session_scale_mapping_dict: Dict
):
"""
MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker)
and each scale.
Args:
emb_scale_seq_dict (dict):
Dictionary containing embedding sequence for each scale. Keys are scale index in integer.
clus_labels (list):
Clustering results from clustering diarizer including all the sessions provided in input manifest files.
speaker_mapping_dict (dict):
Speaker mapping dictionary in case RTTM files are provided. This is mapping between integer based speaker index and
speaker ID tokens in RTTM files.
Example:
{'en_0638': {'speaker_0': 'en_0638_A', 'speaker_1': 'en_0638_B'},
'en_4065': {'speaker_0': 'en_4065_B', 'speaker_1': 'en_4065_A'}, ...,}
session_scale_mapping_dict (dict):
Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string.
Returns:
emb_sess_avg_dict (dict):
Dictionary containing speaker mapping information and cluster-average speaker embedding vector.
Each session-level dictionary is indexed by scale index in integer.
output_clus_label_dict (dict):
Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys.
"""
self.scale_n = len(emb_scale_seq_dict.keys())
emb_sess_avg_dict = {
scale_index: {key: [] for key in emb_scale_seq_dict[self.scale_n - 1].keys()}
for scale_index in emb_scale_seq_dict.keys()
}
output_clus_label_dict, emb_dim = self.get_base_clus_label_dict(clus_labels, emb_scale_seq_dict)
all_scale_clus_label_dict = self.assign_labels_to_longer_segs(
output_clus_label_dict, session_scale_mapping_dict
)
for scale_index in emb_scale_seq_dict.keys():
for uniq_id, _emb_tensor in emb_scale_seq_dict[scale_index].items():
if type(_emb_tensor) == list:
emb_tensor = torch.tensor(np.array(_emb_tensor))
else:
emb_tensor = _emb_tensor
clus_label_list = all_scale_clus_label_dict[scale_index][uniq_id]
spk_set = set(clus_label_list)
# Create a label array which identifies clustering result for each segment.
label_array = torch.Tensor(clus_label_list)
avg_embs = torch.zeros(emb_dim, self.max_num_speakers)
for spk_idx in spk_set:
selected_embs = emb_tensor[label_array == spk_idx]
avg_embs[:, spk_idx] = torch.mean(selected_embs, dim=0)
if speaker_mapping_dict is not None:
inv_map = {clus_key: rttm_key for rttm_key, clus_key in speaker_mapping_dict[uniq_id].items()}
else:
inv_map = None
emb_sess_avg_dict[scale_index][uniq_id] = {'mapping': inv_map, 'avg_embs': avg_embs}
return emb_sess_avg_dict, output_clus_label_dict
def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str):
"""
If no pre-existing data is provided, run clustering diarizer from scratch. This will create scale-wise speaker embedding
sequence, cluster-average embeddings, scale mapping and base scale clustering labels. Note that speaker embedding `state_dict`
is loaded from the `state_dict` in the provided MSDD checkpoint.
Args:
manifest_filepath (str):
Input manifest file for creating audio-to-RTTM mapping.
emb_dir (str):
Output directory where embedding files and timestamp files are saved.
Returns:
emb_sess_avg_dict (dict):
Dictionary containing cluster-average embeddings for each session.
emb_scale_seq_dict (dict):
Dictionary containing embedding tensors which are indexed by scale numbers.
base_clus_label_dict (dict):
Dictionary containing clustering results. Clustering results are cluster labels for the base scale segments.
"""
self.cfg_diar_infer.diarizer.manifest_filepath = manifest_filepath
self.cfg_diar_infer.diarizer.out_dir = emb_dir
# Run ClusteringDiarizer which includes system VAD or oracle VAD.
self.clus_diar_model = ClusteringDiarizer(cfg=self.cfg_diar_infer, speaker_model=self._speaker_model)
self._out_dir = self.clus_diar_model._diarizer_params.out_dir
self.out_rttm_dir = os.path.join(self._out_dir, 'pred_ovl_rttms')
os.makedirs(self.out_rttm_dir, exist_ok=True)
self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters
self.clus_diar_model.multiscale_args_dict[
"multiscale_weights"
] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights
self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = (
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters
)
clustering_params_str = json.dumps(dict(self.clus_diar_model._cluster_params), indent=4)
logging.info(f"Multiscale Weights: {self.clus_diar_model.multiscale_args_dict['multiscale_weights']}")
logging.info(f"Clustering Parameters: {clustering_params_str}")
scores = self.clus_diar_model.diarize(batch_size=self.cfg_diar_infer.batch_size)
# If RTTM (ground-truth diarization annotation) files do not exist, scores is None.
if scores is not None:
metric, speaker_mapping_dict, _ = scores
else:
metric, speaker_mapping_dict = None, None
# Get the mapping between segments in different scales.
self._embs_and_timestamps = get_embs_and_timestamps(
self.clus_diar_model.multiscale_embeddings_and_timestamps, self.clus_diar_model.multiscale_args_dict
)
session_scale_mapping_dict = self.get_scale_map(self._embs_and_timestamps)
emb_scale_seq_dict = self.load_emb_scale_seq_dict(emb_dir)
clus_labels = self.load_clustering_labels(emb_dir)
emb_sess_avg_dict, base_clus_label_dict = self.get_cluster_avg_embs(
emb_scale_seq_dict, clus_labels, speaker_mapping_dict, session_scale_mapping_dict
)
emb_scale_seq_dict['session_scale_mapping'] = session_scale_mapping_dict
return emb_sess_avg_dict, emb_scale_seq_dict, base_clus_label_dict, metric
def get_scale_map(self, embs_and_timestamps):
"""
Save multiscale mapping data into dictionary format.
Args:
embs_and_timestamps (dict):
Dictionary containing embedding tensors and timestamp tensors. Indexed by `uniq_id` string.
Returns:
session_scale_mapping_dict (dict):
Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string.
"""
session_scale_mapping_dict = {}
for uniq_id, uniq_embs_and_timestamps in embs_and_timestamps.items():
scale_mapping_dict = get_scale_mapping_argmat(uniq_embs_and_timestamps)
session_scale_mapping_dict[uniq_id] = scale_mapping_dict
return session_scale_mapping_dict
def check_clustering_labels(self, out_dir):
"""
Check whether the laoded clustering label file is including clustering results for all sessions.
This function is used for inference mode of MSDD.
Args:
out_dir (str):
Path to the directory where clustering result files are saved.
Returns:
file_exists (bool):
Boolean that indicates whether clustering result file exists.
clus_label_path (str):
Path to the clustering label output file.
"""
clus_label_path = os.path.join(
out_dir, 'speaker_outputs', f'subsegments_scale{self.base_scale_index}_cluster.label'
)
file_exists = os.path.exists(clus_label_path)
if not file_exists:
logging.info(f"Clustering label file {clus_label_path} does not exist.")
return file_exists, clus_label_path
def load_clustering_labels(self, out_dir):
"""
Load clustering labels generated by clustering diarizer. This function is used for inference mode of MSDD.
Args:
out_dir (str):
Path to the directory where clustering result files are saved.
Returns:
emb_scale_seq_dict (dict):
List containing clustering results in string format.
"""
file_exists, clus_label_path = self.check_clustering_labels(out_dir)
logging.info(f"Loading cluster label file from {clus_label_path}")
with open(clus_label_path) as f:
clus_labels = f.readlines()
return clus_labels
def load_emb_scale_seq_dict(self, out_dir):
"""
Load saved embeddings generated by clustering diarizer. This function is used for inference mode of MSDD.
Args:
out_dir (str):
Path to the directory where embedding pickle files are saved.
Returns:
emb_scale_seq_dict (dict):
Dictionary containing embedding tensors which are indexed by scale numbers.
"""
window_len_list = list(self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.window_length_in_sec)
emb_scale_seq_dict = {scale_index: None for scale_index in range(len(window_len_list))}
for scale_index in range(len(window_len_list)):
pickle_path = os.path.join(
out_dir, 'speaker_outputs', 'embeddings', f'subsegments_scale{scale_index}_embeddings.pkl'
)
logging.info(f"Loading embedding pickle file of scale:{scale_index} at {pickle_path}")
with open(pickle_path, "rb") as input_file:
emb_dict = pkl.load(input_file)
for key, val in emb_dict.items():
emb_dict[key] = val
emb_scale_seq_dict[scale_index] = emb_dict
return emb_scale_seq_dict
class NeuralDiarizer:
"""
Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing clustering results from
clustering diarizer. Overlap-aware diarizer requires separate RTTM generation and evaluation modules to check the effect of
overlap detection in speaker diarization.
"""
def __init__(self, cfg: DictConfig):
""" """
self._cfg = cfg
# Parameter settings for MSDD model
self.use_speaker_model_from_ckpt = cfg.diarizer.msdd_model.parameters.get('use_speaker_model_from_ckpt', True)
self.use_clus_as_main = cfg.diarizer.msdd_model.parameters.get('use_clus_as_main', False)
self.max_overlap_spks = cfg.diarizer.msdd_model.parameters.get('max_overlap_spks', 2)
self.num_spks_per_model = cfg.diarizer.msdd_model.parameters.get('num_spks_per_model', 2)
self.use_adaptive_thres = cfg.diarizer.msdd_model.parameters.get('use_adaptive_thres', True)
self.max_pred_length = cfg.diarizer.msdd_model.parameters.get('max_pred_length', 0)
self.diar_eval_settings = cfg.diarizer.msdd_model.parameters.get(
'diar_eval_settings', [(0.25, True), (0.25, False), (0.0, False)]
)
self._init_msdd_model(cfg)
self.diar_window_length = cfg.diarizer.msdd_model.parameters.diar_window_length
self.msdd_model.cfg = self.transfer_diar_params_to_model_params(self.msdd_model, cfg)
self.manifest_filepath = self.msdd_model.cfg.test_ds.manifest_filepath
self.AUDIO_RTTM_MAP = audio_rttm_map(self.manifest_filepath)
# Initialize clustering and embedding preparation instance (as a diarization encoder).
self.clustering_embedding = ClusterEmbedding(cfg_diar_infer=cfg, cfg_msdd_model=self.msdd_model.cfg)
self.clustering_embedding._speaker_model = self._speaker_model
# Parameters for creating diarization results from MSDD outputs.
self.clustering_max_spks = self.msdd_model._cfg.max_num_of_spks
self.overlap_infer_spk_limit = cfg.diarizer.msdd_model.parameters.get(
'overlap_infer_spk_limit', self.clustering_max_spks
)
def transfer_diar_params_to_model_params(self, msdd_model, cfg):
"""
Transfer the parameters that are needed for MSDD inference from the diarization inference config files