Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup RNN-T greedy decoding #7926

Merged
merged 47 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9342489
Add structure for batched hypotheses
artbataev Nov 21, 2023
7bcc4c0
Add faster decoding algo
artbataev Nov 21, 2023
7a0942f
Simplify max_symbols support. More speedup
artbataev Nov 22, 2023
26ec40c
Clean up
artbataev Nov 22, 2023
1d556ea
Clean up
artbataev Nov 22, 2023
cf631dd
Filtering only when necessary
artbataev Nov 22, 2023
a50965d
Move max_symbols check to the end of loop
artbataev Nov 22, 2023
510eb90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
659cfff
Support returning prediction network states
artbataev Nov 22, 2023
40d1568
Support preserve_alignments flag
artbataev Nov 22, 2023
ca2d94b
Support confidence
artbataev Nov 22, 2023
b328fac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
7997bd6
Partial fix for jit compatibility
artbataev Nov 23, 2023
6f7746b
Merge branch 'main' into speedup_rnnt_greedy_decoding
artbataev Nov 23, 2023
95da9d1
Support switching between decoding algorithms
artbataev Nov 23, 2023
ef35381
Fix switching algorithms
artbataev Nov 23, 2023
ca5779d
Clean up
artbataev Nov 23, 2023
97092ff
Clean up
artbataev Nov 23, 2023
c9785ff
Fix max symbols per step
artbataev Nov 23, 2023
1e09979
Add tests. Preserve torch.jit compatibility for BatchedHyps
artbataev Nov 24, 2023
f4b7b68
Separate projection from Joint calculation in decoding
artbataev Dec 13, 2023
d67b14b
Fix config instantiation
artbataev Dec 13, 2023
c7d298d
Merge remote-tracking branch 'origin/main' into speedup_rnnt_greedy_d…
artbataev Jan 10, 2024
2ea8f7f
Fix after main merge
artbataev Jan 10, 2024
5c8e18e
Add tests for batched hypotheses
artbataev Jan 10, 2024
e8c43d0
Speedup alignments
artbataev Jan 10, 2024
ffe2a67
Test alignments
artbataev Jan 10, 2024
77bf674
Fix alignments
artbataev Jan 10, 2024
02a9bbd
Fix tests for alignments
artbataev Jan 11, 2024
83c4793
Add more tests
artbataev Jan 11, 2024
430e159
Fix confidence tests
artbataev Jan 11, 2024
266be2c
Avoid common package modification
artbataev Jan 11, 2024
ce33493
Support Stateless prediction network
artbataev Jan 11, 2024
9d545ee
Improve stateless decoder support. Separate alignments and confidence
artbataev Jan 11, 2024
9669149
Fix alignments for max_symbols_per_step
artbataev Jan 11, 2024
1dbf29e
Fix alignments for max_symbols_per_step=0
artbataev Jan 11, 2024
b4421cd
Fix tests
artbataev Jan 12, 2024
3e1ca1e
Fix test
artbataev Jan 12, 2024
1b97e33
Add comments
artbataev Jan 12, 2024
4429432
Batched Hyps/Alignments: lengths -> current_lengths
artbataev Jan 12, 2024
b7b83df
Simplify indexing
artbataev Jan 12, 2024
3df991a
Improve type annotations
artbataev Jan 15, 2024
31649fa
Rework test for greedy decoding
artbataev Jan 15, 2024
5f67c66
Document loop_labels
artbataev Jan 16, 2024
df86b17
Raise ValueError if max_symbols_per_step <= 0
artbataev Jan 16, 2024
0f4463b
Add comments
artbataev Jan 16, 2024
c38f222
Fix test
artbataev Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions nemo/collections/asr/modules/hybrid_autoregressive_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def return_hat_ilm(self):
def return_hat_ilm(self, hat_subtract_ilm):
self._return_hat_ilm = hat_subtract_ilm

def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have the similar API name across the RNNT Joints, is it necessary to change this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is changed for all Joints, starting from AbstractRNNTJoint (see details in Slack)

Copy link
Collaborator Author

@artbataev artbataev Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it is the following:

class AbstractRNNTJoint(NeuralModule, ABC):
    @abstractmethod
    def project_encoder(self, encoder_output):
        raise NotImplementedError()  # can be Linear or identity

    @abstractmethod
    def project_prednet(self, encoder_output):
        raise NotImplementedError()  # can be Linear or identity

  @abstractmethod
  def joint_after_projection(self, f, g):
     """This is the main method that one should implement for Joint""" 
    raise NotImplementedError()
   
  def joint(self, f, g):
     """Full joint computation. Not abstract anymore!"""
    return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))

"""
Compute the joint step of the network.
Compute the joint step of the network after Encoder/Decoder projection.

Here,
B = Batch size
Expand Down Expand Up @@ -169,14 +169,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJoin
Log softmaxed tensor of shape (B, T, U, V + 1).
Internal LM probability (B, 1, U, V) -- in case of return_ilm==True.
"""
# f = [B, T, H1]
f = self.enc(f)
f.unsqueeze_(dim=2) # (B, T, 1, H)

# g = [B, U, H2]
g = self.pred(g)
g.unsqueeze_(dim=1) # (B, 1, U, H)

f = f.unsqueeze(dim=2) # (B, T, 1, H)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the preemptive enc() pred() ? This is shown to be equivalent to RNNT and saves a ton of memory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace unsqueeze_ does not save memory.

Due to separating projections I needed to replace in-place unsqueeze_ operation with unsqueeze. There is no overhead in memory.
According to the documentation https://pytorch.org/docs/stable/generated/torch.unsqueeze.html

The returned tensor shares the same underlying data with this tensor.

You can check it manually:

import torch

device = torch.device('cuda:0')

def print_allocated(device, prefix=""):
    allocated_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
    print(f"{prefix}{allocated_mb:.0f}MB")


print_allocated(device, prefix="Before: ")  # Should be 0MB

# allocate memory ~projection result
data = torch.rand([128, 30 * 1000 // 10 // 8, 640], device=device)
print_allocated(device, prefix="After project encoder output: ")  # 118MB

# apply unsqueeze
data2 = data.unsqueeze(-1)  # unsqueeze returns a new tensor, but storage is the same (only metadata is new!)
print_allocated(device, prefix="After Unsqueeze: ")  # same, 118MB

g = g.unsqueeze(dim=1) # (B, 1, U, H)
inp = f + g # [B, T, U, H]

del f
Expand Down
71 changes: 60 additions & 11 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,22 @@ def batch_copy_states(

return old_states

def mask_select_states(
self, states: Optional[List[torch.Tensor]], mask: torch.Tensor
) -> Optional[List[torch.Tensor]]:
"""
Return states by mask selection
Args:
states: states for the batch
mask: boolean mask for selecting states; batch dimension should be the same as for states

Returns:
states filtered by mask
"""
if states is None:
return None
return [states[0][mask]]

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
Expand Down Expand Up @@ -1047,6 +1063,21 @@ def batch_copy_states(

return old_states

def mask_select_states(
self, states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return states by mask selection
Args:
states: states for the batch
mask: boolean mask for selecting states; batch dimension should be the same as for states

Returns:
states filtered by mask
"""
# LSTM in PyTorch returns a tuple of 2 tensors as a state
return states[0][:, mask], states[1][:, mask]

# Adapter method overrides
def add_adapter(self, name: str, cfg: DictConfig):
# Update the config with correct input dim
Expand Down Expand Up @@ -1382,9 +1413,33 @@ def forward(

return losses, wer, wer_num, wer_denom

def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor:
"""
Project the encoder output to the joint hidden dimension.

Args:
encoder_output: A torch.Tensor of shape [B, T, D]

Returns:
A torch.Tensor of shape [B, T, H]
"""
return self.enc(encoder_output)

def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor:
"""
Project the Prediction Network (Decoder) output to the joint hidden dimension.

Args:
prednet_output: A torch.Tensor of shape [B, U, D]

Returns:
A torch.Tensor of shape [B, U, H]
"""
return self.pred(prednet_output)

def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert name change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is essential to separate projections from other joint computations. It introduces no memory/computational overhead. See details in slack

"""
Compute the joint step of the network.
Compute the joint step of the network after projection.

Here,
B = Batch size
Expand Down Expand Up @@ -1412,14 +1467,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
"""
# f = [B, T, H1]
f = self.enc(f)
f.unsqueeze_(dim=2) # (B, T, 1, H)

# g = [B, U, H2]
g = self.pred(g)
g.unsqueeze_(dim=1) # (B, 1, U, H)

f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)
inp = f + g # [B, T, U, H]

del f, g
Expand Down Expand Up @@ -1536,7 +1585,7 @@ def set_fuse_loss_wer(self, fuse_loss_wer, loss=None, metric=None):

@property
def fused_batch_size(self):
return self._fuse_loss_wer
return self._fused_batch_size

def set_fused_batch_size(self, fused_batch_size):
self._fused_batch_size = fused_batch_size
Expand Down
53 changes: 52 additions & 1 deletion nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC):
"""

@abstractmethod
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert name change. It's fine to keep joint

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments above

"""
Compute the joint step of the network after the projection step.
Args:
f: Output of the Encoder model after projection. A torch.Tensor of shape [B, T, H]
g: Output of the Decoder model (Prediction Network) after projection. A torch.Tensor of shape [B, U, H]

Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
Arbitrary return type, preferably torch.Tensor, but not limited to (e.g., see HatJoint)
"""
raise NotImplementedError()

@abstractmethod
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor:
"""
Project the encoder output to the joint hidden dimension.

Args:
encoder_output: A torch.Tensor of shape [B, T, D]

Returns:
A torch.Tensor of shape [B, T, H]
"""
raise NotImplementedError()

@abstractmethod
def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor:
"""
Project the Prediction Network (Decoder) output to the joint hidden dimension.

Args:
prednet_output: A torch.Tensor of shape [B, U, D]

Returns:
A torch.Tensor of shape [B, U, H]
"""
raise NotImplementedError()

def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
"""
Compute the joint step of the network.
Expand Down Expand Up @@ -58,7 +97,7 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
"""
raise NotImplementedError()
return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))

@property
def num_classes_with_blank(self):
Expand Down Expand Up @@ -277,3 +316,15 @@ def batch_copy_states(
(L x B x H, L x B x H)
"""
raise NotImplementedError()

def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any:
"""
Return states by mask selection
Args:
states: states for the batch (preferably a list of tensors, but not limited to)
mask: boolean mask for selecting states; batch dimension should be the same as for states

Returns:
states filtered by mask (same type as `states`)
"""
raise NotImplementedError()
5 changes: 3 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand Down Expand Up @@ -1495,8 +1496,8 @@ class RNNTDecodingConfig:
rnnt_timestamp_type: str = "all" # can be char, word or all for both

# greedy decoding config
greedy: rnnt_greedy_decoding.GreedyRNNTInferConfig = field(
default_factory=lambda: rnnt_greedy_decoding.GreedyRNNTInferConfig()
greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field(
default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig
)

# beam decoding config
Expand Down
Loading
Loading