Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InterCTC loss and stochastic depth implementation #6013

Merged
merged 20 commits into from
Feb 18, 2023

Conversation

Kipok
Copy link
Collaborator

@Kipok Kipok commented Feb 14, 2023

What does this PR do ?

Adding intermediate CTC loss and stochastic depth as described in https://arxiv.org/abs/2102.03216.

The current implementation is only for conformer encoder, but I'm not really sure how to write a generic code for this case. Please let me know if you have some ideas here.

Collection: ASR

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below

To use, specify parameters in the config. E.g., for stochastic depth:

model.encoder.stochastic_depth_mode=linear  
model.encoder.stochastic_depth_drop_prob=0.3 
model.encoder.stochastic_depth_start_layer=0

For intermediate CTC loss:

model.encoder.capture_output_at_layers=[9] 
model.intermediate_loss_weights=[0.3]

I've added the docs to ConformerEncoder, but not sure how to add the docs to the new intermediate_loss_weights parameter of the EncDecCTCModel. Please let me know what's the right place to put those docs in.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is good, but the code is becoming complicated.
We have two options -

  1. Keep it inside ctc_model.py - but have mixin class that deals with interctc parts of the code including any and all functions needed by it. Then the ctc_model.py class simply calls these functions.
  2. Write entirely separate class for ctc_models.py - subclass it and override the parts we need - which is primarily the forward, parts of training step and validation_step.

A bad option is to merge it as is right now - would significantly complicate the ctc_models.py training step and validation step plus cause issues with long audio inference where we don't want interctc tensor to waste memory.

@VahidooX What is your preference ?

@@ -524,6 +536,23 @@ def forward(
encoded_len = encoder_output[1]
log_probs = self.decoder(encoder_output=encoded)
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)
# generating decoding results for intermediate layers if necessary
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new line after greedy preds, plus some docstring explaining which paper this is implementing and what it does.

@@ -524,6 +536,23 @@ def forward(
encoded_len = encoder_output[1]
log_probs = self.decoder(encoder_output=encoded)
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)
# generating decoding results for intermediate layers if necessary
if self.intermediate_loss_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be explicit and check len(self.intermediate_loss_weights) > 0.0

if self.intermediate_loss_weights:
# we assume that encoder has to have property called "captured_layer_outputs"
# which is a list with the same length as loss weights
if len(self.encoder.captured_layer_outputs) != len(self.intermediate_loss_weights):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont assume. First and foremost you should not be attaching tensors with grads to a module - ever. You can instead use the Tensor Registry framework in NeMo. Most models already support it due to other parts needing it. So do the following - In the training_step, enable the tensor registry, let modules register their forward activation (or a subset of it) and then access the registry here. Empty the registry at end of the train step to prevent memory leak.

So do a hasattr check, then raise a proper error (Value/Runtime) which details

  1. Provided CTC intermediate loss weights list is not empty
  2. But the model does not add anything to the registry.

Also, this will waste memory during inference if the user does not need interctc outputs unless explicitly requested. So put a self.training check above for this part.

)
self.intermediate_decoding_results = [None] * len(self.encoder.captured_layer_outputs)
for idx, captured_output in enumerate(self.encoder.captured_layer_outputs):
self.intermediate_decoding_results[idx] = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not set values - especially cuda tensors - to a module during training. They will not be garbage collected properly.

target_lengths=transcript_len,
input_lengths=intermediate_result[1],
)
tensorboard_logs[f"inter_ctc_loss{idx}"] = loss_value.detach()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be inter_loss_value

raise ValueError('stochastic_depth_mode has to be one of ["linear", "uniform"].')
self.layer_drop_probs = layer_drop_probs
self.capture_output_at_layers = capture_output_at_layers
if self.capture_output_at_layers is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont cache tensors inside a module during the forward pass. use the tensor registry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

@@ -478,6 +523,17 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time
cache_last_channel_next=cache_last_channel_next,
cache_last_time_next=cache_last_time_next,
)
if self.training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring here for stochastic depth explanaing whats being done here.
Also add the condition right here that stochastic_depth_drop_prob > 0.0 after self.training

if should_drop:
# that's not efficient, but it's hard to implement distributed
# version of dropping layers without deadlock or random seed meddling
# so multiplying the signal by 0 to ensure all weights get gradients
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine.

@@ -487,6 +543,9 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
pad_mask, att_mask = self._create_masks(max_audio_length, length, audio_signal.device)

if lth in self.capture_output_at_layers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tensor registry here.

@@ -496,6 +555,9 @@ def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time

audio_signal = torch.transpose(audio_signal, 1, 2)

for captured_output in self.captured_layer_outputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do above when you use the tensor registry.

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks a lot better, minor comments and docstring additions to InterCTCMixin class, then can be merged.

nemo/collections/asr/models/ctc_models.py Show resolved Hide resolved
@@ -536,6 +540,9 @@ def training_step(self, batch, batch_nb):
if AccessMixin.is_access_enabled():
AccessMixin.reset_registry(self)

if self.interctc_enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a function in InterCTCMixin, not a variable assigned to self.

nemo/collections/asr/models/ctc_models.py Show resolved Hide resolved
nemo/collections/asr/models/ctc_models.py Show resolved Hide resolved
nemo/collections/asr/models/ctc_models.py Show resolved Hide resolved
if len(self.encoder.capture_output_at_layers) != len(self.intermediate_loss_weights):
raise ValueError('Length of encoder.capture_output_at_layers has to match intermediate_loss_weights')

def finalize_interctc_metrics(self, metrics, outputs, prefix):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods need docstrings

[x[f"{prefix}final_ctc_loss"] for x in outputs]
).mean()

def get_captured_tensors(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets have the keyword interctc somewhere in the name of all methods to avoid name collision. Or make them private if wont be used outside of this class.

# if intermediate_loss_weights was set, the encoder has to register
# layer_output_X and layer_length_X tensors. We need to apply decoder
# to each of them and compute CTC loss.
module_registry = AccessMixin.get_module_registry(self.encoder)[''] # key for encoder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[''] ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the key that AccessMixin has when assigning tensors to the current module. I don't have control over it :)

@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`.


CTC Configurations
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994 not sure if that's the best place to put these docs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the configuration of ASR models, doesn't look like that bad to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, looks fine

Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just minor comments!

nemo/collections/asr/modules/conformer_encoder.py Outdated Show resolved Hide resolved
nemo/collections/asr/models/ctc_models.py Fixed Show resolved Hide resolved
@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`.


CTC Configurations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the configuration of ASR models, doesn't look like that bad to me.

CTC Configurations
------------------

All CTC-based models also support `InterCTC loss <https://arxiv.org/abs/2102.03216>`_. To use it, you need to specify
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the stochastic depth docs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the best place to put them in? They are currently in the code for ConformerEncoder and only supported there (but for any model that's using it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same place looks good to me. You may mention in the descriptions that it is just supported for Conformer-based models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this one is called "CTC Configurations", while stochastic depth is for both CTC and transducer (although only for conformer-based versions now)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant creating a new section in the same file titled "Stochastic Depth" or something like this.
You may also rename this section to "InterCTC Loss" instead of "CTC Configurations"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Rename to InterCTC Config and add new config Stochastic Depth. Note that conformer is only supported now. We can add more models in the future

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments. The PR is ready to merge after that

@@ -546,6 +547,38 @@ The encoder section includes the details about the RNN-based encoder architectur
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`.


CTC Configurations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, looks fine

CTC Configurations
------------------

All CTC-based models also support `InterCTC loss <https://arxiv.org/abs/2102.03216>`_. To use it, you need to specify
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Rename to InterCTC Config and add new config Stochastic Depth. Note that conformer is only supported now. We can add more models in the future

docs/source/asr/configs.rst Show resolved Hide resolved
docs/source/asr/configs.rst Outdated Show resolved Hide resolved
@Kipok Kipok requested a review from titu1994 February 17, 2023 18:02
titu1994
titu1994 previously approved these changes Feb 17, 2023
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After review, seems I was mistaken. All cases are covered correctly.
LGTM, ready to merge

Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
@Kipok Kipok merged commit 83859ec into NVIDIA:main Feb 18, 2023
titu1994 pushed a commit to titu1994/NeMo that referenced this pull request Mar 24, 2023
* Some simplifications

Signed-off-by: Igor Gitman <[email protected]>

* Add tests for stochastic depth

Signed-off-by: Igor Gitman <[email protected]>

* Fix tests for stochastic depth

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc loss and logs

Signed-off-by: Igor Gitman <[email protected]>

* Fix a few issues

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc loss tests

Signed-off-by: Igor Gitman <[email protected]>

* Add docs

Signed-off-by: Igor Gitman <[email protected]>

* Add training_step test for interctc

Signed-off-by: Igor Gitman <[email protected]>

* Refactoring with AccessMixin WIP

Signed-off-by: Igor Gitman <[email protected]>

* Separate interctc logic into a mixin

Signed-off-by: Igor Gitman <[email protected]>

* Fix tests

Signed-off-by: Igor Gitman <[email protected]>

* Fix some lint errors

Signed-off-by: Igor Gitman <[email protected]>

* Small refactoring

Signed-off-by: Igor Gitman <[email protected]>

* Add more docs, fix PR comments

Signed-off-by: Igor Gitman <[email protected]>

* Add other encoder support + more refactoring

Signed-off-by: Igor Gitman <[email protected]>

* Add more config examples

Signed-off-by: Igor Gitman <[email protected]>

* Move stochastic depth setup to utils

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc_enabled setter + more docs

Signed-off-by: Igor Gitman <[email protected]>

* Fix a few doc strings for better web display

Signed-off-by: Igor Gitman <[email protected]>

* Update CTC flow diagram

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
* Some simplifications

Signed-off-by: Igor Gitman <[email protected]>

* Add tests for stochastic depth

Signed-off-by: Igor Gitman <[email protected]>

* Fix tests for stochastic depth

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc loss and logs

Signed-off-by: Igor Gitman <[email protected]>

* Fix a few issues

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc loss tests

Signed-off-by: Igor Gitman <[email protected]>

* Add docs

Signed-off-by: Igor Gitman <[email protected]>

* Add training_step test for interctc

Signed-off-by: Igor Gitman <[email protected]>

* Refactoring with AccessMixin WIP

Signed-off-by: Igor Gitman <[email protected]>

* Separate interctc logic into a mixin

Signed-off-by: Igor Gitman <[email protected]>

* Fix tests

Signed-off-by: Igor Gitman <[email protected]>

* Fix some lint errors

Signed-off-by: Igor Gitman <[email protected]>

* Small refactoring

Signed-off-by: Igor Gitman <[email protected]>

* Add more docs, fix PR comments

Signed-off-by: Igor Gitman <[email protected]>

* Add other encoder support + more refactoring

Signed-off-by: Igor Gitman <[email protected]>

* Add more config examples

Signed-off-by: Igor Gitman <[email protected]>

* Move stochastic depth setup to utils

Signed-off-by: Igor Gitman <[email protected]>

* Add interctc_enabled setter + more docs

Signed-off-by: Igor Gitman <[email protected]>

* Fix a few doc strings for better web display

Signed-off-by: Igor Gitman <[email protected]>

* Update CTC flow diagram

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants