Skip to content

Commit

Permalink
Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer. (NVIDIA#9347)…
Browse files Browse the repository at this point in the history
… (NVIDIA#9350) (NVIDIA#9371)

* Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer. (NVIDIA#9347)

* Fix GreedyBatchedCTCInfer regression from GreedyCTCInfer.

decoder_lengths is allowed to be on CPU even when decoder_output is on
GPU. This matches the behavior of GreedyCTCInfer. Even though that
behavior is unintentional, there is code depending on that behavior,
including our jupyter notebooks.



* Apply isort and black reformatting



---------






(cherry picked from commit db26475)

* Add Packaging to install documentation



* Mark confidence tests as please fix me



---------

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Co-authored-by: Daniel Galvez <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
4 people committed Jun 4, 2024
1 parent e776933 commit 7d59f38
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ To install the nemo_toolkit, use the following installation method:
.. code-block:: bash
apt-get update && apt-get install -y libsndfile1 ffmpeg
pip install Cython
pip install Cython packaging
pip install nemo_toolkit['all']
Depending on the shell used, you may need to use the ``"nemo_toolkit[all]"`` specifier instead in the above command.
Expand All @@ -263,7 +263,7 @@ If you want to work with a specific version of NeMo from a particular GitHub bra
.. code-block:: bash
apt-get update && apt-get install -y libsndfile1 ffmpeg
pip install Cython
pip install Cython packaging
python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]
Expand Down Expand Up @@ -300,7 +300,7 @@ Run the following code:
conda install -c conda-forge pynini
# install Cython manually
pip install cython
pip install cython packaging
# clone the repo and install in development mode
git clone https://github.com/NVIDIA/NeMo
Expand Down
12 changes: 11 additions & 1 deletion nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,17 @@ def forward(

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])
decoder_lengths = torch.tensor(
[decoder_output.shape[1]], dtype=torch.long, device=decoder_output.device
).expand(decoder_output.shape[0])

# GreedyCTCInfer::forward(), by accident, works with
# decoder_lengths on either CPU or GPU when decoder_output is
# on GPU. For the sake of backwards compatibility, we also
# allow decoder_lengths to be on the CPU device. In this case,
# we simply copy the decoder_lengths from CPU to GPU. If both
# tensors are already on the same device, this is a no-op.
decoder_lengths = decoder_lengths.to(decoder_output.device)

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
Expand Down
2 changes: 2 additions & 0 deletions tests/collections/asr/confidence/test_asr_confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def audio_and_texts(test_data_dir):


class TestASRConfidenceBenchmark:
@pytest.mark.pleasefixme
@pytest.mark.integration
@pytest.mark.with_downloads
@pytest.mark.parametrize('model_name', ("ctc", "rnnt"))
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_run_confidence_benchmark(
atol=TOL,
)

@pytest.mark.pleasefixme
@pytest.mark.integration
@pytest.mark.with_downloads
@pytest.mark.parametrize('model_name', ("ctc", "rnnt"))
Expand Down
71 changes: 65 additions & 6 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,41 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"logprobs_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_logprobs(
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
self,
tmp_tokenizer,
alignments,
timestamps,
preserve_frame_confidence,
length_is_none,
logprobs_device,
length_device,
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
Expand All @@ -217,15 +250,15 @@ def test_batched_decoding_logprobs(
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V))
input_signal = torch.randn(size=(B, T, V), device=logprobs_device)
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])
length = torch.randint(low=1, high=T, size=[B], device=length_device)

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -249,7 +282,33 @@ def test_batched_decoding_logprobs(
@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
@pytest.mark.parametrize(
"labels_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
Expand All @@ -258,15 +317,15 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T))
input_labels = torch.randint(V, size=(B, T), device=labels_device)
# Set some indices to blank to make sure that we always handle
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])
length = torch.randint(low=1, high=T, size=[B], device=length_device)

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down

0 comments on commit 7d59f38

Please sign in to comment.