Skip to content

Commit

Permalink
Fixing an issue with confidence ensembles (NVIDIA#6987)
Browse files Browse the repository at this point in the history
* Bug fix for the confidence ensembles

Signed-off-by: Igor Gitman <[email protected]>

* Relax constraints for the test

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored and zhehuaichen committed Oct 4, 2023
1 parent 0acb36a commit 9a8ce45
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
8 changes: 6 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class TranscriptionConfig:

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False
# set to True if need to return full alignment information
preserve_alignment: bool = False

# Set to True to output language ID information
compute_langs: bool = False
Expand Down Expand Up @@ -230,6 +232,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
# has to be True if timestamps are required
preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
Expand All @@ -252,7 +256,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it
if 'preserve_alignments' in decoding_cfg:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
decoding_cfg.preserve_alignments = preserve_alignment
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
if hasattr(asr_model, 'cur_decoder'):
Expand All @@ -267,7 +271,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
cfg.rnnt_decoding.compute_langs = cfg.compute_langs

if 'preserve_alignments' in cfg.rnnt_decoding:
cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps
cfg.rnnt_decoding.preserve_alignments = preserve_alignment

asr_model.change_decoding_strategy(cfg.rnnt_decoding)
else:
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/asr/models/confidence_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def get_filtered_logprobs(hypothesis: Hypothesis, exclude_blank: bool) -> torch.
filtered_logprobs = logprobs[:1]
else:
filtered_logprobs = logprobs

# need to make sure logprobs are always normalized, so checking if they sum up to 1
if not torch.allclose(filtered_logprobs[0].exp().sum(), torch.tensor(1.0)):
filtered_logprobs = torch.log_softmax(filtered_logprobs, dim=1)

return filtered_logprobs


Expand Down Expand Up @@ -217,10 +222,6 @@ def update_decoding_parameters(self, decoding_cfg: DictConfig):
with open_dict(decoding_cfg):
decoding_cfg.temperature = self.cfg.temperature
decoding_cfg.preserve_alignments = True
if 'confidence_cfg' in decoding_cfg:
decoding_cfg.confidence_cfg.preserve_frame_confidence = True
else:
decoding_cfg.confidence_cfg = ConfidenceConfig(preserve_frame_confidence=True)

def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
"""Pass-through to the ensemble models.
Expand Down
6 changes: 2 additions & 4 deletions scripts/confidence_ensembles/build_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def find_best_confidence(
return best_conf_spec.to_confidence_config(), best_pipe


@hydra_runner(schema=BuildEnsembleConfig)
@hydra_runner(config_name="BuildEnsembleConfig", schema=BuildEnsembleConfig)
def main(cfg: BuildEnsembleConfig):
# silencing all messages from nemo/ptl to avoid dumping tons of configs to the stdout
logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
Expand All @@ -471,12 +471,10 @@ def main(cfg: BuildEnsembleConfig):
pl.seed_everything(cfg.random_seed)
cfg.transcription.random_seed = None # seed is already applied
cfg.transcription.return_transcriptions = True
# that sets preserve_alignment to True
cfg.transcription.compute_timestamps = True
cfg.transcription.preserve_alignment = True
cfg.transcription.ctc_decoding.temperature = cfg.temperature
cfg.transcription.rnnt_decoding.temperature = cfg.temperature
# this ensures that generated output is after log-softmax for consistency with CTC
cfg.transcription.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True

train_confidences = []
dev_confidences = []
Expand Down
2 changes: 1 addition & 1 deletion scripts/confidence_ensembles/test_confidence_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def test_confidence_ensemble(tmp_path, build_args):
)

results = speech_to_text_eval.main(eval_cfg)
assert results.metric_value < 0.15 # relaxed check for better than 15% WER
assert results.metric_value < 0.20 # relaxed check for better than 20% WER

0 comments on commit 9a8ce45

Please sign in to comment.