Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file class based inference API for diarization #5945

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@hydra_runner(config_path="../conf/inference", config_name="diar_infer_meeting.yaml")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
sd_model = ClusteringDiarizer(cfg=cfg)
sd_model = ClusteringDiarizer(cfg=cfg).to(cfg.device)
sd_model.diarize()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ num_workers: 1
sample_rate: 16000
batch_size: 64
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
verbose: True # enable additional logging

diarizer:
manifest_filepath: ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ num_workers: 1
sample_rate: 16000
batch_size: 64
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
verbose: True # enable additional logging

diarizer:
manifest_filepath: ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ num_workers: 1
sample_rate: 16000
batch_size: 64
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
verbose: True # enable additional logging

diarizer:
manifest_filepath: ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@hydra_runner(config_path="../conf/inference", config_name="diar_infer_telephonic.yaml")
def main(cfg):
diarizer_model = NeuralDiarizer(cfg=cfg)
diarizer_model = NeuralDiarizer(cfg=cfg).to(cfg.device)
diarizer_model.diarize()


Expand Down
7 changes: 4 additions & 3 deletions nemo/collections/asr/metrics/der.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''):


def score_labels(
AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True
AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True
) -> Optional[Tuple[DiarizationErrorRate, Dict]]:
"""
Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are
Expand All @@ -63,6 +63,7 @@ def score_labels(
AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath
all_reference (list[uniq_name,Annotation]): reference annotations for score calculation
all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation
verbose (bool): Warns if RTTM file is not found.

Returns:
metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile.
Expand Down Expand Up @@ -101,11 +102,11 @@ def score_labels(
)

return metric, mapping_dict, itemized_errors
else:
elif verbose:
logging.warning(
"Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate"
)
return None
return None


def evaluate_der(audio_rttm_map_dict, all_reference, all_hypothesis, diar_eval_mode='all'):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.k2_sequence_models import EncDecK2SeqModel, EncDecK2SeqModelBPE
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel
from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel, NeuralDiarizer
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel
Expand Down
77 changes: 49 additions & 28 deletions nemo/collections/asr/models/clustering_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tarfile
import tempfile
from copy import deepcopy
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -72,18 +72,20 @@ def get_available_model_names(class_name):
return list(map(lambda x: x.pretrained_model_name, available_models))


class ClusteringDiarizer(Model, DiarizationMixin):
class ClusteringDiarizer(torch.nn.Module, Model, DiarizationMixin):
"""
Inference model Class for offline speaker diarization.
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
Extract Embeddings, Clustering, Resegmentation and Scoring.
All the parameters are passed through config file
"""

def __init__(self, cfg: DictConfig, speaker_model=None):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
# Convert config to support Hydra 1.0+ instantiation
cfg = model_utils.maybe_update_config_version(cfg)
def __init__(self, cfg: Union[DictConfig, Any], speaker_model=None):
super().__init__()
if isinstance(cfg, DictConfig):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
# Convert config to support Hydra 1.0+ instantiation
cfg = model_utils.maybe_update_config_version(cfg)
self._cfg = cfg

# Diarizer set up
Expand All @@ -100,14 +102,9 @@ def __init__(self, cfg: DictConfig, speaker_model=None):
self.multiscale_embeddings_and_timestamps = {}
self._init_speaker_model(speaker_model)
self._speaker_params = self._cfg.diarizer.speaker_embeddings.parameters
self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs')
shutil.rmtree(self._speaker_dir, ignore_errors=True)
os.makedirs(self._speaker_dir)

# Clustering params
self._cluster_params = self._diarizer_params.clustering.parameters
default_device = "cuda" if torch.cuda.is_available() else "cpu"
self._device = torch.device(cfg.device if cfg.device else default_device)

@classmethod
def list_available_models(cls):
Expand All @@ -119,7 +116,7 @@ def _init_vad_model(self):
"""
model_path = self._cfg.diarizer.vad.model_path
if model_path.endswith('.nemo'):
self._vad_model = EncDecClassificationModel.restore_from(model_path)
self._vad_model = EncDecClassificationModel.restore_from(model_path, map_location=self._cfg.device)
logging.info("VAD model loaded locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecClassificationModel):
Expand All @@ -128,8 +125,9 @@ def _init_vad_model(self):
)
model_path = "vad_telephony_marblenet"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._vad_model = EncDecClassificationModel.from_pretrained(model_name=model_path)

self._vad_model = EncDecClassificationModel.from_pretrained(
model_name=model_path, map_location=self._cfg.device
)
self._vad_window_length_in_sec = self._vad_params.window_length_in_sec
self._vad_shift_length_in_sec = self._vad_params.shift_length_in_sec
self.has_vad_model = True
Expand All @@ -143,10 +141,12 @@ def _init_speaker_model(self, speaker_model=None):
else:
model_path = self._cfg.diarizer.speaker_embeddings.model_path
if model_path is not None and model_path.endswith('.nemo'):
self._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
self._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path, map_location=self._cfg.device)
logging.info("Speaker Model restored locally from {}".format(model_path))
elif model_path.endswith('.ckpt'):
self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path)
self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(
model_path, map_location=self._cfg.device
)
logging.info("Speaker Model restored locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecSpeakerLabelModel):
Expand All @@ -155,7 +155,9 @@ def _init_speaker_model(self, speaker_model=None):
)
model_path = "ecapa_tdnn"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_path)
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(
model_name=model_path, map_location=self._cfg.device
)

self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
Expand Down Expand Up @@ -201,7 +203,6 @@ def _run_vad(self, manifest_file):
shutil.rmtree(self._vad_dir, ignore_errors=True)
os.makedirs(self._vad_dir)

self._vad_model = self._vad_model.to(self._device)
self._vad_model.eval()

time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec)
Expand All @@ -214,8 +215,10 @@ def _run_vad(self, manifest_file):
data.append(get_uniqname_from_filepath(file))

status = get_vad_stream_status(data)
for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader(), desc='vad', leave=True)):
test_batch = [x.to(self._device) for x in test_batch]
for i, test_batch in enumerate(
tqdm(self._vad_model.test_dataloader(), desc='vad', leave=True, disable=not self.verbose)
):
test_batch = [x.to(self._vad_model.device) for x in test_batch]
with autocast():
log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
probs = torch.softmax(log_probs, dim=-1)
Expand Down Expand Up @@ -258,11 +261,13 @@ def _run_vad(self, manifest_file):

logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

vad_params = self._vad_params if isinstance(self._vad_params, (DictConfig, dict)) else self._vad_params.dict()
table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
postprocessing_params=self._vad_params,
postprocessing_params=vad_params,
frame_length_in_sec=frame_length_in_sec,
num_workers=self._cfg.num_workers,
out_dir=self._vad_dir,
)

AUDIO_VAD_RTTM_MAP = {}
Expand Down Expand Up @@ -308,6 +313,7 @@ def _perform_speech_activity_detection(self):
'window_length_in_sec': self._vad_window_length_in_sec,
'split_duration': self._split_duration,
'num_workers': self._cfg.num_workers,
'out_dir': self._diarizer_params.out_dir,
}
manifest_vad_input = prepare_manifest(config)
else:
Expand Down Expand Up @@ -337,15 +343,17 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in
logging.info("Extracting embeddings for Diarization")
self._setup_spkr_test_data(manifest_file)
self.embeddings = {}
self._speaker_model = self._speaker_model.to(self._device)
self._speaker_model.eval()
self.time_stamps = {}

all_embs = torch.empty([0])
for test_batch in tqdm(
self._speaker_model.test_dataloader(), desc=f'[{scale_idx+1}/{num_scales}] extract embeddings', leave=True
self._speaker_model.test_dataloader(),
desc=f'[{scale_idx+1}/{num_scales}] extract embeddings',
leave=True,
disable=not self.verbose,
):
test_batch = [x.to(self._device) for x in test_batch]
test_batch = [x.to(self._speaker_model.device) for x in test_batch]
audio_signal, audio_signal_len, labels, slices = test_batch
with autocast():
_, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
Expand Down Expand Up @@ -396,6 +404,14 @@ def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0):
"""

self._out_dir = self._diarizer_params.out_dir

self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs')

if os.path.exists(self._speaker_dir):
logging.warning("Deleting previous clustering diarizer outputs.")
shutil.rmtree(self._speaker_dir, ignore_errors=True)
os.makedirs(self._speaker_dir)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

if not os.path.exists(self._out_dir):
os.mkdir(self._out_dir)

Expand Down Expand Up @@ -442,20 +458,21 @@ def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0):
AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
out_rttm_dir=out_rttm_dir,
clustering_params=self._cluster_params,
device=self._speaker_model.device,
verbose=self.verbose,
)
logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))

# Scoring
score = score_labels(
return score_labels(
self.AUDIO_RTTM_MAP,
all_reference,
all_hypothesis,
collar=self._diarizer_params.collar,
ignore_overlap=self._diarizer_params.ignore_overlap,
verbose=self.verbose,
)

logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))
return score

@staticmethod
def __make_nemo_file_from_folder(filename, source_dir):
with tarfile.open(filename, "w:gz") as tar:
Expand Down Expand Up @@ -536,3 +553,7 @@ def restore_from(
os.chdir(cwd)

return instance

@property
def verbose(self) -> bool:
return self._cfg.verbose
1 change: 1 addition & 0 deletions nemo/collections/asr/models/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
EncDecClassificationDatasetConfig,
EncDecClassificationModelConfig,
)
from nemo.collections.asr.models.configs.diarizer_config import NeuralDiarizerInferenceConfig
from nemo.collections.asr.models.configs.matchboxnet_config import (
EncDecClassificationModelConfigBuilder,
MatchboxNetModelConfig,
Expand Down
Loading