From a4ea69ea1c0c15547a7bfe85762cd459950be9c7 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 5 Apr 2023 06:47:46 -0700 Subject: [PATCH] [Cherry-pick] Fix virtual function issue with CTC decoder (#3230) (#3238) Summary: Currently, creating CTCDecoder object by passing a language model to `lm` argument without assigning it to a variable elsewhere causes `RuntimeError: Tried to call pure virtual function "LM::start"`. According to discussions on PyBind11, ( https://github.com/pybind/pybind11/discussions/4013 and https://github.com/pybind/pybind11/pull/2839 ) this is due to Python object garbage-collected by the time it's used by code implemented in C++. It attempts to call methods defined in Python, which overrides the base pure virtual function, but the object which provides this override gets deleted by garbage collrector, as the original object is not reference counted. This commit fixes this by simply assiging the given `lm` object as an attribute of CTCDecoder class. Address https://github.com/pytorch/audio/issues/3218 Pull Request resolved: https://github.com/pytorch/audio/pull/3230 Reviewed By: hwangjeff Differential Revision: D44642989 Pulled By: mthrok fbshipit-source-id: a90af828c7c576bc0eb505164327365ebaadc471 --- .../models/decoder/ctc_decoder_test.py | 16 ++++++++++++++++ torchaudio/models/decoder/_ctc_decoder.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py index f794f92ff9..87dc93ffd3 100644 --- a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py +++ b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py @@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens): expected_tokens = ["|", "f", "|", "o", "a"] self.assertEqual(tokens, expected_tokens) + + def test_lm_lifecycle(self): + """Passing lm without assiging it to a vaiable won't cause runtime error + + https://github.com/pytorch/audio/issues/3218 + """ + from torchaudio.models.decoder import ctc_decoder + + from .ctc_decoder_utils import CustomZeroLM + + decoder = ctc_decoder( + lexicon=get_asset_path("decoder/lexicon.txt"), + tokens=get_asset_path("decoder/tokens.txt"), + lm=CustomZeroLM(), + ) + decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32)) diff --git a/torchaudio/models/decoder/_ctc_decoder.py b/torchaudio/models/decoder/_ctc_decoder.py index d9fa5165d8..33daa09ec9 100644 --- a/torchaudio/models/decoder/_ctc_decoder.py +++ b/torchaudio/models/decoder/_ctc_decoder.py @@ -269,6 +269,12 @@ def __init__( ) else: self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions) + # https://github.com/pytorch/audio/issues/3218 + # If lm is passed like rvalue reference, the lm object gets garbage collected, + # and later call to the lm fails. + # This ensures that lm object is not deleted as long as the decoder is alive. + # https://github.com/pybind/pybind11/discussions/4013 + self.lm = lm def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: idxs = (g[0] for g in it.groupby(idxs))