Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Autoregressive Transducer (HAT) #6260

Merged
merged 33 commits into from
Mar 24, 2023
Merged

Conversation

andrusenkoau
Copy link
Collaborator

What does this PR do ?

Add HAT model as a new joint network type (HATJoint) for RNNT model. The difference is only in decoding time -- HAT.joint.joint returns two outputs: hat_logprobs and internal_lm_logprobs (for internal lm subtraction in case of Shallow Fusion with external n-gram LM).

Collection: [ASR]

Usage

  • For HAT model training you need replace _target_: nemo.collections.asr.modules.RNNTJoint with _target_: nemo.collections.asr.modules.HATJoint in joint part of standard transducer config.

  • For Shallow Fusion with external n-gram LM use RNNT maes decoding algorithm which is able to work with HATJoint model.

# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added the ASR label Mar 20, 2023
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a bit of refactoring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filename should be full "hybrid_autoregressive_transducer.py"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from nemo.utils import logging


class HATJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is duplicating a lot of code from RNNTJoint. Would it make sense to subclass it ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great comment. I took the RNNTJoint as a parent class and left only several modifications for new HATJoint class. Check it pls.

@@ -460,7 +466,12 @@ def greedy_search(

# TODO: Figure out how to remove this hard coding afterwords
while not_blank and (symbols_added < 5):
ytu = torch.log_softmax(self.joint.joint(hi, y) / self.softmax_temperature, dim=-1) # [1, 1, 1, V + 1]
if isinstance(self.joint, HATJoint):
ytu, _ = self.joint.joint(hi, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kinda logic is problematic in the long run. Why not take a bool in the HAT module that determine what self.joint returns - by default it's set and returns both items, otherwise return things in the form of RNNT so that this code doesn't need to change

Copy link
Collaborator Author

@andrusenkoau andrusenkoau Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that jit compiler does not like variable outputs number. Now I made default mode -- return only logprobs (like the standard rnnt joint) and return both logprobs and internal_lm_logprobs (in case of special boolean flag). This is allowed to save more rnnt decoding code unchanged.

@@ -34,6 +34,7 @@
from omegaconf import DictConfig

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.modules.hat import HATJoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No modules should be imported inside of Greedy of Beam decoding libraries because it will eventually cause circular dependency

Copy link
Collaborator Author

@andrusenkoau andrusenkoau Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is no longer needed due to the new default hat.joint.joint logic (the same as rnnt).

nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
tests/collections/asr/test_asr_modules.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
nemo/collections/asr/modules/hat.py Fixed Show fixed Hide fixed
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 10 potential problems in the proposed changes. Check the Files changed tab for more details.

@andrusenkoau andrusenkoau marked this pull request as ready for review March 22, 2023 05:22
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the code is too circular for HAT import. Another thing is it requires too many modifications to an already very complicated function (mAES).

The first thing we can make more generic with dataclass and property trucks. Those changes are relatively simple but require some refactor.

The second one I dunno how to make more generic. Perhaps an abstract method inside of AbstractRNNTJoint that discussed how to do special forward of joint ? That's a heavy refactor so ignore it for now.

@@ -34,6 +34,7 @@
import torch
from tqdm import tqdm

from nemo.collections.asr.modules import hybrid_autoregressive_transducer as hat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so this import doesn't actually fix circular import - think of it like this

RNNTModel needs EncDecJoint, Loss, Decoding, Metric
Decoding depends on Decoder + Joint
Metric depends on Decoding.
Joint depends on loss and metric.

But now decoding itself imports the joint module. That's fine for now but can be more circular and crash in the future. I'll discuss an alternative below


res = torch.cat((label_logprob_scaled, blank_logprob), dim=-1).contiguous() # [B, T, U, V+1]

if return_ilm:
Copy link
Collaborator

@titu1994 titu1994 Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it seems incorrect to return a tuple here. Let's do this instead -
In rnnt_utils.py create a dataclass call HATJointOutput. It has just two value - a tensor for logprobs and a tensor for ilm. Both are none by default.

If return_ilm property of this class is set, you will build an object of this dataclass, put the two values and return that

More details below


def joint(
self, f: torch.Tensor, g: torch.Tensor, return_ilm: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove return_ilm from here, use the properties

beam_logp, beam_idx = torch.log_softmax(
self.joint.joint(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1,
).topk(self.max_candidates, dim=-1)
if isinstance(self.joint, hat.HATJoint) and self.hat_subtract_ilm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, everywhere, simply call the self.joint.joint(with the ordinary arguments for RNNT). The output can now be either torch.Tensor - (RNNT joint, HAT without the ILM subtract) or it can be HATOutput dataclass.

import RNNT utils and then check if torch.is_tensor(output) here - this is for og RNNT. Elif self.hat_subtract_ilm and isinstance(output, HATOutput):

Then do the required code path. On else path, raise error saying could not resolve the output

@@ -1196,7 +1206,12 @@ def modified_adaptive_expansion_search(
lm_score, new_hyp.ngram_lm_state = self.compute_ngram_score(
hyp.ngram_lm_state, int(k)
)
new_hyp.score += self.ngram_lm_alpha * lm_score
if isinstance(self.joint, hat.HATJoint) and self.hat_subtract_ilm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for everywhere else below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @titu1994! Thank you for detailed review. I tried to modify HAT related code according to your suggestions. For convenience I also added resolve_joint_output function. Check it pls.

from nemo.collections.asr.modules import rnnt
from nemo.collections.asr.parts.utils.rnnt_utils import HATJointOutput

from nemo.utils import logging

Check notice

Code scanning / CodeQL

Unused import

Import of 'logging' is not used.
Comment on lines +123 to +130
self.pred, self.enc, self.joint_net, self.blank_pred = self._joint_hat_net_modules(
num_classes=self._vocab_size, # non blank symbol
pred_n_hidden=self.pred_hidden,
enc_n_hidden=self.encoder_hidden,
joint_n_hidden=self.joint_hidden,
activation=self.activation,
dropout=jointnet.get('dropout', 0.0),
)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class

Assignment overwrites attribute pred, which was previously defined in superclass [RNNTJoint](1). Assignment overwrites attribute enc, which was previously defined in superclass [RNNTJoint](1). Assignment overwrites attribute joint_net, which was previously defined in superclass [RNNTJoint](1).
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union

Check notice

Code scanning / CodeQL

Unused import

Import of 'Tuple' is not used.
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks very good now, could you add some tests that assert that normally forward returns tensor and hat forward with and without flag set returns either tensor or HATJointOutput.

Another thing is we support only mAES and normal beam, can you look into the complexity of the other beam algos to support hat ? If it's difficult we can leave it to another pr in the future

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we support only basic beam and maes. Can you look into supporting HAT with other algos ? If it's simple, it can be done in this pr, if not in another pr.

Copy link
Collaborator Author

@andrusenkoau andrusenkoau Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean n-gram LM fusion (with RNNT and HAT) for other decoding algorithms? Now only maes algorithm supports LM fusion. I did not do it for default beam search because it works too slow. I do not think anyone wants to use it because of speed.

BTW, all the decoding algorithms can work now with HAT model without LM fusion because HATJoint has the same default output type like RNNTJoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok sounds good then

@kobenaxie
Copy link

Hi @andrusenkoau , the HATJoint returns log softmaxed log_prob, but the rnnt_loss in torchaudio or rnnt_pytorch receives logits without logsoftmax, should this be unified ?

@titu1994
Copy link
Collaborator

@kobenaxie that's a template implementation of the loss using pure PyTorch, it is not used during actual training since it is super slow. Instead we use numba bases cuda compiled loss.

Also, hat during training does not return the dataclass (which the loss anyway would not accept) so it is fine

@titu1994
Copy link
Collaborator

Looks great !

@titu1994 titu1994 merged commit 2e36872 into NVIDIA:main Mar 24, 2023
@titu1994
Copy link
Collaborator

Final things to do are to add HAT decoder based conformer config to a conf dir called conf/hat_transducer/conformer/conformer_hat_bpe.yaml / char.yaml

@titu1994
Copy link
Collaborator

That can be done when release bench is cut.

@andrusenkoau
Copy link
Collaborator Author

Hi @andrusenkoau , the HATJoint returns log softmaxed log_prob, but the rnnt_loss in torchaudio or rnnt_pytorch receives logits without logsoftmax, should this be unified ?

Hi @kobenaxie, HAT logic demands to work in the probability domain in order to calculate blank probability and then scale labels probability. For the implementation simplicity we can use the rule -- logsoftmax(logsoftmax(x)) = logsoftmax(x) => it is possible to use HAT log_probs output with any rnnt loss functions which have logsoftmax calculation inside to get final model loss.

@andrusenkoau
Copy link
Collaborator Author

Looks great !

@titu1994 thank you so much for great review and help with code modification!

hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
* add hat joint network

Signed-off-by: andrusenkoau <[email protected]>

* add HATJoint module

Signed-off-by: andrusenkoau <[email protected]>

* add hat script

Signed-off-by: andrusenkoau <[email protected]>

* add hat decoding option

Signed-off-by: andrusenkoau <[email protected]>

* add hat related parameters to maes decoding

Signed-off-by: andrusenkoau <[email protected]>

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

* add hat decoding option

Signed-off-by: andrusenkoau <[email protected]>

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

* add hat related parameters

Signed-off-by: andrusenkoau <[email protected]>

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

* add hat to all rnnt decoding types

Signed-off-by: andrusenkoau <[email protected]>

* add test for hatjoint

Signed-off-by: andrusenkoau <[email protected]>

* combine hatjoint with all rnntjoint tests

Signed-off-by: andrusenkoau <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

* rename hat file

Signed-off-by: andrusenkoau <[email protected]>

* fix hat double output

Signed-off-by: andrusenkoau <[email protected]>

* fix hat double output

Signed-off-by: andrusenkoau <[email protected]>

* fix hat double output

Signed-off-by: andrusenkoau <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py

Signed-off-by: andrusenkoau <[email protected]>

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

* add return_hat_ilm property

Signed-off-by: andrusenkoau <[email protected]>

* add HATJointOutput dataclass

Signed-off-by: andrusenkoau <[email protected]>

* add resolve_joint_output function

Signed-off-by: andrusenkoau <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add local return_hat_ilm_default variable

Signed-off-by: andrusenkoau <[email protected]>

* minor fixes

Signed-off-by: andrusenkoau <[email protected]>

---------

Signed-off-by: andrusenkoau <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants