Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InterCTC loss and stochastic depth implementation #6013

Merged
merged 20 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/asr/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ Mixins
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.mixins.interctc_mixin.InterCTCMixin
:show-inheritance:
:members:

Datasets
--------

Expand Down
95 changes: 74 additions & 21 deletions docs/source/asr/configs.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
NeMo ASR Configuration Files
============================

This section describes the NeMo configuration file setup that is specific to models in the ASR collection. For general information
about how to set up and run experiments that is common to all NeMo models (e.g. Experiment Manager and PyTorch Lightning trainer
This section describes the NeMo configuration file setup that is specific to models in the ASR collection. For general information
about how to set up and run experiments that is common to all NeMo models (e.g. Experiment Manager and PyTorch Lightning trainer
parameters), see the :doc:`../core/core` section.

The model section of the NeMo ASR configuration files generally requires information about the dataset(s) being used, the preprocessor
for audio files, parameters for any augmentation being performed, as well as the model architecture specification. The sections on
The model section of the NeMo ASR configuration files generally requires information about the dataset(s) being used, the preprocessor
for audio files, parameters for any augmentation being performed, as well as the model architecture specification. The sections on
this page cover each of these in more detail.

Example configuration files for all of the NeMo ASR scripts can be found in the
Expand All @@ -17,8 +17,8 @@ Dataset Configuration
---------------------

Training, validation, and test parameters are specified using the ``train_ds``, ``validation_ds``, and
``test_ds`` sections in the configuration file, respectively. Depending on the task, there may be arguments specifying the sample rate
of the audio files, the vocabulary of the dataset (for character prediction), whether or not to shuffle the dataset, and so on. You may
``test_ds`` sections in the configuration file, respectively. Depending on the task, there may be arguments specifying the sample rate
of the audio files, the vocabulary of the dataset (for character prediction), whether or not to shuffle the dataset, and so on. You may
also decide to leave fields such as the ``manifest_filepath`` blank, to be specified via the command-line at runtime.

Any initialization parameter that is accepted for the Dataset class used in the experiment can be set in the config file.
Expand Down Expand Up @@ -80,7 +80,7 @@ Preprocessor Configuration
--------------------------

If you are loading audio files for your experiment, you will likely want to use a preprocessor to convert from the
raw audio signal to features (e.g. mel-spectrogram or MFCC). The ``preprocessor`` section of the config specifies the audio
raw audio signal to features (e.g. mel-spectrogram or MFCC). The ``preprocessor`` section of the config specifies the audio
preprocessor to be used via the ``_target_`` field, as well as any initialization parameters for that preprocessor.

An example of specifying a preprocessor is as follows:
Expand All @@ -97,7 +97,7 @@ An example of specifying a preprocessor is as follows:
...
# Other parameters for the preprocessor

Refer to the `Audio Preprocessors <./api.html#Audio Preprocessors>`__ API section for the preprocessor options, expected arguments,
Refer to the `Audio Preprocessors <./api.html#Audio Preprocessors>`__ API section for the preprocessor options, expected arguments,
and defaults.

Augmentation Configurations
Expand Down Expand Up @@ -179,7 +179,7 @@ The following example sets up a ``SentencePiece Tokenizer`` at a path specified
dir: "<path to the directory that contains the custom tokenizer files>"
type: "bpe" # can be "bpe" or "wpe"

The Aggregate (``agg``) tokenizer feature makes it possible to combine tokenizers in order to train multilingual
The Aggregate (``agg``) tokenizer feature makes it possible to combine tokenizers in order to train multilingual
models. The config file would look like this:

.. code-block:: yaml
Expand All @@ -188,21 +188,21 @@ models. The config file would look like this:
...
tokenizer:
type: "agg" # aggregate tokenizer
langs:
langs:
en:
dir: "<path to the directory that contains the tokenizer files>"
type: "bpe" # can be "bpe" or "wpe"
es:
dir: "<path to the directory that contains the tokenizer files>"
type: "bpe" # can be "bpe" or "wpe"
type: "bpe" # can be "bpe" or "wpe"

In the above config file, each language is associated with its own pre-trained tokenizer, which gets assigned
a token id range in the order the tokenizers are listed. To train a multilingual model, one needs to populate the
In the above config file, each language is associated with its own pre-trained tokenizer, which gets assigned
a token id range in the order the tokenizers are listed. To train a multilingual model, one needs to populate the
``lang`` field in the manifest file, allowing the routing of each sample to the correct tokenizer. At inference time,
the routing is done based on the inferred token id range.

For models which utilize sub-word tokenization, we share the decoder module (``ConvASRDecoder``) with character tokenization models.
All parameters are shared, but for models which utilize sub-word encoding, there are minor differences when setting up the config. For
For models which utilize sub-word tokenization, we share the decoder module (``ConvASRDecoder``) with character tokenization models.
All parameters are shared, but for models which utilize sub-word encoding, there are minor differences when setting up the config. For
such models, the tokenizer is utilized to fill in the missing information when the model is constructed automatically.

For example, a decoder config corresponding to a sub-word tokenization model should look similar to the following:
Expand All @@ -221,7 +221,7 @@ For example, a decoder config corresponding to a sub-word tokenization model sho
Model Architecture Configurations
---------------------------------

Each configuration file should describe the model architecture being used for the experiment. Models in the NeMo ASR collection need
Each configuration file should describe the model architecture being used for the experiment. Models in the NeMo ASR collection need
an ``encoder`` section and a ``decoder`` section, with the ``_target_`` field specifying the module to use for each.

Here is the list of the parameters in the model section which are shared among most of the ASR models:
Expand Down Expand Up @@ -478,7 +478,7 @@ A Citrinet-512 config should look similar to the following:
se: ${model.model_defaults.se}
se_context_size: ${model.model_defaults.se_context_size}

As mentioned above, Citrinet uses the ``ConvASRDecoder`` as the decoder layer similar to QuartzNet. Only the configuration must be
As mentioned above, Citrinet uses the ``ConvASRDecoder`` as the decoder layer similar to QuartzNet. Only the configuration must be
changed slightly as Citrinet utilizes sub-word tokenization.

.. note::
Expand All @@ -499,8 +499,8 @@ The ``SqueezeExcite`` block within a :class:`~nemo.collections.asr.modules.conv_
Conformer-CTC
~~~~~~~~~~~~~

The config files for Conformer-CTC model contain character-based encoding and sub-word encoding at
``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_char.yaml`` and ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_bpe.yaml``
The config files for Conformer-CTC model contain character-based encoding and sub-word encoding at
``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_char.yaml`` and ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_bpe.yaml``
respectively. Some components of the configs of `Conformer-CTC <./models.html#Conformer-CTC>`__ include the following datasets:

* ``train_ds``, ``validation_ds``, and ``test_ds``
Expand All @@ -510,10 +510,11 @@ respectively. Some components of the configs of `Conformer-CTC <./models.html#Co
* ``trainer``
* ``exp_manager``

These datasets are similar to other ASR models like `QuartzNet <./models.html#QuartzNet>`__. There should be a tokenizer section where you can
These datasets are similar to other ASR models like `QuartzNet <./models.html#QuartzNet>`__. There should be a tokenizer section where you can
specify the tokenizer if you want to use sub-word encoding instead of character-based encoding.

The encoder section includes the details about the Conformer-CTC encoder architecture. You may find more information in the

The encoder section includes the details about the Conformer-CTC encoder architecture. You may find more information in the
config files and also :ref:`nemo.collections.asr.modules.ConformerEncoder <conformer-encoder-api>`.

Squeezeformer-CTC
Expand Down Expand Up @@ -546,6 +547,58 @@ The encoder section includes the details about the RNN-based encoder architectur
config files and also :ref:`nemo.collections.asr.modules.RNNEncoder <rnn-encoder-api>`.


InterCTC Config
---------------

All CTC-based models also support `InterCTC loss <https://arxiv.org/abs/2102.03216>`_. To use it, you need to specify
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the stochastic depth docs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the best place to put them in? They are currently in the code for ConformerEncoder and only supported there (but for any model that's using it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same place looks good to me. You may mention in the descriptions that it is just supported for Conformer-based models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this one is called "CTC Configurations", while stochastic depth is for both CTC and transducer (although only for conformer-based versions now)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant creating a new section in the same file titled "Stochastic Depth" or something like this.
You may also rename this section to "InterCTC Loss" instead of "CTC Configurations"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Rename to InterCTC Config and add new config Stochastic Depth. Note that conformer is only supported now. We can add more models in the future

2 parameters as in example below

.. code-block:: yaml

model:
# ...
interctc:
loss_weights: [0.3]
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
apply_at_layers: [8]

which can be used to reproduce the default setup from the paper (assuming the total number of layers is 18).
You can also specify multiple CTC losses from different layers, e.g., to get 2 losses from layers 3 and 8 with
weights 0.1 and 0.3, specify:

.. code-block:: yaml

model:
# ...
Kipok marked this conversation as resolved.
Show resolved Hide resolved
interctc:
loss_weights: [0.1, 0.3]
apply_at_layers: [3, 8]

Note that the final-layer CTC loss weight is automatically computed to normalize
all weight to 1 (0.6 in the example above).


Stochastic Depth Config
-----------------------

`Stochastic Depth <https://arxiv.org/abs/2102.03216>`_ is a useful technique for regularizing ASR model training.
Currently it's only supported for :ref:`nemo.collections.asr.modules.ConformerEncoder <conformer-encoder-api>`. To
use it, specify the following parameters in the encoder config file to reproduce the default setup from the paper:

.. code-block:: yaml

model:
# ...
encoder:
# ...
stochastic_depth_drop_prob: 0.3
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

See :ref:`documentation of ConformerEncoder <conformer-encoder-api>` for more details. Note that stochastic depth
is supported for both CTC and Transducer model variations (or any other kind of model/loss that's using
conformer as encoder).


Transducer Configurations
-------------------------

Expand Down
9 changes: 5 additions & 4 deletions examples/asr/asr_ctc/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# ASR with CTC Models

This directory contains example scripts to train ASR models using Connectionist Temporal Classification Loss.
This directory contains example scripts to train ASR models using Connectionist Temporal Classification Loss.

Currently supported models are -
Currently supported models are -

* Character based CTC model
* Subword based CTC model
Expand All @@ -21,8 +21,9 @@ graph TD
C --> E[Model]
B --> |Init| E[Model]
E --> |Constructor| F1(Change Vocabulary)
F1 --> F2(Setup Adapters if available)
F2 --> G(Setup Train + Validation + Test Data loaders)
F1 --> F2(Setup InterCTC if available)
F2 --> F3(Setup Adapters if available)
F3 --> G(Setup Train + Validation + Test Data loaders)
G --> H(Setup Optimization)
H --> I[Maybe init from pretrained]
I --> J["trainer.fit(model)"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,26 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

# config for InterCTC loss: https://arxiv.org/abs/2102.03216
# specify loss weights and which layers to use for InterCTC
# e.g., to reproduce the paper results, set loss_weights: [0.3]
# and apply_at_layers: [8] (assuming 18 layers). Note that final
# layer loss coefficient is automatically adjusted (to 0.7 in above example)
interctc:
loss_weights: []
apply_at_layers: []

optim:
name: adamw
lr: 2.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
Expand Down
14 changes: 14 additions & 0 deletions examples/asr/conf/conformer/conformer_ctc_bpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,26 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

# config for InterCTC loss: https://arxiv.org/abs/2102.03216
Kipok marked this conversation as resolved.
Show resolved Hide resolved
# specify loss weights and which layers to use for InterCTC
# e.g., to reproduce the paper results, set loss_weights: [0.3]
# and apply_at_layers: [8] (assuming 18 layers). Note that final
# layer loss coefficient is automatically adjusted (to 0.7 in above example)
Kipok marked this conversation as resolved.
Show resolved Hide resolved
interctc:
loss_weights: []
apply_at_layers: []

optim:
name: adamw
lr: 2.0
Expand Down
14 changes: 14 additions & 0 deletions examples/asr/conf/conformer/conformer_ctc_char.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,26 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: ${model.labels}

# config for InterCTC loss: https://arxiv.org/abs/2102.03216
# specify loss weights and which layers to use for InterCTC
# e.g., to reproduce the paper results, set loss_weights: [0.3]
# and apply_at_layers: [8] (assuming 18 layers). Note that final
# layer loss coefficient is automatically adjusted (to 0.7 in above example)
interctc:
loss_weights: []
apply_at_layers: []

optim:
name: adamw
lr: 2.0
Expand Down
5 changes: 5 additions & 0 deletions examples/asr/conf/conformer/conformer_transducer_bpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
Expand Down
5 changes: 5 additions & 0 deletions examples/asr/conf/conformer/conformer_transducer_char.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
Expand Down
14 changes: 14 additions & 0 deletions examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,26 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

# config for InterCTC loss: https://arxiv.org/abs/2102.03216
# specify loss weights and which layers to use for InterCTC
# e.g., to reproduce the paper results, set loss_weights: [0.3]
# and apply_at_layers: [8] (assuming 18 layers). Note that final
# layer loss coefficient is automatically adjusted (to 0.7 in above example)
interctc:
loss_weights: []
apply_at_layers: []

optim:
name: adamw
lr: 1e-3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ model:
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 0

decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
Expand All @@ -142,8 +147,8 @@ model:

# if a large vocabulary size is desired, you may wish to use SampleRNNTJoint module
# _target_: nemo.collections.asr.modules.SampledRNNTJoint
# n_samples: 500 # Specifies the minimum number of tokens to sample from the vocabulary space, excluding
# the RNNT blank token. If a given value is larger than the entire vocabulary size, then the full
# n_samples: 500 # Specifies the minimum number of tokens to sample from the vocabulary space, excluding
# the RNNT blank token. If a given value is larger than the entire vocabulary size, then the full
# vocabulary will be used
joint:
_target_: nemo.collections.asr.modules.RNNTJoint
Expand Down
Loading