Skip to content

Commit

Permalink
Accept None as an argument to decoder_lengths in GreedyBatchedCTCInfe…
Browse files Browse the repository at this point in the history
…r::forward

GreedyCTCInfer::forward already allowed for this, so they did not
implement the exact same interface. Now, they do.

Also warn about not passing in the decoder_lengths argument. It is
likely an error on the user's part not to pass it in explicitly.
  • Loading branch information
galv committed May 17, 2024
1 parent 5a68d2a commit b44faec
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
26 changes: 21 additions & 5 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _states_to_device(dec_state, device='cpu'):

return dec_state

_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]."


class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
"""A greedy CTC decoder.
Expand Down Expand Up @@ -145,7 +147,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -158,6 +162,9 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""
if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING)

with torch.inference_mode():
hypotheses = []
# Process each sequence independently
Expand Down Expand Up @@ -204,7 +211,7 @@ def forward(
return (packed_result,)

@torch.no_grad()
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T, D]
# out_len: [seq_len]

Expand Down Expand Up @@ -234,7 +241,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
return hypothesis

@torch.no_grad()
def _greedy_decode_labels(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T]
# out_len: [seq_len]

Expand Down Expand Up @@ -361,7 +368,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -374,11 +383,18 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""

input_decoder_lengths = decoder_lengths

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, input_decoder_lengths)
return (packed_result,)

@torch.no_grad()
Expand Down
16 changes: 12 additions & 4 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
Expand All @@ -217,7 +218,10 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -240,7 +244,8 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,

@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
Expand All @@ -254,7 +259,10 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down

0 comments on commit b44faec

Please sign in to comment.