-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding interCTC loss to hybrid models #6215
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some minor comments.
import torch | ||
from omegaconf import DictConfig, ListConfig | ||
|
||
from nemo.collections.asr.metrics.wer import CTCDecoding, CTCDecodingConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused imports here.
@@ -104,7 +104,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |||
self.setup_optimization_flags() | |||
|
|||
# setting up interCTC loss (from InterCTCMixin) | |||
self.setup_interctc() | |||
self.setup_interctc(self._wer, self.encoder, self.decoder, self.loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
order of inputs should be pytorch modules first, then loss, then metrics. This is by conversion in RNNT.
@@ -88,6 +88,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |||
# setting the RNNT decoder as the default one | |||
self.use_rnnt_decoder = True | |||
|
|||
# setting up interCTC loss (from InterCTCMixin) | |||
self.setup_interctc(self.ctc_wer, self.encoder, self.ctc_decoder, self.ctc_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow above convension
@@ -535,7 +538,6 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): | |||
tensorboard_logs['val_ctc_loss'] = ctc_loss | |||
tensorboard_logs['val_rnnt_loss'] = loss_value | |||
loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss | |||
tensorboard_logs['val_loss'] = loss_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert above removals - RNNT loss is calculated only optionally in inference mode, it is not set at all if the value is False (which is the default). Below, you can check and then add more to the value if interctc is enabled or ignore it if rnnt loss is not supposed to be logged.
@@ -582,7 +579,9 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): | |||
ctc_wer_num = torch.stack([x['val_wer_num_ctc'] for x in outputs]).sum() | |||
ctc_wer_denom = torch.stack([x['val_wer_denom_ctc'] for x in outputs]).sum() | |||
tensorboard_logs['val_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom | |||
return {**val_loss_log, 'log': tensorboard_logs} | |||
metrics = {**val_loss_log, 'log': tensorboard_logs} | |||
self.finalize_interctc_metrics(metrics, outputs, prefix="val_") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
losses for rnnt are optional in val/test mode, so check and then add.
interctc_config = self.cfg.get("interctc") | ||
if interctc_config is not None: | ||
# if interctc is in the config, we want to check that it indeed defines | ||
# the required keys and nothing else - that's automatically done by | ||
# matching with keyword arguments in self._process_config_values | ||
self._process_config_values(**interctc_config) | ||
self._interctc_params['wer'] = wer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to keep a reference to these objects.? It can leak memory. Use WeakReference instead. Note that other objects like Decoding do keep the value of these objects, but they are changeable + visible directly as part of the model code, whereas this mixin is not.
It is easy to forget that this mixin has references to modules and can be forgotten during change of model vocabulary (which updates decoding + metric)..
I would prefer to simply not register modules like this inside of this mixin.
""" | ||
if not self.is_interctc_enabled(): | ||
return [] | ||
|
||
# note that we have a loop here, because tensors can be defined from | ||
# submodules of encoder (e.g., that's the case in Jasper) | ||
total_registry = {} | ||
for module_registry in AccessMixin.get_module_registry(self.encoder).values(): | ||
for key, value in module_registry.items(): | ||
for module_registry in AccessMixin.get_module_registry(self._interctc_params['encoder']).values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just do self.encoder here - don't need to keep the encoder in a dict and have a reference to it
@@ -154,7 +169,9 @@ def get_captured_interctc_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor | |||
raise RuntimeError( | |||
"Make sure encoder.forward is called exactly one time before interCTC loss is computed." | |||
) | |||
captured_tensors.append((self.decoder(encoder_output=layer_outputs[0]), layer_lengths[0])) | |||
captured_tensors.append( | |||
(self._interctc_params['decoder'](encoder_output=layer_outputs[0]), layer_lengths[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just use self.decoder here
): | ||
inter_loss_value = self.loss( | ||
inter_loss_value = self._interctc_params['loss']( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use self.loss?
loss_value += inter_loss_value * loss_weight | ||
if compute_wer: | ||
self._wer.update( | ||
self._interctc_params['wer'].update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use self._wer ? You can check and set self._wer or self.wer (or better yet, add alias of self.wer = self._wer as a property for CTC models.
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Just one not used import of CTCDecoding in test_asr_interctc_models.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better, thanks ! Minor change then let's merge
@@ -104,7 +104,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |||
self.setup_optimization_flags() | |||
|
|||
# setting up interCTC loss (from InterCTCMixin) | |||
self.setup_interctc() | |||
self.setup_interctc('decoder', 'loss', '_wer') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use keyword args here ? Kinda hard to tell what is input to this function
@@ -88,6 +88,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |||
# setting the RNNT decoder as the default one | |||
self.use_rnnt_decoder = True | |||
|
|||
# setting up interCTC loss (from InterCTCMixin) | |||
self.setup_interctc('ctc_decoder', 'ctc_loss', 'ctc_wer') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
Signed-off-by: Igor Gitman <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
* Add interctc functionality to hybrid models Signed-off-by: Igor Gitman <[email protected]> * Fix bugs with interctc loss Signed-off-by: Igor Gitman <[email protected]> * Update configs Signed-off-by: Igor Gitman <[email protected]> * Minor cleanup + use attribute names instead of objects in setup Signed-off-by: Igor Gitman <[email protected]> * Correctly handle compute_eval_loss=False Signed-off-by: Igor Gitman <[email protected]> * Add compute_eval_loss=False test cases Signed-off-by: Igor Gitman <[email protected]> * Remove unused import, add keyword args Signed-off-by: Igor Gitman <[email protected]> --------- Signed-off-by: Igor Gitman <[email protected]>
* Add interctc functionality to hybrid models Signed-off-by: Igor Gitman <[email protected]> * Fix bugs with interctc loss Signed-off-by: Igor Gitman <[email protected]> * Update configs Signed-off-by: Igor Gitman <[email protected]> * Minor cleanup + use attribute names instead of objects in setup Signed-off-by: Igor Gitman <[email protected]> * Correctly handle compute_eval_loss=False Signed-off-by: Igor Gitman <[email protected]> * Add compute_eval_loss=False test cases Signed-off-by: Igor Gitman <[email protected]> * Remove unused import, add keyword args Signed-off-by: Igor Gitman <[email protected]> --------- Signed-off-by: Igor Gitman <[email protected]> Signed-off-by: hsiehjackson <[email protected]>
What does this PR do ?
Adding interCTC loss to hybrid models. To solve issue that ctc and hybrid models use different names for things that we need (e.g., wer vs ctc_wer), I'm now directly passing all necessary "main" class variables to the interctc_setup method (except for self.cfg). I'm not sure if there is much sense in still keeping interctc as mixin, given that it does not access anything from the main class anymore, but not making any changes for now.
Additionally, there are 2 slight bugs in the original code. First, inter_ctc_loss was incorrectly calculated in the logs (main loss was used, not inter_ctc there). Note that this is purely a logging thing and does not affect training. Second, main loss was not multiplied by the correct coefficient, which does affect training to some extent. So, the calculation should have been something like
loss = final_loss * 0.7 + inter_loss * 0.3
, while it was justloss = final_loss + inter_loss * 0.3
Collection: ASR
Changelog
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information