diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 3445ad4be6ef4..97e9496999def 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -325,10 +325,6 @@ def autocast(): if type(transcriptions) == tuple and len(transcriptions) == 2: transcriptions = transcriptions[0] - from IPython import embed - - embed() - # write audio transcriptions output_filename = write_transcription( transcriptions, diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index 86451cd923480..1e7bb56227779 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Union import joblib +import numpy as np import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer @@ -30,41 +31,47 @@ from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRModuleMixin, InterCTCMixin +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, get_confidence_aggregation_bank from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.classes.mixins import AccessMixin from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType -from nemo.utils import logging +from nemo.utils import logging, model_utils __all__ = ['ConfidenceEnsembleModel'] class ConfidenceEnsembleModel(ModelPT): def __init__( - self, - cfg: DictConfig, - trainer: 'Trainer' = None, - models: Optional[List[str]] = None, - model_selection_block_path: Optional[str] = None, + self, cfg: DictConfig, trainer: 'Trainer' = None, models: Optional[List[str]] = None, ): super().__init__(cfg=cfg, trainer=trainer) # either we load all models from ``models`` init parameter # or all of them are specified in the config alongside the num_models key + # # ideally, we'd like to directly store all models in a list, but that # is not currently supported by the submodule logic + # so to access all the models, we do something like + # + # for model_idx in range(self.num_models): + # model = getattr(self, f"model{model_idx}") if 'num_models' in self.cfg: self.num_models = self.cfg.num_models - for idx, model_cfg in enumerate(self.cfg.models): + for idx in range(self.num_models): cfg_field = f"model{idx}" + model_cfg = self.cfg[cfg_field] + model_class = model_utils.import_class_by_path(model_cfg['target']) self.register_nemo_submodule( - name=cfg_field, config_field=cfg_field, model=ASRModel(model_cfg, trainer=trainer), + name=cfg_field, config_field=cfg_field, model=model_class(model_cfg, trainer=trainer), ) else: self.num_models = len(models) + OmegaConf.set_struct(self.cfg, False) self.cfg.num_models = self.num_models + OmegaConf.set_struct(self.cfg, True) for idx, model in enumerate(models): cfg_field = f"model{idx}" if model.endswith(".nemo"): @@ -76,13 +83,17 @@ def __init__( cfg_field, config_field=cfg_field, model=ASRModel.from_pretrained(model), ) - # registering the model selection block as an artifact - if model_selection_block_path: - self.register_artifact("model_selection_block", model_selection_block_path) - self.model_selection_block = joblib.load(model_selection_block_path) - else: # or loading from checkpoint if not specified - model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block) - self.model_selection_block = joblib.load(model_selection_block_path) + model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block) + self.model_selection_block = joblib.load(model_selection_block_path) + self.confidence = ConfidenceConfig(**self.cfg.confidence) + + # making sure each model has correct confidence settings in the decoder strategy + for model_idx in range(self.num_models): + model = getattr(self, f"model{model_idx}") + decoding_cfg = model.cfg.decoding + decoding_cfg.confidence_cfg = self.confidence + # TODO: is there a way to handle hybrid model change flexibly here? + model.change_decoding_strategy(decoding_cfg) def list_available_models(self): pass @@ -93,6 +104,21 @@ def setup_training_data(self): def setup_validation_data(self): pass + def change_attention_model(self, *args, **kwargs): + """Pass-through to the ensemble models.""" + for model_idx in range(self.num_models): + getattr(self, f"model{model_idx}").change_attention_model(*args, **kwargs) + + def change_decoding_strategy(self, decoding_cfg: DictConfig = None, **kwargs): + """Pass-through to the ensemble models. + + The only change here is that we always require frame-confidence + to be returned. + """ + decoding_cfg.confidence_cfg = self.confidence + for model_idx in range(self.num_models): + getattr(self, f"model{model_idx}").change_decoding_strategy(decoding_cfg, **kwargs) + @torch.no_grad() def transcribe( # TODO: rnnt takes different parameters? self, @@ -116,10 +142,30 @@ def transcribe( # TODO: rnnt takes different parameters? 3. Use logistic regression to pick the "most confident" model 4. Return the output of that model """ - hypotheses = [] - for model in self.models: - hypotheses.append(model.transcribe(*args, **kwargs)) - - from IPython import embed - - embed() + # TODO: lots of duplicate code with building ensemble script + aggr_func = get_confidence_aggregation_bank()[self.confidence.aggregation] + confidences = [] + all_transcriptions = [] + for model_idx in range(self.num_models): + model = getattr(self, f"model{model_idx}") + transcriptions = model.transcribe(*args, **kwargs) + + model_confidences = [] + for transcription in transcriptions: + if isinstance(transcription.frame_confidence[0], list): + # NeMo Transducer API returns list of lists for confidences + conf_values = [conf_value for confs in transcription.frame_confidence for conf_value in confs] + else: + conf_values = transcription.frame_confidence + model_confidences.append(aggr_func(conf_values)) + confidences.append(model_confidences) + all_transcriptions.append(transcriptions) + + # transposing with zip(*list) + features = np.array(list(zip(*confidences))) + model_indices = self.model_selection_block.predict(features) + final_transcriptions = [] + for transcrption_idx in range(len(all_transcriptions[0])): + final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx]) + + return final_transcriptions diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index e128673c1f0be..323970f7df634 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -128,7 +128,6 @@ def train_model_selection( def main(cfg: BuildEnsembleConfig): logging.info(f'Build ensemble config: {OmegaConf.to_yaml(cfg)}') - # TODO: does this validate arguments? Do we need the check? if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) @@ -181,10 +180,9 @@ def main(cfg: BuildEnsembleConfig): # creating ensemble checkpoint ensemble_model = ConfidenceEnsembleModel( - cfg=DictConfig({}), + cfg=DictConfig({'model_selection_block': model_selection_block_path, 'confidence': cfg.confidence,}), trainer=None, models=[model_cfg.model for model_cfg in cfg.ensemble], - model_selection_block_path=model_selection_block_path, ) ensemble_model.save_to(cfg.output_path)