Skip to content

Commit

Permalink
Fix bug wrt change decoding strategy for bpe models (#7762)
Browse files Browse the repository at this point in the history
* Fix bug wrt change decoding strategy for bpe models

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and web-flow committed Oct 20, 2023
1 parent 9cefa80 commit c62d84a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg

self.cur_decoder = "rnnt"
logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

elif decoder_type == 'ctc':
if not hasattr(self, 'ctc_decoding'):
raise ValueError("The model does not have the ctc_decoding module and does not support ctc decoding.")
Expand Down
22 changes: 22 additions & 0 deletions tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,25 @@ def test_decoding_change(self, hybrid_asr_model):
assert hybrid_asr_model.ctc_decoding.preserve_alignments is True
assert hybrid_asr_model.ctc_decoding.compute_timestamps is True
assert hybrid_asr_model.cur_decoder == "ctc"

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_type_change(self, hybrid_asr_model):
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyBatchedRNNTInfer)

new_strategy = DictConfig({})
new_strategy.strategy = 'greedy'
new_strategy.greedy = DictConfig({'max_symbols': 10})
hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='rnnt')
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer)
assert hybrid_asr_model.cur_decoder == 'rnnt'

hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='ctc')
assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding)
assert hybrid_asr_model.cur_decoder == 'ctc'

hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='rnnt')
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer)
assert hybrid_asr_model.cur_decoder == 'rnnt'
22 changes: 22 additions & 0 deletions tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,28 @@ def test_decoding_change(self, hybrid_asr_model):
assert hybrid_asr_model.ctc_decoding.preserve_alignments is True
assert hybrid_asr_model.ctc_decoding.compute_timestamps is True

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_type_change(self, hybrid_asr_model):
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyBatchedRNNTInfer)

new_strategy = DictConfig({})
new_strategy.strategy = 'greedy'
new_strategy.greedy = DictConfig({'max_symbols': 10})
hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='rnnt')
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer)
assert hybrid_asr_model.cur_decoder == 'rnnt'

hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='ctc')
assert isinstance(hybrid_asr_model.ctc_decoding, CTCDecoding)
assert hybrid_asr_model.cur_decoder == 'ctc'

hybrid_asr_model.change_decoding_strategy(decoding_cfg=new_strategy, decoder_type='rnnt')
assert isinstance(hybrid_asr_model.decoding.decoding, greedy_decode.GreedyRNNTInfer)
assert hybrid_asr_model.cur_decoder == 'rnnt'

@pytest.mark.unit
def test_GreedyRNNTInferConfig(self):
IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index']
Expand Down

0 comments on commit c62d84a

Please sign in to comment.