-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Confidence ensembles implementation #6614
Changes from all commits
faf3ae0
b708c1e
c077adf
cd46388
7dd63d9
ab057df
99d9abd
b669840
4264c7a
ef15746
4d3dc25
83af9d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Optional, Union | ||
|
||
import joblib | ||
import numpy as np | ||
import torch | ||
from omegaconf import DictConfig, OmegaConf, open_dict | ||
Check notice Code scanning / CodeQL Unused import
Import of 'OmegaConf' is not used.
|
||
from pytorch_lightning import Trainer | ||
|
||
from nemo.collections.asr.models.asr_model import ASRModel | ||
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel | ||
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, get_confidence_aggregation_bank | ||
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType | ||
from nemo.core.classes import ModelPT | ||
from nemo.utils import model_utils | ||
|
||
__all__ = ['ConfidenceEnsembleModel'] | ||
|
||
|
||
class ConfidenceEnsembleModel(ModelPT): | ||
def __init__( | ||
self, cfg: DictConfig, trainer: 'Trainer' = None, | ||
): | ||
super().__init__(cfg=cfg, trainer=trainer) | ||
|
||
# either we load all models from ``load_models`` cfg parameter | ||
# or all of them are specified in the config as modelX alongside the num_models key | ||
# | ||
# ideally, we'd like to directly store all models in a list, but that | ||
# is not currently supported by the submodule logic | ||
# so to access all the models, we do something like | ||
# | ||
# for model_idx in range(self.num_models): | ||
# model = getattr(self, f"model{model_idx}") | ||
|
||
if 'num_models' in self.cfg: | ||
self.num_models = self.cfg.num_models | ||
for idx in range(self.num_models): | ||
cfg_field = f"model{idx}" | ||
model_cfg = self.cfg[cfg_field] | ||
model_class = model_utils.import_class_by_path(model_cfg['target']) | ||
self.register_nemo_submodule( | ||
name=cfg_field, config_field=cfg_field, model=model_class(model_cfg, trainer=trainer), | ||
) | ||
else: | ||
self.num_models = len(cfg.load_models) | ||
with open_dict(self.cfg): | ||
self.cfg.num_models = self.num_models | ||
for idx, model in enumerate(cfg.load_models): | ||
cfg_field = f"model{idx}" | ||
if model.endswith(".nemo"): | ||
self.register_nemo_submodule( | ||
name=cfg_field, | ||
config_field=cfg_field, | ||
model=ASRModel.restore_from(model, trainer=trainer, map_location="cpu"), | ||
) | ||
else: | ||
self.register_nemo_submodule( | ||
cfg_field, config_field=cfg_field, model=ASRModel.from_pretrained(model, map_location="cpu"), | ||
) | ||
|
||
# registering model selection block - this is expected to be a joblib-saved | ||
# pretrained sklearn pipeline containing standardization + logistic regression | ||
# trained to predict "most-confident" model index from the confidence scores of all models | ||
model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block) | ||
titu1994 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.model_selection_block = joblib.load(model_selection_block_path) | ||
self.confidence_cfg = ConfidenceConfig(**self.cfg.confidence) | ||
|
||
# making sure each model has correct confidence settings in the decoder strategy | ||
for model_idx in range(self.num_models): | ||
model = getattr(self, f"model{model_idx}") | ||
# for now we assume users are direclty responsible for matching | ||
# decoder type when building ensemlbe with inference type | ||
# TODO: add automatic checks for errors | ||
if isinstance(model, EncDecHybridRNNTCTCModel): | ||
self.update_decoding_parameters(model.cfg.decoding) | ||
model.change_decoding_strategy(model.cfg.decoding, decoder_type="rnnt") | ||
self.update_decoding_parameters(model.cfg.aux_ctc.decoding) | ||
model.change_decoding_strategy(model.cfg.aux_ctc.decoding, decoder_type="ctc") | ||
else: | ||
self.update_decoding_parameters(model.cfg.decoding) | ||
model.change_decoding_strategy(model.cfg.decoding) | ||
|
||
def update_decoding_parameters(self, decoding_cfg): | ||
"""Updating confidence/temperature parameters of the config.""" | ||
with open_dict(decoding_cfg): | ||
decoding_cfg.confidence_cfg = self.confidence_cfg | ||
decoding_cfg.temperature = self.cfg.temperature | ||
|
||
def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): | ||
"""Pass-through to the ensemble models. | ||
|
||
Note that training is not actually supported for this class! | ||
""" | ||
for model_idx in range(self.num_models): | ||
getattr(self, f"model{model_idx}").setup_training_data(train_data_config) | ||
|
||
def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): | ||
"""Pass-through to the ensemble models.""" | ||
for model_idx in range(self.num_models): | ||
getattr(self, f"model{model_idx}").setup_validation_data(val_data_config) | ||
|
||
def change_attention_model( | ||
self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True | ||
): | ||
"""Pass-through to the ensemble models.""" | ||
for model_idx in range(self.num_models): | ||
getattr(self, f"model{model_idx}").change_attention_model( | ||
self_attention_model, att_context_size, update_config | ||
) | ||
|
||
def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): | ||
"""Pass-through to the ensemble models. | ||
|
||
The only change here is that we always require frame-confidence to | ||
be returned. | ||
""" | ||
decoding_cfg.confidence_cfg = self.confidence_cfg | ||
for model_idx in range(self.num_models): | ||
model = getattr(self, f"model{model_idx}") | ||
if isinstance(model, EncDecHybridRNNTCTCModel): | ||
model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type) | ||
else: | ||
model.change_decoding_strategy(decoding_cfg) | ||
|
||
@torch.no_grad() | ||
def transcribe( | ||
self, | ||
paths2audio_files: List[str], | ||
batch_size: int = 4, | ||
return_hypotheses: bool = False, | ||
num_workers: int = 0, | ||
channel_selector: Optional[ChannelSelectorType] = None, | ||
augmentor: DictConfig = None, | ||
verbose: bool = True, | ||
**kwargs, # any other model specific parameters are passed directly | ||
) -> List[str]: | ||
"""Confidence-ensemble transcribe method. | ||
|
||
Consists of the following steps: | ||
|
||
1. Run all models (TODO: in parallel) | ||
2. Compute confidence for each model | ||
3. Use logistic regression to pick the "most confident" model | ||
4. Return the output of that model | ||
""" | ||
# TODO: lots of duplicate code with building ensemble script | ||
aggr_func = get_confidence_aggregation_bank()[self.confidence_cfg.aggregation] | ||
confidences = [] | ||
all_transcriptions = [] | ||
# always requiring to return hypothesis | ||
# TODO: make sure to return text only if was False originally | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this be resolved before merging the PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will resolve this in a follow-up PR(s) |
||
return_hypotheses = True | ||
for model_idx in range(self.num_models): | ||
model = getattr(self, f"model{model_idx}") | ||
transcriptions = model.transcribe( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this support transcribe_partial_audio just like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at the moment, right now the assumption is that we transcribe the full thing and then get the confidence. But in principal, there is a plan to add support (likely with a subclass) that will transcribe only part of the utterance and then continue with the transcription of only the single most-confident model There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. Yeah theoretically all the ASR model (with extended feature) should support transcribing whole utterance and part of it. Please make sure to add this later. Looking forward to it! |
||
paths2audio_files=paths2audio_files, | ||
batch_size=batch_size, | ||
return_hypotheses=return_hypotheses, | ||
num_workers=num_workers, | ||
channel_selector=channel_selector, | ||
augmentor=augmentor, | ||
verbose=verbose, | ||
**kwargs, | ||
) | ||
if isinstance(transcriptions, tuple): # transducers return a tuple | ||
transcriptions = transcriptions[0] | ||
|
||
model_confidences = [] | ||
for transcription in transcriptions: | ||
if isinstance(transcription.frame_confidence[0], list): | ||
# NeMo Transducer API returns list of lists for confidences | ||
conf_values = [conf_value for confs in transcription.frame_confidence for conf_value in confs] | ||
else: | ||
conf_values = transcription.frame_confidence | ||
model_confidences.append(aggr_func(conf_values)) | ||
confidences.append(model_confidences) | ||
all_transcriptions.append(transcriptions) | ||
|
||
# transposing with zip(*list) | ||
features = np.array(list(zip(*confidences))) | ||
model_indices = self.model_selection_block.predict(features) | ||
final_transcriptions = [] | ||
for transcrption_idx in range(len(all_transcriptions[0])): | ||
final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx]) | ||
|
||
return final_transcriptions | ||
|
||
def list_available_models(self): | ||
return [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -442,6 +442,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): | |
self.joint.set_loss(self.loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be added to ASRwithTTS models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think those have asr models as direct subblocks, so all changes should already be included. Let me know if that's not the case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds right, cool |
||
self.joint.set_wer(self.wer) | ||
|
||
self.joint.temperature = decoding_cfg.get('temperature', 1.0) | ||
|
||
# Update config | ||
with open_dict(self.cfg.decoding): | ||
self.cfg.decoding = decoding_cfg | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then how to write the transcription if confidence ensembles enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The confidence ensemble itself is run with return_transcription=False, but inner models will have it set to True. You can have a look at the tests to see how it's used here:
NeMo/scripts/confidence_ensembles/test_confidence_ensembles.py
Line 99 in 83af9d8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect this will somehow be integrated with
transcribe_speech.py
since this is the go-to place for all ASR models offline inference. but this makes sense.