Skip to content

Commit

Permalink
[text] add init_tokenizer unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 27, 2023
1 parent bd24277 commit 75b1e78
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
41 changes: 41 additions & 0 deletions test/wenet/utils/test_init_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from wenet.utils.init_tokenizer import init_tokenizer


def test_init_whisper_tokenizer():
# TODO(Mddct): add configs generator
configs = {}
configs['whisper'] = True
configs['whisper_conf'] = {}
configs['whisper_conf']['is_multilingual'] = False
configs['whisper_conf']['num_languages'] = 99

tokenizer = init_tokenizer(configs, None)
text = "whisper powered by wenet, it's great"

assert text == tokenizer.tokens2text(tokenizer.text2tokens(text))


@pytest.mark.parametrize("symbol_table_path", [
"test/resources/aishell2.words.txt",
])
def test_init_char_tokenizer(symbol_table_path):
configs = {}
tokenizer = init_tokenizer(configs, symbol_table_path)

text = "大家都好帅"
assert text == tokenizer.tokens2text(tokenizer.text2tokens(text))


@pytest.mark.parametrize(
"symbol_table_path, bpe_model",
[("test/resources/librispeech.words.txt",
"test/resources/librispeech.train_960_unigram5000.bpemodel")])
def test_init_bpe_tokenizer(symbol_table_path, bpe_model):

configs = {}
tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model)
text = "WENET IT'S GREAT"

assert text == tokenizer.tokens2text(tokenizer.text2tokens(text))
3 changes: 2 additions & 1 deletion wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def main():
configs = override_config(configs, args.override_config)

# init tokenizer
tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms)
tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model,
non_lang_syms)

# Init env for ddp OR deepspeed
world_size, local_rank, rank = init_distributed(args)
Expand Down
4 changes: 4 additions & 0 deletions wenet/utils/init_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ def init_tokenizer(configs,
'split_with_space', False))

return tokenizer


tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model,
non_lang_syms)

0 comments on commit 75b1e78

Please sign in to comment.