-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
InterCTC loss and stochastic depth implementation #6013
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is good, but the code is becoming complicated.
We have two options -
- Keep it inside ctc_model.py - but have mixin class that deals with interctc parts of the code including any and all functions needed by it. Then the ctc_model.py class simply calls these functions.
- Write entirely separate class for ctc_models.py - subclass it and override the parts we need - which is primarily the forward, parts of training step and validation_step.
A bad option is to merge it as is right now - would significantly complicate the ctc_models.py training step and validation step plus cause issues with long audio inference where we don't want interctc tensor to waste memory.
@VahidooX What is your preference ?
@@ -524,6 +536,23 @@ def forward( | |||
encoded_len = encoder_output[1] | |||
log_probs = self.decoder(encoder_output=encoded) | |||
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) | |||
# generating decoding results for intermediate layers if necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add new line after greedy preds, plus some docstring explaining which paper this is implementing and what it does.
@@ -524,6 +536,23 @@ def forward( | |||
encoded_len = encoder_output[1] | |||
log_probs = self.decoder(encoder_output=encoded) | |||
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) | |||
# generating decoding results for intermediate layers if necessary | |||
if self.intermediate_loss_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be explicit and check len(self.intermediate_loss_weights) > 0.0
if self.intermediate_loss_weights: | ||
# we assume that encoder has to have property called "captured_layer_outputs" | ||
# which is a list with the same length as loss weights | ||
if len(self.encoder.captured_layer_outputs) != len(self.intermediate_loss_weights): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont assume. First and foremost you should not be attaching tensors with grads to a module - ever. You can instead use the Tensor Registry framework in NeMo. Most models already support it due to other parts needing it. So do the following - In the training_step, enable the tensor registry, let modules register their forward activation (or a subset of it) and then access the registry here. Empty the registry at end of the train step to prevent memory leak.
So do a hasattr check, then raise a proper error (Value/Runtime) which details
- Provided CTC intermediate loss weights list is not empty
- But the model does not add anything to the registry.
Also, this will waste memory during inference if the user does not need interctc outputs unless explicitly requested. So put a self.training check above for this part.
) | ||
self.intermediate_decoding_results = [None] * len(self.encoder.captured_layer_outputs) | ||
for idx, captured_output in enumerate(self.encoder.captured_layer_outputs): | ||
self.intermediate_decoding_results[idx] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not set values - especially cuda tensors - to a module during training. They will not be garbage collected properly.
target_lengths=transcript_len, | ||
input_lengths=intermediate_result[1], | ||
) | ||
tensorboard_logs[f"inter_ctc_loss{idx}"] = loss_value.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be inter_loss_value
raise ValueError('stochastic_depth_mode has to be one of ["linear", "uniform"].') | ||
self.layer_drop_probs = layer_drop_probs | ||
self.capture_output_at_layers = capture_output_at_layers | ||
if self.capture_output_at_layers is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont cache tensors inside a module during the forward pass. use the tensor registry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++
@@ -478,6 +523,17 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time | |||
cache_last_channel_next=cache_last_channel_next, | |||
cache_last_time_next=cache_last_time_next, | |||
) | |||
if self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring here for stochastic depth explanaing whats being done here.
Also add the condition right here that stochastic_depth_drop_prob > 0.0 after self.training
if should_drop: | ||
# that's not efficient, but it's hard to implement distributed | ||
# version of dropping layers without deadlock or random seed meddling | ||
# so multiplying the signal by 0 to ensure all weights get gradients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine.
@@ -487,6 +543,9 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time | |||
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) | |||
pad_mask, att_mask = self._create_masks(max_audio_length, length, audio_signal.device) | |||
|
|||
if lth in self.capture_output_at_layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use tensor registry here.
@@ -496,6 +555,9 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time | |||
|
|||
audio_signal = torch.transpose(audio_signal, 1, 2) | |||
|
|||
for captured_output in self.captured_layer_outputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do above when you use the tensor registry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks a lot better, minor comments and docstring additions to InterCTCMixin class, then can be merged.
@@ -536,6 +540,9 @@ def training_step(self, batch, batch_nb): | |||
if AccessMixin.is_access_enabled(): | |||
AccessMixin.reset_registry(self) | |||
|
|||
if self.interctc_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be a function in InterCTCMixin, not a variable assigned to self.
if len(self.encoder.capture_output_at_layers) != len(self.intermediate_loss_weights): | ||
raise ValueError('Length of encoder.capture_output_at_layers has to match intermediate_loss_weights') | ||
|
||
def finalize_interctc_metrics(self, metrics, outputs, prefix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Methods need docstrings
[x[f"{prefix}final_ctc_loss"] for x in outputs] | ||
).mean() | ||
|
||
def get_captured_tensors(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets have the keyword interctc
somewhere in the name of all methods to avoid name collision. Or make them private if wont be used outside of this class.
# if intermediate_loss_weights was set, the encoder has to register | ||
# layer_output_X and layer_length_X tensors. We need to apply decoder | ||
# to each of them and compute CTC loss. | ||
module_registry = AccessMixin.get_module_registry(self.encoder)[''] # key for encoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[''] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the key that AccessMixin has when assigning tensors to the current module. I don't have control over it :)
378c8b3
to
a828d6c
Compare
docs/source/asr/configs.rst
Outdated
@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur | |||
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`. | |||
|
|||
|
|||
CTC Configurations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@titu1994 not sure if that's the best place to put these docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for the configuration of ASR models, doesn't look like that bad to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, looks fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor comments!
docs/source/asr/configs.rst
Outdated
@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur | |||
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`. | |||
|
|||
|
|||
CTC Configurations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for the configuration of ASR models, doesn't look like that bad to me.
CTC Configurations | ||
------------------ | ||
|
||
All CTC-based models also support `InterCTC loss <https://arxiv.org/abs/2102.03216>`_. To use it, you need to specify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the stochastic depth docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the best place to put them in? They are currently in the code for ConformerEncoder and only supported there (but for any model that's using it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same place looks good to me. You may mention in the descriptions that it is just supported for Conformer-based models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this one is called "CTC Configurations", while stochastic depth is for both CTC and transducer (although only for conformer-based versions now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant creating a new section in the same file titled "Stochastic Depth" or something like this.
You may also rename this section to "InterCTC Loss" instead of "CTC Configurations"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Rename to InterCTC Config and add new config Stochastic Depth. Note that conformer is only supported now. We can add more models in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments. The PR is ready to merge after that
docs/source/asr/configs.rst
Outdated
@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur | |||
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`. | |||
|
|||
|
|||
CTC Configurations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, looks fine
CTC Configurations | ||
------------------ | ||
|
||
All CTC-based models also support `InterCTC loss <https://arxiv.org/abs/2102.03216>`_. To use it, you need to specify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Rename to InterCTC Config and add new config Stochastic Depth. Note that conformer is only supported now. We can add more models in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After review, seems I was mistaken. All cases are covered correctly.
LGTM, ready to merge
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
3d92c70
to
5b27a82
Compare
* Some simplifications Signed-off-by: Igor Gitman <[email protected]> * Add tests for stochastic depth Signed-off-by: Igor Gitman <[email protected]> * Fix tests for stochastic depth Signed-off-by: Igor Gitman <[email protected]> * Add interctc loss and logs Signed-off-by: Igor Gitman <[email protected]> * Fix a few issues Signed-off-by: Igor Gitman <[email protected]> * Add interctc loss tests Signed-off-by: Igor Gitman <[email protected]> * Add docs Signed-off-by: Igor Gitman <[email protected]> * Add training_step test for interctc Signed-off-by: Igor Gitman <[email protected]> * Refactoring with AccessMixin WIP Signed-off-by: Igor Gitman <[email protected]> * Separate interctc logic into a mixin Signed-off-by: Igor Gitman <[email protected]> * Fix tests Signed-off-by: Igor Gitman <[email protected]> * Fix some lint errors Signed-off-by: Igor Gitman <[email protected]> * Small refactoring Signed-off-by: Igor Gitman <[email protected]> * Add more docs, fix PR comments Signed-off-by: Igor Gitman <[email protected]> * Add other encoder support + more refactoring Signed-off-by: Igor Gitman <[email protected]> * Add more config examples Signed-off-by: Igor Gitman <[email protected]> * Move stochastic depth setup to utils Signed-off-by: Igor Gitman <[email protected]> * Add interctc_enabled setter + more docs Signed-off-by: Igor Gitman <[email protected]> * Fix a few doc strings for better web display Signed-off-by: Igor Gitman <[email protected]> * Update CTC flow diagram Signed-off-by: Igor Gitman <[email protected]> --------- Signed-off-by: Igor Gitman <[email protected]>
* Some simplifications Signed-off-by: Igor Gitman <[email protected]> * Add tests for stochastic depth Signed-off-by: Igor Gitman <[email protected]> * Fix tests for stochastic depth Signed-off-by: Igor Gitman <[email protected]> * Add interctc loss and logs Signed-off-by: Igor Gitman <[email protected]> * Fix a few issues Signed-off-by: Igor Gitman <[email protected]> * Add interctc loss tests Signed-off-by: Igor Gitman <[email protected]> * Add docs Signed-off-by: Igor Gitman <[email protected]> * Add training_step test for interctc Signed-off-by: Igor Gitman <[email protected]> * Refactoring with AccessMixin WIP Signed-off-by: Igor Gitman <[email protected]> * Separate interctc logic into a mixin Signed-off-by: Igor Gitman <[email protected]> * Fix tests Signed-off-by: Igor Gitman <[email protected]> * Fix some lint errors Signed-off-by: Igor Gitman <[email protected]> * Small refactoring Signed-off-by: Igor Gitman <[email protected]> * Add more docs, fix PR comments Signed-off-by: Igor Gitman <[email protected]> * Add other encoder support + more refactoring Signed-off-by: Igor Gitman <[email protected]> * Add more config examples Signed-off-by: Igor Gitman <[email protected]> * Move stochastic depth setup to utils Signed-off-by: Igor Gitman <[email protected]> * Add interctc_enabled setter + more docs Signed-off-by: Igor Gitman <[email protected]> * Fix a few doc strings for better web display Signed-off-by: Igor Gitman <[email protected]> * Update CTC flow diagram Signed-off-by: Igor Gitman <[email protected]> --------- Signed-off-by: Igor Gitman <[email protected]> Signed-off-by: hsiehjackson <[email protected]>
What does this PR do ?
Adding intermediate CTC loss and stochastic depth as described in https://arxiv.org/abs/2102.03216.
The current implementation is only for conformer encoder, but I'm not really sure how to write a generic code for this case. Please let me know if you have some ideas here.
Collection: ASR
Changelog
Usage
To use, specify parameters in the config. E.g., for stochastic depth:
For intermediate CTC loss:
I've added the docs to ConformerEncoder, but not sure how to add the docs to the new
intermediate_loss_weights
parameter of theEncDecCTCModel
. Please let me know what's the right place to put those docs in.Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information