Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup RNN-T greedy decoding #7926

Merged
merged 47 commits into from
Jan 16, 2024
Merged

Speedup RNN-T greedy decoding #7926

merged 47 commits into from
Jan 16, 2024

Conversation

artbataev
Copy link
Collaborator

@artbataev artbataev commented Nov 21, 2023

What does this PR do ?

New algorithm for greedy batched decoding for RNN-Transducer.
With large batch sizes (e.g., 128) the expected speedup for large Fast Conformer-Transducer (full evaluation time including Encoder) is 1.7x-1.9x (when using speech_to_text_eval.py). For small batch sizes, e.g., 16, the observed speedup is ~1.3x.
The original algorithm is preserved and can be enabled by using loop_labels=False

E.g., on my local machine, with bf16, bs=128, Fast Conformer-Transducer Large, full test-other decoding

Algorithm Greedy Greedy + Alignments
Current NeMo 45 sec 1 min 38 sec
Proposed 24 sec 30 sec

Collection: [ASR]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

# default - new decoding algorithm
python examples/asr/speech_to_text_eval.py \
   model_path=<nemo_model.nemo> \
   dataset_manifest=<manifest> \
   batch_size=128 \
   output_filename=<output_mainfest_path> 

# previous algorithm is preserved and can be used with `loop_labels=false`
python examples/asr/speech_to_text_eval.py \
   model_path=<nemo_model.nemo> \
   dataset_manifest=<manifest> \
   batch_size=128 \
   output_filename=<output_mainfest_path> \
   rnnt_decoding.greedy.loop_labels=false

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added the ASR label Nov 21, 2023
Copy link
Contributor

github-actions bot commented Dec 9, 2023

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Dec 9, 2023
@artbataev artbataev removed the stale label Dec 13, 2023
Signed-off-by: Vladimir Bataev <[email protected]>
@artbataev
Copy link
Collaborator Author

jenkins

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@artbataev
Copy link
Collaborator Author

jenkins

GNroy
GNroy previously approved these changes Jan 12, 2024
Copy link
Collaborator

@GNroy GNroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but see comments.
I'd like to especially commend your tests, thanks for improving NeMo!

nemo/collections/asr/modules/rnnt.py Outdated Show resolved Hide resolved
nemo/collections/asr/modules/rnnt.py Outdated Show resolved Hide resolved
nemo/collections/asr/modules/rnnt_abstract.py Outdated Show resolved Hide resolved
nemo/collections/asr/modules/rnnt_abstract.py Outdated Show resolved Hide resolved

# Use the following commented print statements to check
# the alignment of other algorithms compared to the default
print("Text", hyp.text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the following commented print statements

not commented

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copied from the code nearby.
I reworked the test: instead of just printing the alignment, I use non-batched greedy decoding as a reference, and check if the batched version returns the same results.

Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
@artbataev
Copy link
Collaborator Author

jenkins

@artbataev artbataev requested a review from GNroy January 15, 2024 13:54
GNroy
GNroy previously approved these changes Jan 15, 2024
Copy link
Collaborator

@GNroy GNroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work. Minor comments on inline documentation of the actual decoding loop and explain what is loop labels.

I also want to ask why the separation of joint into 3 functions - it seems ok but for example allows HAT to use less memory efficient path which can cause oom.

Finally, excellent tests, much better coverage of cases than before

@@ -138,9 +138,9 @@ def return_hat_ilm(self):
def return_hat_ilm(self, hat_subtract_ilm):
self._return_hat_ilm = hat_subtract_ilm

def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have the similar API name across the RNNT Joints, is it necessary to change this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is changed for all Joints, starting from AbstractRNNTJoint (see details in Slack)

Copy link
Collaborator Author

@artbataev artbataev Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it is the following:

class AbstractRNNTJoint(NeuralModule, ABC):
    @abstractmethod
    def project_encoder(self, encoder_output):
        raise NotImplementedError()  # can be Linear or identity

    @abstractmethod
    def project_prednet(self, encoder_output):
        raise NotImplementedError()  # can be Linear or identity

  @abstractmethod
  def joint_after_projection(self, f, g):
     """This is the main method that one should implement for Joint""" 
    raise NotImplementedError()
   
  def joint(self, f, g):
     """Full joint computation. Not abstract anymore!"""
    return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))

g = self.pred(g)
g.unsqueeze_(dim=1) # (B, 1, U, H)

f = f.unsqueeze(dim=2) # (B, T, 1, H)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the preemptive enc() pred() ? This is shown to be equivalent to RNNT and saves a ton of memory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace unsqueeze_ does not save memory.

Due to separating projections I needed to replace in-place unsqueeze_ operation with unsqueeze. There is no overhead in memory.
According to the documentation https://pytorch.org/docs/stable/generated/torch.unsqueeze.html

The returned tensor shares the same underlying data with this tensor.

You can check it manually:

import torch

device = torch.device('cuda:0')

def print_allocated(device, prefix=""):
    allocated_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
    print(f"{prefix}{allocated_mb:.0f}MB")


print_allocated(device, prefix="Before: ")  # Should be 0MB

# allocate memory ~projection result
data = torch.rand([128, 30 * 1000 // 10 // 8, 640], device=device)
print_allocated(device, prefix="After project encoder output: ")  # 118MB

# apply unsqueeze
data2 = data.unsqueeze(-1)  # unsqueeze returns a new tensor, but storage is the same (only metadata is new!)
print_allocated(device, prefix="After Unsqueeze: ")  # same, 118MB

"""
return self.pred(prednet_output)

def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert name change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is essential to separate projections from other joint computations. It introduces no memory/computational overhead. See details in slack

@@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC):
"""

@abstractmethod
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert name change. It's fine to keep joint

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments above

@@ -545,6 +573,7 @@ def __init__(
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in docstring what this isc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, missed the class docstring before)

if self.preserve_frame_confidence
else None,
)
advance_mask = torch.logical_and(blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document line

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

.squeeze(1)
.squeeze(1)
)
more_scores, more_labels = logits.max(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment (above this line)


# stage 4: to avoid looping, go to next frame after max_symbols emission
if self.max_symbols is not None:
force_blank_mask = torch.logical_and(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment above



class BatchedHyps:
"""Class to store batched hypotheses (labels, time_indices, scores) for efficient RNNT decoding"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very neat, this is done so that jit compile is happy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep) There is also a test that torch.jit is fine with this structure :)

return hypotheses


def return_empty_hypotheses(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty hys might be needed for beam search init and temp placeholders now that I remember

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this function, this was used only when max_symbols=0 for the new decoding algorithm

Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
@artbataev
Copy link
Collaborator Author

jenkins

Signed-off-by: Vladimir Bataev <[email protected]>
@artbataev
Copy link
Collaborator Author

jenkins

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After detailed explanation, the changes make sense design wize. If you could put those explanation in the PR itself it will help for future discussion. Thanks again for the significant speedup !

@artbataev
Copy link
Collaborator Author

Pasting here the discussion from Slack about Joint refactoring, joint_after_projection, non-inplace unsqueeze.

Main points:

  • (1) separating projections in Joint from other operations is not only needed for the "memory vs speed" tradeoff. It is also helpful for speed optimization without additional memory usage (this is done not only for (2))
  • (2) The immediate projection of encoder output is a tiny overhead, and I'm sure it is negligible compared to other operations for the RNN-T system
  • (3) inplace unsqueeze_ does not save memory, memory consumption is the same with unsqueeze, and there is no overhead for the change
  • (4) implementation of separation of projections – I tried to preserve compatibility, readability, and usability for inheritance.

1) separating projections in Joint from other operations is helpful in many cases.

Even in the original encoder algorithm, when we loop over encoder frames, we can project the frame immediately (one-by-one => no memory overhead), but this will save computations: for each encoder frame, multiple evaluations for Joint are used => we waste time when recalculating the encoder vector's projection.
The new algorithm is even more sensitive to operations in Joint, and I see a substantial speedup for separating projections

2) The immediate projection of encoder output is a tiny overhead

I see the speedup from projecting the encoder output immediately.
So, what's the overhead, and is it significant?

  • This could be considered a significant overhead when we used tiny encoders with linear memory/time complexity. For modern encoders with quadratic complexity (due to attention)
  • for bs 128, 30 sec, subsampling 8, joint_hidden=640, fp32, the size of tensor will be ~118MB, for bf16 – ~59MB
    • To compare with the memory consumption of one piece of the encoder, I tried a MultiHeadAttention block (used by Conformer). It uses ~2129MB memory (one block!) due to quadratic complexity. I'm sure that 118MB or even 59MB is a tiny piece of memory compared to modern encoders (it's from Conformer large, not x+large!)
  • From a practical point of view, I can easily fit bs 256, fp32 to my desktop GPU (LibriSpeech test-other, Fast Conformer Large), and we are targeting bf16, bs 128, and smaller.
  • Comparison with CTC system: We project to the final output with vocabulary size dimension (not one-by-one!), which is larger than RNN-T projection, and do not optimize this for better memory usage, sacrificing the speed

Given all these facts, it is acceptable to project the encoder output immediately. If we need a robust memory consumption optimization, we can use a separate flag (preserve_memory), but I don't think it is now required.

3) in-place unsqueeze_ does not save memory (no overhead after separating projections)

Due to separating projections, I needed to replace the in-place unsqueeze_ operation with unsqueeze. There is no overhead in memory.
According to the documentation https://pytorch.org/docs/stable/generated/torch.unsqueeze.html

The returned tensor shares the same underlying data with this tensor.

4) Implementation of separation of projections.

I think the acceptable solution should:

  • expose projections as the public API (AbstractRNNTJoint)
  • should be developer-friendly
  • should not lead to unnecessary code duplication
  • should not break checkpoints
  • should not introduce any significant overhead

I considered several possibilities.
a) we can duplicate the code for joint in joint_after_projection, but there I do not think it is a good practice to maintain the same code in 2 places (it must be the same except applying projections)

b) using enc() and pred() as functions: undesirable, since it will break the checkpoints.

c) use enc and pred with type annotations in public API

class AbstractRNNTJoint(NeuralModule, ABC):
   enc: Callable # can be Identity
   pred: Callable # can be Identity
   
  @abstractmethod
  def joint_after_projection(self, f, g):
     """This is the main method that one should implement for Joint""" 
    raise NotImplementedError()
   
  def joint(self, f, g):
     # not abstract anymore!
    return self.joint_after_projection(self.enc(f), self.pred(g))

d) separate abstract project_prednet and project_encoder methods – the current solution
I think it is better since if project_prednet and project_encoder are not implemented, a clear error will indicate this.

I prefer the last one because there is no overhead for the current implementation (see (3)).

@artbataev artbataev merged commit 410f092 into main Jan 16, 2024
15 checks passed
@artbataev artbataev deleted the speedup_rnnt_greedy_decoding branch January 16, 2024 20:31
jubick1337 pushed a commit that referenced this pull request Jan 17, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
minitu pushed a commit to minitu/NeMo that referenced this pull request Jan 19, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
stevehuang52 pushed a commit that referenced this pull request Jan 31, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: stevehuang52 <[email protected]>
ssh-meister pushed a commit to ssh-meister/NeMo that referenced this pull request Feb 15, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Sasha Meister <[email protected]>
pablo-garay pushed a commit that referenced this pull request Mar 19, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Pablo Garay <[email protected]>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants