Skip to content

Commit ae8769d

Browse files
authored
Add loop_labels algorithm for TDT greedy decoding (NVIDIA#8215)
* Add `loop_labels` algorithm for TDT greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Use `loop_labels` by default Signed-off-by: Vladimir Bataev <[email protected]> * Loop labels greedy decoding v2 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments. Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix comment Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add computer for TDT Signed-off-by: Vladimir Bataev <[email protected]> * Fix TDT decoding algorithm Signed-off-by: Vladimir Bataev <[email protected]> * Use loop frames by default for TDT Signed-off-by: Vladimir Bataev <[email protected]> * Remove "loop frames" implementation for TDT Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence. Use tensor for durations. Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]>
1 parent 88126f3 commit ae8769d

File tree

2 files changed

+301
-166
lines changed

2 files changed

+301
-166
lines changed

nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py

+33-166
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from nemo.collections.asr.modules import rnnt_abstract
3737
from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer
38+
from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer
3839
from nemo.collections.asr.parts.utils import rnnt_utils
3940
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
4041
from nemo.collections.common.parts.rnn import label_collate
@@ -2638,8 +2639,20 @@ def __init__(
26382639

26392640
# Depending on availability of `blank_as_pad` support
26402641
# switch between more efficient batch decoding technique
2642+
self._decoding_computer = None
26412643
if self.decoder.blank_as_pad:
2642-
self._greedy_decode = self._greedy_decode_blank_as_pad
2644+
# batched "loop frames" is not implemented for TDT
2645+
self._decoding_computer = GreedyBatchedTDTLoopLabelsComputer(
2646+
decoder=self.decoder,
2647+
joint=self.joint,
2648+
blank_index=self._blank_index,
2649+
durations=self.durations,
2650+
max_symbols_per_step=self.max_symbols,
2651+
preserve_alignments=preserve_alignments,
2652+
preserve_frame_confidence=preserve_frame_confidence,
2653+
confidence_method_cfg=confidence_method_cfg,
2654+
)
2655+
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
26432656
else:
26442657
self._greedy_decode = self._greedy_decode_masked
26452658

@@ -2685,179 +2698,33 @@ def forward(
26852698

26862699
return (packed_result,)
26872700

2688-
def _greedy_decode_blank_as_pad(
2701+
def _greedy_decode_masked(
26892702
self,
26902703
x: torch.Tensor,
26912704
out_len: torch.Tensor,
26922705
device: torch.device,
26932706
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
26942707
):
2695-
if partial_hypotheses is not None:
2696-
raise NotImplementedError("`partial_hypotheses` support is not supported")
2697-
2698-
with torch.inference_mode():
2699-
# x: [B, T, D]
2700-
# out_len: [B]
2701-
# device: torch.device
2702-
2703-
# Initialize list of Hypothesis
2704-
batchsize = x.shape[0]
2705-
hypotheses = [
2706-
rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize)
2707-
]
2708-
2709-
# Initialize Hidden state matrix (shared by entire batch)
2710-
hidden = None
2711-
2712-
# If alignments need to be preserved, register a danling list to hold the values
2713-
if self.preserve_alignments:
2714-
# alignments is a 3-dimensional dangling list representing B x T x U
2715-
for hyp in hypotheses:
2716-
hyp.alignments = [[]]
2717-
2718-
# If confidence scores need to be preserved, register a danling list to hold the values
2719-
if self.preserve_frame_confidence:
2720-
# frame_confidence is a 3-dimensional dangling list representing B x T x U
2721-
for hyp in hypotheses:
2722-
hyp.frame_confidence = [[]]
2723-
2724-
# Last Label buffer + Last Label without blank buffer
2725-
# batch level equivalent of the last_label
2726-
last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)
2727-
2728-
# Mask buffers
2729-
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
2730-
2731-
# Get max sequence length
2732-
max_out_len = out_len.max()
2733-
2734-
# skip means the number of frames the next decoding step should "jump" to. When skip == 1
2735-
# it means the next decoding step will just use the next input frame.
2736-
skip = 1
2737-
for time_idx in range(max_out_len):
2738-
if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame.
2739-
skip -= 1
2740-
continue
2741-
f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]
2742-
2743-
# need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame.
2744-
need_to_stay = True
2745-
symbols_added = 0
2746-
2747-
# Reset blank mask
2748-
blank_mask.mul_(False)
2749-
2750-
# Update blank mask with time mask
2751-
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
2752-
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
2753-
blank_mask = time_idx >= out_len
2754-
2755-
# Start inner loop
2756-
while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols):
2757-
# Batch prediction and joint network steps
2758-
# If very first prediction step, submit SOS tag (blank) to pred_step.
2759-
# This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
2760-
if time_idx == 0 and symbols_added == 0 and hidden is None:
2761-
g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
2762-
else:
2763-
# Perform batch step prediction of decoder, getting new states and scores ("g")
2764-
g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize)
2765-
2766-
# Batched joint step - Output = [B, V + 1 + num-big-blanks]
2767-
# Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits,
2768-
# and they need to be normalized independently.
2769-
joined = self._joint_step(f, g, log_normalize=None)
2770-
logp = joined[:, 0, 0, : -len(self.durations)]
2771-
duration_logp = joined[:, 0, 0, -len(self.durations) :]
2772-
2773-
if logp.dtype != torch.float32:
2774-
logp = logp.float()
2775-
duration_logp = duration_logp.float()
2776-
2777-
# get the max for both token and duration predictions.
2778-
v, k = logp.max(1)
2779-
dv, dk = duration_logp.max(1)
2780-
2781-
# here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there.
2782-
# Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this.
2783-
skip = self.durations[int(torch.min(dk))]
2784-
2785-
# this is a special case: if all batches emit blanks, we require that skip be at least 1
2786-
# so we don't loop forever at the current frame.
2787-
if blank_mask.all():
2788-
if skip == 0:
2789-
skip = 1
2790-
2791-
need_to_stay = skip == 0
2792-
del g
2793-
2794-
# Update blank mask with current predicted blanks
2795-
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
2796-
k_is_blank = k == self._blank_index
2797-
blank_mask.bitwise_or_(k_is_blank)
2798-
2799-
del k_is_blank
2800-
del logp, duration_logp
2801-
2802-
# If all samples predict / have predicted prior blanks, exit loop early
2803-
# This is equivalent to if single sample predicted k
2804-
if not blank_mask.all():
2805-
# Collect batch indices where blanks occurred now/past
2806-
blank_indices = (blank_mask == 1).nonzero(as_tuple=False)
2807-
2808-
# Recover prior state for all samples which predicted blank now/past
2809-
if hidden is not None:
2810-
hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices)
2811-
2812-
elif len(blank_indices) > 0 and hidden is None:
2813-
# Reset state if there were some blank and other non-blank predictions in batch
2814-
# Original state is filled with zeros so we just multiply
2815-
# LSTM has 2 states
2816-
hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0)
2817-
2818-
# Recover prior predicted label for all samples which predicted blank now/past
2819-
k[blank_indices] = last_label[blank_indices, 0]
2820-
2821-
# Update new label and hidden state for next iteration
2822-
last_label = k.clone().view(-1, 1)
2823-
hidden = hidden_prime
2824-
2825-
# Update predicted labels, accounting for time mask
2826-
# If blank was predicted even once, now or in the past,
2827-
# Force the current predicted label to also be blank
2828-
# This ensures that blanks propogate across all timesteps
2829-
# once they have occured (normally stopping condition of sample level loop).
2830-
for kidx, ki in enumerate(k):
2831-
if blank_mask[kidx] == 0:
2832-
hypotheses[kidx].y_sequence.append(ki)
2833-
hypotheses[kidx].timestep.append(time_idx)
2834-
hypotheses[kidx].score += float(v[kidx])
2835-
2836-
symbols_added += 1
2837-
2838-
# Remove trailing empty list of alignments at T_{am-len} x Uj
2839-
if self.preserve_alignments:
2840-
for batch_idx in range(batchsize):
2841-
if len(hypotheses[batch_idx].alignments[-1]) == 0:
2842-
del hypotheses[batch_idx].alignments[-1]
2843-
2844-
# Remove trailing empty list of confidence scores at T_{am-len} x Uj
2845-
if self.preserve_frame_confidence:
2846-
for batch_idx in range(batchsize):
2847-
if len(hypotheses[batch_idx].frame_confidence[-1]) == 0:
2848-
del hypotheses[batch_idx].frame_confidence[-1]
2849-
2850-
# Preserve states
2851-
for batch_idx in range(batchsize):
2852-
hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx)
2853-
2854-
return hypotheses
2708+
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")
28552709

2856-
def _greedy_decode_masked(
2710+
@torch.inference_mode()
2711+
def _greedy_decode_blank_as_pad_loop_labels(
28572712
self,
28582713
x: torch.Tensor,
28592714
out_len: torch.Tensor,
28602715
device: torch.device,
2861-
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
2862-
):
2863-
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")
2716+
partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None,
2717+
) -> list[rnnt_utils.Hypothesis]:
2718+
"""
2719+
Optimized batched greedy decoding.
2720+
The main idea: search for next labels for the whole batch (evaluating Joint)
2721+
and thus always evaluate prediction network with maximum possible batch size
2722+
"""
2723+
if partial_hypotheses is not None:
2724+
raise NotImplementedError("`partial_hypotheses` support is not implemented")
2725+
2726+
batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len)
2727+
hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments)
2728+
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
2729+
hyp.dec_state = state
2730+
return hyps

0 commit comments

Comments
 (0)