diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index b789198db512..eeac9d3c78ad 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -214,10 +214,11 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): self.AGGREGATE_TOKENIZERS_DICT_PREFIX ][lang]['type'] - if tokenizer_cfg.get('is_canary', False): - # CanaryTokenizer easy access to spl_tokens which aggegatate - # doesn't have for now; TODO: merge both later - self.tokenizer = tokenizers.CanaryTokenizer(tokenizers_dict) + if "custom_tokenizer" in tokenizer_cfg: + # Class which implements this is usually a ModelPT, has access to Serializable mixin by extension + self.tokenizer = self.from_config_dict( + {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "tokenizers": tokenizers_dict} + ) else: self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict) diff --git a/tests/collections/asr/test_custom_tokenizer.py b/tests/collections/asr/test_custom_tokenizer.py index acb4eb3ce95d..bd65753571ec 100644 --- a/tests/collections/asr/test_custom_tokenizer.py +++ b/tests/collections/asr/test_custom_tokenizer.py @@ -7,6 +7,7 @@ from nemo.collections.asr.parts.mixins import ASRBPEMixin from nemo.collections.common.tokenizers.canary_tokenizer import SPECIAL_TOKENS, UNUSED_SPECIAL_TOKENS, CanaryTokenizer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model +from nemo.core import Serialization @pytest.fixture(scope="session") @@ -35,21 +36,24 @@ def test_canary_tokenizer_build_special_tokenizer(tmp_path): def test_canary_tokenizer_init_from_cfg(special_tokenizer_path, lang_tokenizer_path): - bpe_mixin = ASRBPEMixin() - bpe_mixin.register_artifact = Mock(side_effect=lambda self, x: x) + class DummyModel(ASRBPEMixin, Serialization): + pass + + model = DummyModel() + model.register_artifact = Mock(side_effect=lambda self, x: x) config = OmegaConf.create( { "type": "agg", - "is_canary": True, "dir": None, "langs": { - "spl_tokens": {"dir": special_tokenizer_path, "type": "bpe",}, - "en": {"dir": lang_tokenizer_path, "type": "bpe",}, + "spl_tokens": {"dir": special_tokenizer_path, "type": "bpe"}, + "en": {"dir": lang_tokenizer_path, "type": "bpe"}, }, + "custom_tokenizer": {"_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer",}, } ) - bpe_mixin._setup_aggregate_tokenizer(config) - tokenizer = bpe_mixin.tokenizer + model._setup_aggregate_tokenizer(config) + tokenizer = model.tokenizer assert isinstance(tokenizer, CanaryTokenizer) assert len(tokenizer.tokenizers_dict) == 2