Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confidence ensembles implementation #6614

Merged
merged 12 commits into from
May 15, 2023
14 changes: 11 additions & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import contextlib
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional, Union
from typing import List, Optional, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -163,9 +163,14 @@ class TranscriptionConfig:
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False

# can be set to True to return list of transcriptions instead of the config
# if True, will also skip writing anything to the output file
return_transcriptions: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# just specifying List in the return type as otherwise it's too many things
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List]:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

for key in cfg:
Expand Down Expand Up @@ -299,7 +304,7 @@ def autocast():
cfg = compute_output_filename(cfg, model_name)

# if transcripts should not be overwritten, and already exists, skip re-transcription step and return
if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
if not cfg.return_transcriptions and not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
logging.info(
f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
Expand Down Expand Up @@ -349,6 +354,9 @@ def autocast():
if type(transcriptions) == tuple and len(transcriptions) == 2:
transcriptions = transcriptions[0]

if cfg.return_transcriptions:
return transcriptions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then how to write the transcription if confidence ensembles enabled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The confidence ensemble itself is run with return_transcription=False, but inner models will have it set to True. You can have a look at the tests to see how it's used here:

results = speech_to_text_eval.main(eval_cfg)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this will somehow be integrated with transcribe_speech.py since this is the go-to place for all ASR models offline inference. but this makes sense.


# write audio transcriptions
output_filename, pred_text_attr_name = write_transcription(
transcriptions,
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,3 +1268,6 @@ class RNNTDecodingConfig:

# beam decoding config
beam: beam_decode.BeamRNNTInferConfig = beam_decode.BeamRNNTInferConfig(beam_size=4)

# can be used to change temperature for decoding
temperature: float = 1.0
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
The timestamps will be available in the returned Hypothesis.timestep as a dictionary.

compute_langs: a bool flag, which allows to compute language id (LID) information per token,
word, and the entire sample (most likely language id). The LIDS will be available
word, and the entire sample (most likely language id). The LIDS will be available
in the returned Hypothesis object as a dictionary

rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated.
Expand Down
11 changes: 7 additions & 4 deletions nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def word_error_rate_detail(
) -> Tuple[float, int, float, float, float]:
"""
Computes Average Word Error Rate with details (insertion rate, deletion rate, substitution rate)
between two texts represented as corresponding lists of string.
between two texts represented as corresponding lists of string.

Hypotheses and references must have same length.
Args:
hypotheses (list): list of hypotheses
Expand All @@ -88,7 +88,7 @@ def word_error_rate_detail(
ins_rate (float): average insertion error rate
del_rate (float): average deletion error rate
sub_rate (float): average substitution error rate

"""
scores = 0
words = 0
Expand Down Expand Up @@ -1222,5 +1222,8 @@ class CTCDecodingConfig:
# beam decoding config
beam: ctc_beam_decoding.BeamCTCInferConfig = ctc_beam_decoding.BeamCTCInferConfig(beam_size=4)

# confidence config
# confidence config
confidence_cfg: ConfidenceConfig = ConfidenceConfig()

# can be used to change temperature for decoding
temperature: float = 1.0
203 changes: 203 additions & 0 deletions nemo/collections/asr/models/confidence_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union

import joblib
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict

Check notice

Code scanning / CodeQL

Unused import

Import of 'OmegaConf' is not used.
from pytorch_lightning import Trainer

from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, get_confidence_aggregation_bank
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.core.classes import ModelPT
from nemo.utils import model_utils

__all__ = ['ConfidenceEnsembleModel']


class ConfidenceEnsembleModel(ModelPT):
def __init__(
self, cfg: DictConfig, trainer: 'Trainer' = None,
):
super().__init__(cfg=cfg, trainer=trainer)

# either we load all models from ``load_models`` cfg parameter
# or all of them are specified in the config as modelX alongside the num_models key
#
# ideally, we'd like to directly store all models in a list, but that
# is not currently supported by the submodule logic
# so to access all the models, we do something like
#
# for model_idx in range(self.num_models):
# model = getattr(self, f"model{model_idx}")

if 'num_models' in self.cfg:
self.num_models = self.cfg.num_models
for idx in range(self.num_models):
cfg_field = f"model{idx}"
model_cfg = self.cfg[cfg_field]
model_class = model_utils.import_class_by_path(model_cfg['target'])
self.register_nemo_submodule(
name=cfg_field, config_field=cfg_field, model=model_class(model_cfg, trainer=trainer),
)
else:
self.num_models = len(cfg.load_models)
with open_dict(self.cfg):
self.cfg.num_models = self.num_models
for idx, model in enumerate(cfg.load_models):
cfg_field = f"model{idx}"
if model.endswith(".nemo"):
self.register_nemo_submodule(
name=cfg_field,
config_field=cfg_field,
model=ASRModel.restore_from(model, trainer=trainer, map_location="cpu"),
)
else:
self.register_nemo_submodule(
cfg_field, config_field=cfg_field, model=ASRModel.from_pretrained(model, map_location="cpu"),
)

# registering model selection block - this is expected to be a joblib-saved
# pretrained sklearn pipeline containing standardization + logistic regression
# trained to predict "most-confident" model index from the confidence scores of all models
model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block)
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
self.model_selection_block = joblib.load(model_selection_block_path)
self.confidence_cfg = ConfidenceConfig(**self.cfg.confidence)

# making sure each model has correct confidence settings in the decoder strategy
for model_idx in range(self.num_models):
model = getattr(self, f"model{model_idx}")
# for now we assume users are direclty responsible for matching
# decoder type when building ensemlbe with inference type
# TODO: add automatic checks for errors
if isinstance(model, EncDecHybridRNNTCTCModel):
self.update_decoding_parameters(model.cfg.decoding)
model.change_decoding_strategy(model.cfg.decoding, decoder_type="rnnt")
self.update_decoding_parameters(model.cfg.aux_ctc.decoding)
model.change_decoding_strategy(model.cfg.aux_ctc.decoding, decoder_type="ctc")
else:
self.update_decoding_parameters(model.cfg.decoding)
model.change_decoding_strategy(model.cfg.decoding)

def update_decoding_parameters(self, decoding_cfg):
"""Updating confidence/temperature parameters of the config."""
with open_dict(decoding_cfg):
decoding_cfg.confidence_cfg = self.confidence_cfg
decoding_cfg.temperature = self.cfg.temperature

def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
"""Pass-through to the ensemble models.

Note that training is not actually supported for this class!
"""
for model_idx in range(self.num_models):
getattr(self, f"model{model_idx}").setup_training_data(train_data_config)

def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
"""Pass-through to the ensemble models."""
for model_idx in range(self.num_models):
getattr(self, f"model{model_idx}").setup_validation_data(val_data_config)

def change_attention_model(
self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True
):
"""Pass-through to the ensemble models."""
for model_idx in range(self.num_models):
getattr(self, f"model{model_idx}").change_attention_model(
self_attention_model, att_context_size, update_config
)

def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None):
"""Pass-through to the ensemble models.

The only change here is that we always require frame-confidence to
be returned.
"""
decoding_cfg.confidence_cfg = self.confidence_cfg
for model_idx in range(self.num_models):
model = getattr(self, f"model{model_idx}")
if isinstance(model, EncDecHybridRNNTCTCModel):
model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type)
else:
model.change_decoding_strategy(decoding_cfg)

@torch.no_grad()
def transcribe(
self,
paths2audio_files: List[str],
batch_size: int = 4,
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
**kwargs, # any other model specific parameters are passed directly
) -> List[str]:
"""Confidence-ensemble transcribe method.

Consists of the following steps:

1. Run all models (TODO: in parallel)
2. Compute confidence for each model
3. Use logistic regression to pick the "most confident" model
4. Return the output of that model
"""
# TODO: lots of duplicate code with building ensemble script
aggr_func = get_confidence_aggregation_bank()[self.confidence_cfg.aggregation]
confidences = []
all_transcriptions = []
# always requiring to return hypothesis
# TODO: make sure to return text only if was False originally
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be resolved before merging the PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will resolve this in a follow-up PR(s)

return_hypotheses = True
for model_idx in range(self.num_models):
model = getattr(self, f"model{model_idx}")
transcriptions = model.transcribe(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this support transcribe_partial_audio just like transcribe_speech.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment, right now the assumption is that we transcribe the full thing and then get the confidence. But in principal, there is a plan to add support (likely with a subclass) that will transcribe only part of the utterance and then continue with the transcription of only the single most-confident model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Yeah theoretically all the ASR model (with extended feature) should support transcribing whole utterance and part of it. Please make sure to add this later. Looking forward to it!

paths2audio_files=paths2audio_files,
batch_size=batch_size,
return_hypotheses=return_hypotheses,
num_workers=num_workers,
channel_selector=channel_selector,
augmentor=augmentor,
verbose=verbose,
**kwargs,
)
if isinstance(transcriptions, tuple): # transducers return a tuple
transcriptions = transcriptions[0]

model_confidences = []
for transcription in transcriptions:
if isinstance(transcription.frame_confidence[0], list):
# NeMo Transducer API returns list of lists for confidences
conf_values = [conf_value for confs in transcription.frame_confidence for conf_value in confs]
else:
conf_values = transcription.frame_confidence
model_confidences.append(aggr_func(conf_values))
confidences.append(model_confidences)
all_transcriptions.append(transcriptions)

# transposing with zip(*list)
features = np.array(list(zip(*confidences)))
model_indices = self.model_selection_block.predict(features)
final_transcriptions = []
for transcrption_idx in range(len(all_transcriptions[0])):
final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx])

return final_transcriptions

def list_available_models(self):
return []
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
dist_sync_on_step=True,
)

self.decoder.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
dist_sync_on_step=True,
)

self.decoder.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
self.joint.set_loss(self.loss)
self.joint.set_wer(self.wer)

self.joint.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg
Expand Down Expand Up @@ -442,6 +444,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
dist_sync_on_step=True,
)

self.ctc_decoder.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.aux_ctc.decoding):
self.cfg.aux_ctc.decoding = decoding_cfg
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
dist_sync_on_step=True,
)

self.ctc_decoder.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = decoding_cfg
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
self.joint.set_loss(self.loss)
self.joint.set_wer(self.wer)

self.joint.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
self.joint.set_loss(self.loss)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be added to ASRwithTTS models?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think those have asr models as direct subblocks, so all changes should already be included. Let me know if that's not the case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds right, cool

self.joint.set_wer(self.wer)

self.joint.temperature = decoding_cfg.get('temperature', 1.0)

# Update config
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=
accepted_adapters = [adapter_utils.LINEAR_ADAPTER_CLASSPATH]
self.set_accepted_adapter_types(accepted_adapters)

# to change, requires running ``model.temperature = T`` explicitly
self.temperature = 1.0

@typecheck()
def forward(self, encoder_output):
# Adapter module forward step
Expand All @@ -453,6 +456,10 @@ def forward(self, encoder_output):
encoder_output = self.forward_enabled_adapters(encoder_output)
encoder_output = encoder_output.transpose(1, 2) # [B, C, T]

if self.temperature != 1.0:
return torch.nn.functional.log_softmax(
self.decoder_layers(encoder_output).transpose(1, 2) / self.temperature, dim=-1
)
return torch.nn.functional.log_softmax(self.decoder_layers(encoder_output).transpose(1, 2), dim=-1)

def input_example(self, max_batch=1, max_dim=256):
Expand Down
Loading