Skip to content

Commit

Permalink
Working version
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok committed May 9, 2023
1 parent 0016f47 commit 64fa1d3
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 29 deletions.
4 changes: 0 additions & 4 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,6 @@ def autocast():
if type(transcriptions) == tuple and len(transcriptions) == 2:
transcriptions = transcriptions[0]

from IPython import embed

embed()

# write audio transcriptions
output_filename = write_transcription(
transcriptions,
Expand Down
90 changes: 68 additions & 22 deletions nemo/collections/asr/models/confidence_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, List, Optional, Union

import joblib
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
Expand All @@ -30,41 +31,47 @@
from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRModuleMixin, InterCTCMixin
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, get_confidence_aggregation_bank
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
from nemo.utils import logging
from nemo.utils import logging, model_utils

__all__ = ['ConfidenceEnsembleModel']


class ConfidenceEnsembleModel(ModelPT):
def __init__(
self,
cfg: DictConfig,
trainer: 'Trainer' = None,
models: Optional[List[str]] = None,
model_selection_block_path: Optional[str] = None,
self, cfg: DictConfig, trainer: 'Trainer' = None, models: Optional[List[str]] = None,
):
super().__init__(cfg=cfg, trainer=trainer)

# either we load all models from ``models`` init parameter
# or all of them are specified in the config alongside the num_models key
#
# ideally, we'd like to directly store all models in a list, but that
# is not currently supported by the submodule logic
# so to access all the models, we do something like
#
# for model_idx in range(self.num_models):
# model = getattr(self, f"model{model_idx}")

if 'num_models' in self.cfg:
self.num_models = self.cfg.num_models
for idx, model_cfg in enumerate(self.cfg.models):
for idx in range(self.num_models):
cfg_field = f"model{idx}"
model_cfg = self.cfg[cfg_field]
model_class = model_utils.import_class_by_path(model_cfg['target'])
self.register_nemo_submodule(
name=cfg_field, config_field=cfg_field, model=ASRModel(model_cfg, trainer=trainer),
name=cfg_field, config_field=cfg_field, model=model_class(model_cfg, trainer=trainer),
)
else:
self.num_models = len(models)
OmegaConf.set_struct(self.cfg, False)
self.cfg.num_models = self.num_models
OmegaConf.set_struct(self.cfg, True)
for idx, model in enumerate(models):
cfg_field = f"model{idx}"
if model.endswith(".nemo"):
Expand All @@ -76,13 +83,17 @@ def __init__(
cfg_field, config_field=cfg_field, model=ASRModel.from_pretrained(model),
)

# registering the model selection block as an artifact
if model_selection_block_path:
self.register_artifact("model_selection_block", model_selection_block_path)
self.model_selection_block = joblib.load(model_selection_block_path)
else: # or loading from checkpoint if not specified
model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block)
self.model_selection_block = joblib.load(model_selection_block_path)
model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block)
self.model_selection_block = joblib.load(model_selection_block_path)
self.confidence = ConfidenceConfig(**self.cfg.confidence)

# making sure each model has correct confidence settings in the decoder strategy
for model_idx in range(self.num_models):
model = getattr(self, f"model{model_idx}")
decoding_cfg = model.cfg.decoding
decoding_cfg.confidence_cfg = self.confidence
# TODO: is there a way to handle hybrid model change flexibly here?
model.change_decoding_strategy(decoding_cfg)

def list_available_models(self):
pass
Expand All @@ -93,6 +104,21 @@ def setup_training_data(self):
def setup_validation_data(self):
pass

def change_attention_model(self, *args, **kwargs):
"""Pass-through to the ensemble models."""
for model_idx in range(self.num_models):
getattr(self, f"model{model_idx}").change_attention_model(*args, **kwargs)

def change_decoding_strategy(self, decoding_cfg: DictConfig = None, **kwargs):
"""Pass-through to the ensemble models.
The only change here is that we always require frame-confidence
to be returned.
"""
decoding_cfg.confidence_cfg = self.confidence
for model_idx in range(self.num_models):
getattr(self, f"model{model_idx}").change_decoding_strategy(decoding_cfg, **kwargs)

@torch.no_grad()
def transcribe( # TODO: rnnt takes different parameters?
self,
Expand All @@ -116,10 +142,30 @@ def transcribe( # TODO: rnnt takes different parameters?
3. Use logistic regression to pick the "most confident" model
4. Return the output of that model
"""
hypotheses = []
for model in self.models:
hypotheses.append(model.transcribe(*args, **kwargs))

from IPython import embed

embed()
# TODO: lots of duplicate code with building ensemble script
aggr_func = get_confidence_aggregation_bank()[self.confidence.aggregation]
confidences = []
all_transcriptions = []
for model_idx in range(self.num_models):
model = getattr(self, f"model{model_idx}")
transcriptions = model.transcribe(*args, **kwargs)

model_confidences = []
for transcription in transcriptions:
if isinstance(transcription.frame_confidence[0], list):
# NeMo Transducer API returns list of lists for confidences
conf_values = [conf_value for confs in transcription.frame_confidence for conf_value in confs]
else:
conf_values = transcription.frame_confidence
model_confidences.append(aggr_func(conf_values))
confidences.append(model_confidences)
all_transcriptions.append(transcriptions)

# transposing with zip(*list)
features = np.array(list(zip(*confidences)))
model_indices = self.model_selection_block.predict(features)
final_transcriptions = []
for transcrption_idx in range(len(all_transcriptions[0])):
final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx])

return final_transcriptions
4 changes: 1 addition & 3 deletions scripts/confidence_ensembles/build_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def train_model_selection(
def main(cfg: BuildEnsembleConfig):
logging.info(f'Build ensemble config: {OmegaConf.to_yaml(cfg)}')

# TODO: does this validate arguments? Do we need the check?
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -181,10 +180,9 @@ def main(cfg: BuildEnsembleConfig):

# creating ensemble checkpoint
ensemble_model = ConfidenceEnsembleModel(
cfg=DictConfig({}),
cfg=DictConfig({'model_selection_block': model_selection_block_path, 'confidence': cfg.confidence,}),
trainer=None,
models=[model_cfg.model for model_cfg in cfg.ensemble],
model_selection_block_path=model_selection_block_path,
)
ensemble_model.save_to(cfg.output_path)

Expand Down

0 comments on commit 64fa1d3

Please sign in to comment.