Skip to content

Commit

Permalink
[Cherry-pick] Fix virtual function issue with CTC decoder (#3230) (#3238
Browse files Browse the repository at this point in the history
)

Summary:
Currently, creating CTCDecoder object by passing a language model to
`lm` argument without assigning it to a variable elsewhere causes
`RuntimeError: Tried to call pure virtual function "LM::start"`.

According to discussions on PyBind11, (
pybind/pybind11#4013 and
pybind/pybind11#2839
) this is due to Python object garbage-collected by the time
it's used by code implemented in C++. It attempts to call
methods defined in Python, which overrides the base pure virtual
function, but the object which provides this override gets
deleted by garbage collrector, as the original object is not
reference counted.

This commit fixes this by simply assiging the given `lm` object
as an attribute of CTCDecoder class.

Address #3218

Pull Request resolved: #3230

Reviewed By: hwangjeff

Differential Revision: D44642989

Pulled By: mthrok

fbshipit-source-id: a90af828c7c576bc0eb505164327365ebaadc471
  • Loading branch information
mthrok authored Apr 5, 2023
1 parent 9df28ff commit a4ea69e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens):

expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens)

def test_lm_lifecycle(self):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from torchaudio.models.decoder import ctc_decoder

from .ctc_decoder_utils import CustomZeroLM

decoder = ctc_decoder(
lexicon=get_asset_path("decoder/lexicon.txt"),
tokens=get_asset_path("decoder/tokens.txt"),
lm=CustomZeroLM(),
)
decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32))
6 changes: 6 additions & 0 deletions torchaudio/models/decoder/_ctc_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ def __init__(
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
# https://github.com/pytorch/audio/issues/3218
# If lm is passed like rvalue reference, the lm object gets garbage collected,
# and later call to the lm fails.
# This ensures that lm object is not deleted as long as the decoder is alive.
# https://github.com/pybind/pybind11/discussions/4013
self.lm = lm

def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
Expand Down

0 comments on commit a4ea69e

Please sign in to comment.