Skip to content

Commit

Permalink
Attempt at refactor...
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Jan 31, 2024
1 parent e4327fc commit cd0facb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
9 changes: 5 additions & 4 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig):
self.AGGREGATE_TOKENIZERS_DICT_PREFIX
][lang]['type']

if tokenizer_cfg.get('is_canary', False):
# CanaryTokenizer easy access to spl_tokens which aggegatate
# doesn't have for now; TODO: merge both later
self.tokenizer = tokenizers.CanaryTokenizer(tokenizers_dict)
if "custom_tokenizer" in tokenizer_cfg:
# Class which implements this is usually a ModelPT, has access to Serializable mixin by extension
self.tokenizer = self.from_config_dict(
{"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "tokenizers": tokenizers_dict}
)
else:
self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict)

Expand Down
18 changes: 11 additions & 7 deletions tests/collections/asr/test_custom_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.common.tokenizers.canary_tokenizer import SPECIAL_TOKENS, UNUSED_SPECIAL_TOKENS, CanaryTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model
from nemo.core import Serialization


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -35,21 +36,24 @@ def test_canary_tokenizer_build_special_tokenizer(tmp_path):


def test_canary_tokenizer_init_from_cfg(special_tokenizer_path, lang_tokenizer_path):
bpe_mixin = ASRBPEMixin()
bpe_mixin.register_artifact = Mock(side_effect=lambda self, x: x)
class DummyModel(ASRBPEMixin, Serialization):
pass

model = DummyModel()
model.register_artifact = Mock(side_effect=lambda self, x: x)
config = OmegaConf.create(
{
"type": "agg",
"is_canary": True,
"dir": None,
"langs": {
"spl_tokens": {"dir": special_tokenizer_path, "type": "bpe",},
"en": {"dir": lang_tokenizer_path, "type": "bpe",},
"spl_tokens": {"dir": special_tokenizer_path, "type": "bpe"},
"en": {"dir": lang_tokenizer_path, "type": "bpe"},
},
"custom_tokenizer": {"_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer",},
}
)
bpe_mixin._setup_aggregate_tokenizer(config)
tokenizer = bpe_mixin.tokenizer
model._setup_aggregate_tokenizer(config)
tokenizer = model.tokenizer

assert isinstance(tokenizer, CanaryTokenizer)
assert len(tokenizer.tokenizers_dict) == 2
Expand Down

0 comments on commit cd0facb

Please sign in to comment.