|
35 | 35 |
|
36 | 36 | from nemo.collections.asr.modules import rnnt_abstract
|
37 | 37 | from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer
|
| 38 | +from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer |
38 | 39 | from nemo.collections.asr.parts.utils import rnnt_utils
|
39 | 40 | from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
|
40 | 41 | from nemo.collections.common.parts.rnn import label_collate
|
@@ -2638,8 +2639,20 @@ def __init__(
|
2638 | 2639 |
|
2639 | 2640 | # Depending on availability of `blank_as_pad` support
|
2640 | 2641 | # switch between more efficient batch decoding technique
|
| 2642 | + self._decoding_computer = None |
2641 | 2643 | if self.decoder.blank_as_pad:
|
2642 |
| - self._greedy_decode = self._greedy_decode_blank_as_pad |
| 2644 | + # batched "loop frames" is not implemented for TDT |
| 2645 | + self._decoding_computer = GreedyBatchedTDTLoopLabelsComputer( |
| 2646 | + decoder=self.decoder, |
| 2647 | + joint=self.joint, |
| 2648 | + blank_index=self._blank_index, |
| 2649 | + durations=self.durations, |
| 2650 | + max_symbols_per_step=self.max_symbols, |
| 2651 | + preserve_alignments=preserve_alignments, |
| 2652 | + preserve_frame_confidence=preserve_frame_confidence, |
| 2653 | + confidence_method_cfg=confidence_method_cfg, |
| 2654 | + ) |
| 2655 | + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels |
2643 | 2656 | else:
|
2644 | 2657 | self._greedy_decode = self._greedy_decode_masked
|
2645 | 2658 |
|
@@ -2685,179 +2698,33 @@ def forward(
|
2685 | 2698 |
|
2686 | 2699 | return (packed_result,)
|
2687 | 2700 |
|
2688 |
| - def _greedy_decode_blank_as_pad( |
| 2701 | + def _greedy_decode_masked( |
2689 | 2702 | self,
|
2690 | 2703 | x: torch.Tensor,
|
2691 | 2704 | out_len: torch.Tensor,
|
2692 | 2705 | device: torch.device,
|
2693 | 2706 | partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
|
2694 | 2707 | ):
|
2695 |
| - if partial_hypotheses is not None: |
2696 |
| - raise NotImplementedError("`partial_hypotheses` support is not supported") |
2697 |
| - |
2698 |
| - with torch.inference_mode(): |
2699 |
| - # x: [B, T, D] |
2700 |
| - # out_len: [B] |
2701 |
| - # device: torch.device |
2702 |
| - |
2703 |
| - # Initialize list of Hypothesis |
2704 |
| - batchsize = x.shape[0] |
2705 |
| - hypotheses = [ |
2706 |
| - rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) |
2707 |
| - ] |
2708 |
| - |
2709 |
| - # Initialize Hidden state matrix (shared by entire batch) |
2710 |
| - hidden = None |
2711 |
| - |
2712 |
| - # If alignments need to be preserved, register a danling list to hold the values |
2713 |
| - if self.preserve_alignments: |
2714 |
| - # alignments is a 3-dimensional dangling list representing B x T x U |
2715 |
| - for hyp in hypotheses: |
2716 |
| - hyp.alignments = [[]] |
2717 |
| - |
2718 |
| - # If confidence scores need to be preserved, register a danling list to hold the values |
2719 |
| - if self.preserve_frame_confidence: |
2720 |
| - # frame_confidence is a 3-dimensional dangling list representing B x T x U |
2721 |
| - for hyp in hypotheses: |
2722 |
| - hyp.frame_confidence = [[]] |
2723 |
| - |
2724 |
| - # Last Label buffer + Last Label without blank buffer |
2725 |
| - # batch level equivalent of the last_label |
2726 |
| - last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) |
2727 |
| - |
2728 |
| - # Mask buffers |
2729 |
| - blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) |
2730 |
| - |
2731 |
| - # Get max sequence length |
2732 |
| - max_out_len = out_len.max() |
2733 |
| - |
2734 |
| - # skip means the number of frames the next decoding step should "jump" to. When skip == 1 |
2735 |
| - # it means the next decoding step will just use the next input frame. |
2736 |
| - skip = 1 |
2737 |
| - for time_idx in range(max_out_len): |
2738 |
| - if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame. |
2739 |
| - skip -= 1 |
2740 |
| - continue |
2741 |
| - f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] |
2742 |
| - |
2743 |
| - # need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame. |
2744 |
| - need_to_stay = True |
2745 |
| - symbols_added = 0 |
2746 |
| - |
2747 |
| - # Reset blank mask |
2748 |
| - blank_mask.mul_(False) |
2749 |
| - |
2750 |
| - # Update blank mask with time mask |
2751 |
| - # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) |
2752 |
| - # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len |
2753 |
| - blank_mask = time_idx >= out_len |
2754 |
| - |
2755 |
| - # Start inner loop |
2756 |
| - while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols): |
2757 |
| - # Batch prediction and joint network steps |
2758 |
| - # If very first prediction step, submit SOS tag (blank) to pred_step. |
2759 |
| - # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state |
2760 |
| - if time_idx == 0 and symbols_added == 0 and hidden is None: |
2761 |
| - g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) |
2762 |
| - else: |
2763 |
| - # Perform batch step prediction of decoder, getting new states and scores ("g") |
2764 |
| - g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) |
2765 |
| - |
2766 |
| - # Batched joint step - Output = [B, V + 1 + num-big-blanks] |
2767 |
| - # Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits, |
2768 |
| - # and they need to be normalized independently. |
2769 |
| - joined = self._joint_step(f, g, log_normalize=None) |
2770 |
| - logp = joined[:, 0, 0, : -len(self.durations)] |
2771 |
| - duration_logp = joined[:, 0, 0, -len(self.durations) :] |
2772 |
| - |
2773 |
| - if logp.dtype != torch.float32: |
2774 |
| - logp = logp.float() |
2775 |
| - duration_logp = duration_logp.float() |
2776 |
| - |
2777 |
| - # get the max for both token and duration predictions. |
2778 |
| - v, k = logp.max(1) |
2779 |
| - dv, dk = duration_logp.max(1) |
2780 |
| - |
2781 |
| - # here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there. |
2782 |
| - # Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this. |
2783 |
| - skip = self.durations[int(torch.min(dk))] |
2784 |
| - |
2785 |
| - # this is a special case: if all batches emit blanks, we require that skip be at least 1 |
2786 |
| - # so we don't loop forever at the current frame. |
2787 |
| - if blank_mask.all(): |
2788 |
| - if skip == 0: |
2789 |
| - skip = 1 |
2790 |
| - |
2791 |
| - need_to_stay = skip == 0 |
2792 |
| - del g |
2793 |
| - |
2794 |
| - # Update blank mask with current predicted blanks |
2795 |
| - # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) |
2796 |
| - k_is_blank = k == self._blank_index |
2797 |
| - blank_mask.bitwise_or_(k_is_blank) |
2798 |
| - |
2799 |
| - del k_is_blank |
2800 |
| - del logp, duration_logp |
2801 |
| - |
2802 |
| - # If all samples predict / have predicted prior blanks, exit loop early |
2803 |
| - # This is equivalent to if single sample predicted k |
2804 |
| - if not blank_mask.all(): |
2805 |
| - # Collect batch indices where blanks occurred now/past |
2806 |
| - blank_indices = (blank_mask == 1).nonzero(as_tuple=False) |
2807 |
| - |
2808 |
| - # Recover prior state for all samples which predicted blank now/past |
2809 |
| - if hidden is not None: |
2810 |
| - hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) |
2811 |
| - |
2812 |
| - elif len(blank_indices) > 0 and hidden is None: |
2813 |
| - # Reset state if there were some blank and other non-blank predictions in batch |
2814 |
| - # Original state is filled with zeros so we just multiply |
2815 |
| - # LSTM has 2 states |
2816 |
| - hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) |
2817 |
| - |
2818 |
| - # Recover prior predicted label for all samples which predicted blank now/past |
2819 |
| - k[blank_indices] = last_label[blank_indices, 0] |
2820 |
| - |
2821 |
| - # Update new label and hidden state for next iteration |
2822 |
| - last_label = k.clone().view(-1, 1) |
2823 |
| - hidden = hidden_prime |
2824 |
| - |
2825 |
| - # Update predicted labels, accounting for time mask |
2826 |
| - # If blank was predicted even once, now or in the past, |
2827 |
| - # Force the current predicted label to also be blank |
2828 |
| - # This ensures that blanks propogate across all timesteps |
2829 |
| - # once they have occured (normally stopping condition of sample level loop). |
2830 |
| - for kidx, ki in enumerate(k): |
2831 |
| - if blank_mask[kidx] == 0: |
2832 |
| - hypotheses[kidx].y_sequence.append(ki) |
2833 |
| - hypotheses[kidx].timestep.append(time_idx) |
2834 |
| - hypotheses[kidx].score += float(v[kidx]) |
2835 |
| - |
2836 |
| - symbols_added += 1 |
2837 |
| - |
2838 |
| - # Remove trailing empty list of alignments at T_{am-len} x Uj |
2839 |
| - if self.preserve_alignments: |
2840 |
| - for batch_idx in range(batchsize): |
2841 |
| - if len(hypotheses[batch_idx].alignments[-1]) == 0: |
2842 |
| - del hypotheses[batch_idx].alignments[-1] |
2843 |
| - |
2844 |
| - # Remove trailing empty list of confidence scores at T_{am-len} x Uj |
2845 |
| - if self.preserve_frame_confidence: |
2846 |
| - for batch_idx in range(batchsize): |
2847 |
| - if len(hypotheses[batch_idx].frame_confidence[-1]) == 0: |
2848 |
| - del hypotheses[batch_idx].frame_confidence[-1] |
2849 |
| - |
2850 |
| - # Preserve states |
2851 |
| - for batch_idx in range(batchsize): |
2852 |
| - hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx) |
2853 |
| - |
2854 |
| - return hypotheses |
| 2708 | + raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.") |
2855 | 2709 |
|
2856 |
| - def _greedy_decode_masked( |
| 2710 | + @torch.inference_mode() |
| 2711 | + def _greedy_decode_blank_as_pad_loop_labels( |
2857 | 2712 | self,
|
2858 | 2713 | x: torch.Tensor,
|
2859 | 2714 | out_len: torch.Tensor,
|
2860 | 2715 | device: torch.device,
|
2861 |
| - partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, |
2862 |
| - ): |
2863 |
| - raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.") |
| 2716 | + partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None, |
| 2717 | + ) -> list[rnnt_utils.Hypothesis]: |
| 2718 | + """ |
| 2719 | + Optimized batched greedy decoding. |
| 2720 | + The main idea: search for next labels for the whole batch (evaluating Joint) |
| 2721 | + and thus always evaluate prediction network with maximum possible batch size |
| 2722 | + """ |
| 2723 | + if partial_hypotheses is not None: |
| 2724 | + raise NotImplementedError("`partial_hypotheses` support is not implemented") |
| 2725 | + |
| 2726 | + batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len) |
| 2727 | + hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments) |
| 2728 | + for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)): |
| 2729 | + hyp.dec_state = state |
| 2730 | + return hyps |
0 commit comments