Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding interCTC loss to hybrid models #6215

Merged
merged 7 commits into from
Mar 16, 2023
Merged

Conversation

Kipok
Copy link
Collaborator

@Kipok Kipok commented Mar 15, 2023

What does this PR do ?

Adding interCTC loss to hybrid models. To solve issue that ctc and hybrid models use different names for things that we need (e.g., wer vs ctc_wer), I'm now directly passing all necessary "main" class variables to the interctc_setup method (except for self.cfg). I'm not sure if there is much sense in still keeping interctc as mixin, given that it does not access anything from the main class anymore, but not making any changes for now.

Additionally, there are 2 slight bugs in the original code. First, inter_ctc_loss was incorrectly calculated in the logs (main loss was used, not inter_ctc there). Note that this is purely a logging thing and does not affect training. Second, main loss was not multiplied by the correct coefficient, which does affect training to some extent. So, the calculation should have been something like loss = final_loss * 0.7 + inter_loss * 0.3, while it was just loss = final_loss + inter_loss * 0.3

Collection: ASR

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@Kipok Kipok requested review from titu1994 and VahidooX March 15, 2023 23:27
@github-actions github-actions bot added the ASR label Mar 15, 2023
Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Some minor comments.

tests/collections/asr/test_asr_interctc_models.py Outdated Show resolved Hide resolved
import torch
from omegaconf import DictConfig, ListConfig

from nemo.collections.asr.metrics.wer import CTCDecoding, CTCDecodingConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports here.

@@ -104,7 +104,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.setup_optimization_flags()

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc()
self.setup_interctc(self._wer, self.encoder, self.decoder, self.loss)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order of inputs should be pytorch modules first, then loss, then metrics. This is by conversion in RNNT.

@@ -88,6 +88,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# setting the RNNT decoder as the default one
self.use_rnnt_decoder = True

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc(self.ctc_wer, self.encoder, self.ctc_decoder, self.ctc_loss)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow above convension

@@ -535,7 +538,6 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
tensorboard_logs['val_ctc_loss'] = ctc_loss
tensorboard_logs['val_rnnt_loss'] = loss_value
loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss
tensorboard_logs['val_loss'] = loss_value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert above removals - RNNT loss is calculated only optionally in inference mode, it is not set at all if the value is False (which is the default). Below, you can check and then add more to the value if interctc is enabled or ignore it if rnnt loss is not supposed to be logged.

@@ -582,7 +579,9 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
ctc_wer_num = torch.stack([x['val_wer_num_ctc'] for x in outputs]).sum()
ctc_wer_denom = torch.stack([x['val_wer_denom_ctc'] for x in outputs]).sum()
tensorboard_logs['val_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom
return {**val_loss_log, 'log': tensorboard_logs}
metrics = {**val_loss_log, 'log': tensorboard_logs}
self.finalize_interctc_metrics(metrics, outputs, prefix="val_")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

losses for rnnt are optional in val/test mode, so check and then add.

interctc_config = self.cfg.get("interctc")
if interctc_config is not None:
# if interctc is in the config, we want to check that it indeed defines
# the required keys and nothing else - that's automatically done by
# matching with keyword arguments in self._process_config_values
self._process_config_values(**interctc_config)
self._interctc_params['wer'] = wer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to keep a reference to these objects.? It can leak memory. Use WeakReference instead. Note that other objects like Decoding do keep the value of these objects, but they are changeable + visible directly as part of the model code, whereas this mixin is not.
It is easy to forget that this mixin has references to modules and can be forgotten during change of model vocabulary (which updates decoding + metric)..

I would prefer to simply not register modules like this inside of this mixin.

"""
if not self.is_interctc_enabled():
return []

# note that we have a loop here, because tensors can be defined from
# submodules of encoder (e.g., that's the case in Jasper)
total_registry = {}
for module_registry in AccessMixin.get_module_registry(self.encoder).values():
for key, value in module_registry.items():
for module_registry in AccessMixin.get_module_registry(self._interctc_params['encoder']).values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do self.encoder here - don't need to keep the encoder in a dict and have a reference to it

@@ -154,7 +169,9 @@ def get_captured_interctc_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor
raise RuntimeError(
"Make sure encoder.forward is called exactly one time before interCTC loss is computed."
)
captured_tensors.append((self.decoder(encoder_output=layer_outputs[0]), layer_lengths[0]))
captured_tensors.append(
(self._interctc_params['decoder'](encoder_output=layer_outputs[0]), layer_lengths[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use self.decoder here

):
inter_loss_value = self.loss(
inter_loss_value = self._interctc_params['loss'](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use self.loss?

loss_value += inter_loss_value * loss_weight
if compute_wer:
self._wer.update(
self._interctc_params['wer'].update(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use self._wer ? You can check and set self._wer or self.wer (or better yet, add alias of self.wer = self._wer as a property for CTC models.

VahidooX
VahidooX previously approved these changes Mar 16, 2023
Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Just one not used import of CTCDecoding in test_asr_interctc_models.py.

titu1994
titu1994 previously approved these changes Mar 16, 2023
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better, thanks ! Minor change then let's merge

@@ -104,7 +104,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.setup_optimization_flags()

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc()
self.setup_interctc('decoder', 'loss', '_wer')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use keyword args here ? Kinda hard to tell what is input to this function

@@ -88,6 +88,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# setting the RNNT decoder as the default one
self.use_rnnt_decoder = True

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc('ctc_decoder', 'ctc_loss', 'ctc_wer')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@Kipok Kipok dismissed stale reviews from titu1994 and VahidooX via 80e65c5 March 16, 2023 20:32
@Kipok Kipok mentioned this pull request Mar 16, 2023
8 tasks
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@titu1994 titu1994 merged commit 4d096a1 into NVIDIA:main Mar 16, 2023
@github-actions github-actions bot mentioned this pull request Mar 17, 2023
8 tasks
titu1994 pushed a commit to titu1994/NeMo that referenced this pull request Mar 24, 2023
* Add interctc functionality to hybrid models

Signed-off-by: Igor Gitman <[email protected]>

* Fix bugs with interctc loss

Signed-off-by: Igor Gitman <[email protected]>

* Update configs

Signed-off-by: Igor Gitman <[email protected]>

* Minor cleanup + use attribute names instead of objects in setup

Signed-off-by: Igor Gitman <[email protected]>

* Correctly handle compute_eval_loss=False

Signed-off-by: Igor Gitman <[email protected]>

* Add compute_eval_loss=False test cases

Signed-off-by: Igor Gitman <[email protected]>

* Remove unused import, add keyword args

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
* Add interctc functionality to hybrid models

Signed-off-by: Igor Gitman <[email protected]>

* Fix bugs with interctc loss

Signed-off-by: Igor Gitman <[email protected]>

* Update configs

Signed-off-by: Igor Gitman <[email protected]>

* Minor cleanup + use attribute names instead of objects in setup

Signed-off-by: Igor Gitman <[email protected]>

* Correctly handle compute_eval_loss=False

Signed-off-by: Igor Gitman <[email protected]>

* Add compute_eval_loss=False test cases

Signed-off-by: Igor Gitman <[email protected]>

* Remove unused import, add keyword args

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants