Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup RNN-T greedy decoding #7926

Merged
merged 47 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9342489
Add structure for batched hypotheses
artbataev Nov 21, 2023
7bcc4c0
Add faster decoding algo
artbataev Nov 21, 2023
7a0942f
Simplify max_symbols support. More speedup
artbataev Nov 22, 2023
26ec40c
Clean up
artbataev Nov 22, 2023
1d556ea
Clean up
artbataev Nov 22, 2023
cf631dd
Filtering only when necessary
artbataev Nov 22, 2023
a50965d
Move max_symbols check to the end of loop
artbataev Nov 22, 2023
510eb90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
659cfff
Support returning prediction network states
artbataev Nov 22, 2023
40d1568
Support preserve_alignments flag
artbataev Nov 22, 2023
ca2d94b
Support confidence
artbataev Nov 22, 2023
b328fac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
7997bd6
Partial fix for jit compatibility
artbataev Nov 23, 2023
6f7746b
Merge branch 'main' into speedup_rnnt_greedy_decoding
artbataev Nov 23, 2023
95da9d1
Support switching between decoding algorithms
artbataev Nov 23, 2023
ef35381
Fix switching algorithms
artbataev Nov 23, 2023
ca5779d
Clean up
artbataev Nov 23, 2023
97092ff
Clean up
artbataev Nov 23, 2023
c9785ff
Fix max symbols per step
artbataev Nov 23, 2023
1e09979
Add tests. Preserve torch.jit compatibility for BatchedHyps
artbataev Nov 24, 2023
f4b7b68
Separate projection from Joint calculation in decoding
artbataev Dec 13, 2023
d67b14b
Fix config instantiation
artbataev Dec 13, 2023
c7d298d
Merge remote-tracking branch 'origin/main' into speedup_rnnt_greedy_d…
artbataev Jan 10, 2024
2ea8f7f
Fix after main merge
artbataev Jan 10, 2024
5c8e18e
Add tests for batched hypotheses
artbataev Jan 10, 2024
e8c43d0
Speedup alignments
artbataev Jan 10, 2024
ffe2a67
Test alignments
artbataev Jan 10, 2024
77bf674
Fix alignments
artbataev Jan 10, 2024
02a9bbd
Fix tests for alignments
artbataev Jan 11, 2024
83c4793
Add more tests
artbataev Jan 11, 2024
430e159
Fix confidence tests
artbataev Jan 11, 2024
266be2c
Avoid common package modification
artbataev Jan 11, 2024
ce33493
Support Stateless prediction network
artbataev Jan 11, 2024
9d545ee
Improve stateless decoder support. Separate alignments and confidence
artbataev Jan 11, 2024
9669149
Fix alignments for max_symbols_per_step
artbataev Jan 11, 2024
1dbf29e
Fix alignments for max_symbols_per_step=0
artbataev Jan 11, 2024
b4421cd
Fix tests
artbataev Jan 12, 2024
3e1ca1e
Fix test
artbataev Jan 12, 2024
1b97e33
Add comments
artbataev Jan 12, 2024
4429432
Batched Hyps/Alignments: lengths -> current_lengths
artbataev Jan 12, 2024
b7b83df
Simplify indexing
artbataev Jan 12, 2024
3df991a
Improve type annotations
artbataev Jan 15, 2024
31649fa
Rework test for greedy decoding
artbataev Jan 15, 2024
5f67c66
Document loop_labels
artbataev Jan 16, 2024
df86b17
Raise ValueError if max_symbols_per_step <= 0
artbataev Jan 16, 2024
0f4463b
Add comments
artbataev Jan 16, 2024
c38f222
Fix test
artbataev Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
)
else:
self.decoding = greedy_decode.GreedyBatchedTDTInfer(
Expand Down Expand Up @@ -1317,7 +1318,9 @@
rnnt_timestamp_type: str = "all" # can be char, word or all for both

# greedy decoding config
greedy: greedy_decode.GreedyRNNTInferConfig = field(default_factory=lambda: greedy_decode.GreedyRNNTInferConfig())
greedy: greedy_decode.GreedyBatchedRNNTInferConfig = field(
default_factory=lambda: greedy_decode.GreedyBatchedRNNTInferConfig()
Fixed Show fixed Hide fixed
)

# beam decoding config
beam: beam_decode.BeamRNNTInferConfig = field(default_factory=lambda: beam_decode.BeamRNNTInferConfig(beam_size=4))
Expand Down
192 changes: 189 additions & 3 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
logitlen_cpu = logitlen

for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis
hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)
hyp.y_sequence = (
hyp.y_sequence.to(torch.long)
if isinstance(hyp.y_sequence, torch.Tensor)
else torch.tensor(hyp.y_sequence, dtype=torch.long)
)
hyp.length = logitlen_cpu[idx]

if hyp.dec_state is not None:
Expand Down Expand Up @@ -545,6 +549,7 @@
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in docstring what this isc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, missed the class docstring before)

):
super().__init__(
decoder_model=decoder_model,
Expand All @@ -559,7 +564,12 @@
# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
if self.decoder.blank_as_pad:
self._greedy_decode = self._greedy_decode_blank_as_pad
if loop_labels:
# default (faster) algo: loop over labels
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
else:
# previous algo: loop over frames
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
self._greedy_decode = self._greedy_decode_masked

Expand Down Expand Up @@ -607,7 +617,182 @@

return (packed_result,)

def _greedy_decode_blank_as_pad(
@torch.inference_mode()
def _greedy_decode_blank_as_pad_loop_labels(
self,
x: torch.Tensor,
out_len: torch.Tensor,
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
GNroy marked this conversation as resolved.
Show resolved Hide resolved
"""
Optimized batched greedy decoding.
The main idea: search for next labels for the whole batch (evaluating Joint)
and thus always evaluate prediction network with maximum possible batch size
"""
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not implemented")

batch_size, max_time, _ = x.shape

# Initialize empty hypotheses and all necessary tensors
batched_hyps = rnnt_utils.BatchedHyps(
batch_size=batch_size, init_length=max_time, device=x.device, float_dtype=x.dtype
)
time_indices = torch.zeros([batch_size], dtype=torch.long, device=device) # always of batch_size
active_indices = torch.arange(batch_size, dtype=torch.long, device=device) # initial: all indices
labels = torch.full([batch_size], fill_value=self._blank_index, dtype=torch.long, device=device)
state = None

# init additional structs for hypotheses: last decoder state, alignments, frame_confidence
last_decoder_state = [None for _ in range(batch_size)]
if self.preserve_alignments:
alignments: List[List[List[tuple[torch.Tensor, torch.Tensor]]]] = [
[[] for _ in range(out_len[i].item())] for i in range(batch_size)
]
if self.preserve_frame_confidence:
frame_confidence: List[List[List[torch.Tensor]]] = [
[[] for _ in range(out_len[i].item())] for i in range(batch_size)
]

# loop while there are active indices
while (current_batch_size := active_indices.shape[0]) > 0:
# stage 1: get decoder (prediction network) output
if state is None:
decoder_output, state, *_ = self._pred_step(self._SOS, state, batch_size=current_batch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work in documenting inline

Please be super explicit in documenting every line of decoding logic so that future reader has full knowledge of what is going on at every line.
It is necessary cause decoding loop is super complicated for rnnt

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments for this if

else:
decoder_output, state, *_ = self._pred_step(labels.unsqueeze(1), state, batch_size=current_batch_size)

# stage 2: get joint output, iteratively seeking for non-blank labels
# blank label in `labels` tensor means "end of hypothesis" (for this index)
logits = (
self._joint_step(
x[active_indices, time_indices[active_indices]].unsqueeze(1),
decoder_output,
log_normalize=True if self.preserve_frame_confidence else None,
)
.squeeze(1)
.squeeze(1)
)
scores, labels = logits.max(-1)

# search for non-blank labels using joint, advancing time indices for blank labels
# checking max_symbols is not needed, since we already forced advancing time indices for such cases
blank_mask = labels == self._blank_index
advance_mask = torch.logical_and(blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document line

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

if self.preserve_alignments:
for i, global_batch_idx in enumerate(active_indices.cpu().numpy()):
alignments[global_batch_idx][time_indices[global_batch_idx]].append(
(logits[i].cpu(), labels[i].cpu())
)
if self.preserve_frame_confidence:
confidence = self._get_confidence(logits)
for i, global_batch_idx in enumerate(active_indices.cpu().numpy()):
frame_confidence[global_batch_idx][time_indices[global_batch_idx]].append(confidence[i])
while advance_mask.any(): # .item()?
advance_indices = active_indices[advance_mask]
time_indices[advance_indices] += 1
logits = (
self._joint_step(
x[advance_indices, time_indices[advance_indices]].unsqueeze(1),
decoder_output[advance_mask],
log_normalize=True if self.preserve_frame_confidence else None,
)
.squeeze(1)
.squeeze(1)
)
more_scores, more_labels = logits.max(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment (above this line)

labels[advance_mask] = more_labels
scores[advance_mask] = more_scores
if self.preserve_alignments:
for i, global_batch_idx in enumerate(advance_indices.cpu().numpy()):
alignments[global_batch_idx][time_indices[global_batch_idx]].append(
(logits[i].cpu(), more_labels[i].cpu())
)
if self.preserve_frame_confidence:
confidence = self._get_confidence(logits)
for i, global_batch_idx in enumerate(advance_indices.cpu().numpy()):
frame_confidence[global_batch_idx][time_indices[global_batch_idx]].append(confidence[i])
blank_mask = labels == self._blank_index
advance_mask = torch.logical_and(
blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices])
)

# stage 3: filter labels and state, store hypotheses
# the only case, when there are blank labels in predictions - when we found the end for some utterances
if blank_mask.any():
non_blank_mask = ~blank_mask
labels = labels[non_blank_mask]
scores = scores[non_blank_mask]

# select states for hyps that became inactive (is it necessary?)
inactive_global_indices = active_indices[blank_mask]
inactive_inner_indices = torch.arange(current_batch_size, device=device, dtype=torch.long)[blank_mask]
for idx, batch_idx in zip(inactive_global_indices.cpu().numpy(), inactive_inner_indices.cpu().numpy()):
last_decoder_state[idx] = self.decoder.batch_select_state(state, batch_idx)

# update active indices and state
active_indices = active_indices[non_blank_mask]
if isinstance(state, torch.Tensor):
state = state[non_blank_mask]
elif isinstance(state, (tuple, list)):
# tuple is immutable, convert temporary to list
state = list(state)
for i in range(len(state)):
state[i] = state[i][:, non_blank_mask]
state = tuple(state)
# store hypotheses
batched_hyps.add_results_(
active_indices, labels, time_indices[active_indices].clone(), scores,
)

# stage 4: to avoid looping, go to next frame after max_symbols emission
if self.max_symbols is not None:
force_blank_mask = torch.logical_and(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment above

torch.logical_and(
labels != self._blank_index,
batched_hyps.last_timestep_lasts[active_indices] >= self.max_symbols,
),
batched_hyps.last_timestep[active_indices] == time_indices[active_indices],
)
if (self.preserve_alignments or self.preserve_frame_confidence) and force_blank_mask.any():
# we do not need output for forced blank in a general case, since hidden state will be the same
# but for preserving alignments/confidence we need a tensor with logits
force_blank_indices = active_indices[force_blank_mask]
logits = (
self._joint_step(
x[force_blank_indices, time_indices[force_blank_indices]].unsqueeze(1),
decoder_output,
log_normalize=True if self.preserve_frame_confidence else None,
)
.squeeze(1)
.squeeze(1)
)
if self.preserve_alignments:
for i, global_batch_idx in enumerate(force_blank_indices.cpu().numpy()):
alignments[global_batch_idx][time_indices[global_batch_idx]].append(
(logits[i].cpu(), torch.tensor(self._blank_index, device=device))
)
if self.preserve_frame_confidence:
confidence = self._get_confidence(logits)
for i, global_batch_idx in enumerate(force_blank_indices.cpu().numpy()):
frame_confidence[global_batch_idx][time_indices[global_batch_idx]].append(confidence[i])
time_indices[active_indices[force_blank_mask]] += 1

hyps = batched_hyps.to_hyps()
# preserve last decoder state (is it necessary?)
for i, last_state in enumerate(last_decoder_state):
# assert last_state is not None
hyps[i].dec_state = last_state
if self.preserve_alignments:
for i, alignment in enumerate(alignments):
hyps[i].alignments = alignment
if self.preserve_frame_confidence:
for i, current_frame_confidence in enumerate(frame_confidence):
hyps[i].frame_confidence = current_frame_confidence
return hyps

def _greedy_decode_blank_as_pad_loop_frames(
self,
x: torch.Tensor,
out_len: torch.Tensor,
Expand Down Expand Up @@ -1737,9 +1922,9 @@
# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
if self.decoder.blank_as_pad:
self._greedy_decode = self._greedy_decode_blank_as_pad

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
else:
self._greedy_decode = self._greedy_decode_masked

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
self._SOS = blank_index - len(big_blank_durations)

def _greedy_decode_blank_as_pad(
Expand Down Expand Up @@ -2207,6 +2392,7 @@
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down
69 changes: 69 additions & 0 deletions nemo/collections/asr/parts/utils/rnnt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,72 @@ def select_k_expansions(
k_expansions.append([(k_best_exp_idx, k_best_exp)])

return k_expansions


class BatchedHyps:
"""Class to store batched hypotheses for efficient RNNT decoding"""

def __init__(
self,
batch_size: int,
init_length: int,
device: Optional[torch.device] = None,
float_dtype: Optional[torch.dtype] = None,
):
if init_length <= 0:
raise ValueError(f"init_length must be > 0, got {init_length}")
self.max_length = init_length
self.lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
self.transcript = torch.zeros((batch_size, self.max_length), device=device, dtype=torch.long)
self.timesteps = torch.zeros((batch_size, self.max_length), device=device, dtype=torch.long)
self.scores = torch.zeros(batch_size, device=device, dtype=float_dtype)
# tracking last timestep of each hyp to avoid infinite looping (when max symbols per frame is restricted)
self.last_timestep_lasts = torch.zeros(batch_size, device=device, dtype=torch.long)
self.last_timestep = torch.full((batch_size,), -1, device=device, dtype=torch.long)

def _allocate_more(self):
"""
Allocate 2x space for tensors, similar to common C++ std::vector implementations
to maintain O(1) insertion time complexity
"""
self.transcript = torch.cat((self.transcript, torch.zeros_like(self.transcript)), dim=-1)
self.timesteps = torch.cat((self.timesteps, torch.zeros_like(self.timesteps)), dim=-1)
self.max_length *= 2

def add_results_(
self, active_indices: torch.Tensor, labels: torch.Tensor, time_indices: torch.Tensor, scores: torch.Tensor
):
"""
Add results (inplace) from a decoding step to the batched hypotheses
Args:
active_indices: tensor with indices of active hypotheses (indices should be within the original batch_size)
labels: non-blank labels to add
time_indices: tensor of time index for each label
scores: label scores
"""
# we assume that all tensors have the same first dimension, and labels are non-blanks
if active_indices.shape[0] == 0:
return # nothing to add
if self.lengths.max().item() >= self.max_length:
self._allocate_more()
self.scores[active_indices] += scores
self.transcript.view(-1)[active_indices * self.max_length + self.lengths[active_indices]] = labels
self.timesteps.view(-1)[active_indices * self.max_length + self.lengths[active_indices]] = time_indices
self.last_timestep_lasts[active_indices] = torch.where(
self.last_timestep[active_indices] == time_indices, self.last_timestep_lasts[active_indices] + 1, 1
)
self.last_timestep[active_indices] = time_indices
self.lengths[active_indices] += 1

def to_hyps(self) -> List[Hypothesis]:
hypotheses = [
Hypothesis(
score=self.scores[i].item(),
y_sequence=self.transcript[i, : self.lengths[i]],
timestep=self.timesteps[i, : self.lengths[i]],
alignments=None,
dec_state=None,
)
for i in range(self.scores.shape[0])
]
return hypotheses
2 changes: 1 addition & 1 deletion nemo/collections/common/parts/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x, h = self.lstm(x, h)

if self.dropout:
if self.dropout is not None:
x = self.dropout(x)

return x, h
Expand Down
Loading