Skip to content

Commit

Permalink
add label inference support to EncDecSpeakerLabel class (#5278)
Browse files Browse the repository at this point in the history
* add label inference support to EncDecSpeakerLabel class

Signed-off-by: nithinraok <[email protected]>

* add necessary tests

Signed-off-by: nithinraok <[email protected]>

* reflect on comments

Signed-off-by: nithinraok <[email protected]>

* grammatical correction

Signed-off-by: nithinraok <[email protected]>

* minor doc string changes

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
  • Loading branch information
nithinraok and okuchaiev authored Nov 2, 2022
1 parent ba92ad2 commit a7d73f4
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 66 deletions.
53 changes: 20 additions & 33 deletions examples/speaker_tasks/recognition/extract_speaker_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,29 @@

import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm

from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.utils.speaker_utils import embedding_normalize
from nemo.utils import logging

try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


def get_embeddings(speaker_model, manifest_file, batch_size=1, embedding_dir='./', device='cuda'):
test_config = OmegaConf.create(
dict(manifest_filepath=manifest_file, sample_rate=16000, labels=None, batch_size=batch_size, shuffle=False,)
)

speaker_model.setup_test_data(test_config)
speaker_model = speaker_model.to(device)
speaker_model.eval()

all_embs = []
out_embeddings = {}

for test_batch in tqdm(speaker_model.test_dataloader()):
test_batch = [x.to(device) for x in test_batch]
audio_signal, audio_signal_len, labels, slices = test_batch
with autocast():
_, embs = speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
emb_shape = embs.shape[-1]
embs = embs.view(-1, emb_shape)
all_embs.extend(embs.cpu().detach().numpy())
del test_batch

"""
save embeddings to pickle file
Args:
speaker_model: NeMo <EncDecSpeakerLabel> model
manifest_file: path to the manifest file containing the audio file path from which the
embeddings should be extracted
batch_size: batch_size for inference
embedding_dir: path to directory to store embeddings file
device: compute device to perform operations
"""

all_embs, _, _, _ = speaker_model.batch_inference(manifest_file, batch_size=batch_size, device=device)
all_embs = np.asarray(all_embs)
all_embs = embedding_normalize(all_embs)
out_embeddings = {}

with open(manifest_file, 'r', encoding='utf-8') as manifest:
for i, line in enumerate(manifest.readlines()):
line = line.strip()
Expand Down Expand Up @@ -107,6 +89,9 @@ def main():
required=False,
help="path to .nemo speaker verification model file to extract embeddings, if not passed SpeakerNet-M model would be downloaded from NGC and used to extract embeddings",
)
parser.add_argument(
"--batch_size", type=int, default=1, required=False, help="batch size",
)
parser.add_argument(
"--embedding_dir",
type=str,
Expand All @@ -131,7 +116,9 @@ def main():
device = 'cpu'
logging.warning("Running model on CPU, for faster performance it is adviced to use atleast one NVIDIA GPUs")

get_embeddings(speaker_model, args.manifest, batch_size=1, embedding_dir=args.embedding_dir, device=device)
get_embeddings(
speaker_model, args.manifest, batch_size=args.batch_size, embedding_dir=args.embedding_dir, device=device
)


if __name__ == '__main__':
Expand Down
20 changes: 8 additions & 12 deletions examples/speaker_tasks/recognition/speaker_identification_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def main(cfg):

backend = cfg.backend.backend_model.lower()

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=enrollment_manifest, labels=None, featurizer=featurizer)
enroll_id2label = dataset.id2label

if backend == 'cosine_similarity':
model_path = cfg.backend.cosine_similarity.model_path
batch_size = cfg.backend.cosine_similarity.batch_size
Expand All @@ -50,13 +54,11 @@ def main(cfg):
else:
speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

enroll_embs, _, enroll_truelabels, enroll_id2label = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, enrollment_manifest, batch_size, sample_rate, device=device,
enroll_embs, _, enroll_truelabels, _ = speaker_model.batch_inference(
enrollment_manifest, batch_size, sample_rate, device=device,
)

test_embs, _, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, test_manifest, batch_size, sample_rate, device=device,
)
test_embs, _, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,)

# length normalize
enroll_embs = enroll_embs / (np.linalg.norm(enroll_embs, ord=2, axis=-1, keepdims=True))
Expand Down Expand Up @@ -84,18 +86,12 @@ def main(cfg):
else:
speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=enrollment_manifest, labels=None, featurizer=featurizer)
enroll_id2label = dataset.id2label

if speaker_model.decoder.final.out_features != len(enroll_id2label):
raise ValueError(
"number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath"
)

_, test_logits, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
speaker_model, test_manifest, batch_size, sample_rate, device=device,
)
_, test_logits, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,)
matched_labels = test_logits.argmax(axis=-1)

with open(test_manifest, 'rb') as f1, open(out_manifest, 'w', encoding='utf-8') as f2:
Expand Down
104 changes: 83 additions & 21 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,14 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'test')

@torch.no_grad()
def get_embedding(self, path2audio_file):
def infer_file(self, path2audio_file):
"""
Returns the speaker embeddings for a provided audio file.
Args:
path2audio_file: path to audio wav file
path2audio_file: path to an audio wav file
Returns:
embs: speaker embeddings
emb: speaker embeddings (Audio representations)
logits: logits corresponding of final layer
"""
audio, sr = librosa.load(path2audio_file, sr=None)
target_sr = self._cfg.train_ds.get('sample_rate', 16000)
Expand All @@ -441,13 +440,48 @@ def get_embedding(self, path2audio_file):
mode = self.training
self.freeze()

_, embs = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
logits, emb = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)

self.train(mode=mode)
if mode is True:
self.unfreeze()
del audio_signal, audio_signal_len
return embs
return emb, logits

def get_label(self, path2audio_file):
"""
Returns label of path2audio_file from classes the model was trained on.
Args:
path2audio_file: path to audio wav file
Returns:
label: label corresponding to the trained model
"""
_, logits = self.infer_file(path2audio_file=path2audio_file)
mapped_labels = list(self._cfg['train_ds']['labels'])
if mapped_labels is not None:
label_id = logits.argmax(axis=1)
label = mapped_labels[int(label_id[0])]
else:
logging.info("labels are not saved to model, hence only outputting the label id index")
label = logits.argmax(axis=1)

return label

def get_embedding(self, path2audio_file):
"""
Returns the speaker embeddings for a provided audio file.
Args:
path2audio_file: path to an audio wav file
Returns:
emb: speaker embeddings (Audio representations)
"""

emb, _ = self.infer_file(path2audio_file=path2audio_file)

return emb

@torch.no_grad()
def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7):
Expand Down Expand Up @@ -478,13 +512,37 @@ def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7):
logging.info(" two audio files are from different speakers")
return False

@staticmethod
@torch.no_grad()
def get_batch_embeddings(speaker_model, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda'):
def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda'):
"""
Perform batch inference on EncDecSpeakerLabelModel.
To perform inference on single audio file, once can use infer_model, get_label or get_embedding
speaker_model.eval()
if device == 'cuda':
speaker_model.to(device)
To map predicted labels, one can do
`arg_values = logits.argmax(axis=1)`
`pred_labels = list(map(lambda t : pred_labels[t], arg_values))`
Args:
manifest_filepath: Path to manifest file
batch_size: batch size to perform batch inference
sample_rate: sample rate of audio files in manifest file
device: compute device to perform operations.
Returns:
The variables below all follow the audio file order in the manifest file.
embs: embeddings of files provided in manifest file
logits: logits of final layer of EncDecSpeakerLabel Model
gt_labels: labels from manifest file (needed for speaker enrollment and testing)
mapped_labels: Classification labels sorted in the order that they are mapped by the trained model
"""
mode = self.training
self.freeze()
self.eval()
self.to(device)
mapped_labels = self._cfg['train_ds']['labels']
if mapped_labels is not None:
mapped_labels = list(mapped_labels)

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer)
Expand All @@ -493,20 +551,24 @@ def get_batch_embeddings(speaker_model, manifest_filepath, batch_size=32, sample
dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn,
)

all_logits = []
all_labels = []
all_embs = []
logits = []
embs = []
gt_labels = []

for test_batch in tqdm(dataloader):
if device == 'cuda':
test_batch = [x.to(device) for x in test_batch]
audio_signal, audio_signal_len, labels, _ = test_batch
logits, embs = speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
logit, emb = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)

all_logits.extend(logits.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_embs.extend(embs.cpu().numpy())
logits.extend(logit.cpu().numpy())
gt_labels.extend(labels.cpu().numpy())
embs.extend(emb.cpu().numpy())

self.train(mode=mode)
if mode is True:
self.unfreeze()

all_logits, true_labels, all_embs = np.asarray(all_logits), np.asarray(all_labels), np.asarray(all_embs)
logits, embs, gt_labels = np.asarray(logits), np.asarray(embs), np.asarray(gt_labels)

return all_embs, all_logits, true_labels, dataset.id2label
return embs, logits, gt_labels, mapped_labels
31 changes: 31 additions & 0 deletions tests/collections/asr/test_speaker_label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import TestCase

import pytest
Expand Down Expand Up @@ -139,3 +140,33 @@ def test_titanet_enc_dec(self):
confdict = speaker_model.to_config_dict()
instance2 = EncDecSpeakerLabelModel.from_config_dict(confdict)
self.assertTrue(isinstance(instance2, EncDecSpeakerLabelModel))


class TestEncDecSpeechLabelModel:
@pytest.mark.unit
def test_pretrained_titanet_embeddings(self, test_data_dir):
model_name = 'titanet_large'
speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name)
assert isinstance(speaker_model, EncDecSpeakerLabelModel)
relative_filepath = "an4_speaker/an4/wav/an4_clstk/fash/an251-fash-b.wav"
filename = os.path.join(test_data_dir, relative_filepath)

emb, logits = speaker_model.infer_file(filename)

class_id = logits.argmax(axis=-1)
emb_sum = emb.sum()

assert 11144 == class_id
assert (emb_sum + 0.2575) <= 1e-2

@pytest.mark.unit
def test_pretrained_ambernet_logits(self, test_data_dir):
model_name = 'langid_ambernet'
lang_model = EncDecSpeakerLabelModel.from_pretrained(model_name)
assert isinstance(lang_model, EncDecSpeakerLabelModel)
relative_filepath = "an4_speaker/an4/wav/an4_clstk/fash/an255-fash-b.wav"
filename = os.path.join(test_data_dir, relative_filepath)

label = lang_model.get_label(filename)

assert label == "en"

0 comments on commit a7d73f4

Please sign in to comment.