From 49994bfeb6693cee259c95f29be3c619373ac666 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 21:39:08 +0800 Subject: [PATCH] [text] add init_tokenizer unit test --- test/wenet/utils/test_init_tokenizer.py | 41 +++++++++++++++++++++++++ wenet/bin/train.py | 3 +- 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 test/wenet/utils/test_init_tokenizer.py diff --git a/test/wenet/utils/test_init_tokenizer.py b/test/wenet/utils/test_init_tokenizer.py new file mode 100644 index 000000000..1b5b4af94 --- /dev/null +++ b/test/wenet/utils/test_init_tokenizer.py @@ -0,0 +1,41 @@ +import pytest + +from wenet.utils.init_tokenizer import init_tokenizer + + +def test_init_whisper_tokenizer(): + # TODO(Mddct): add configs generator + configs = {} + configs['whisper'] = True + configs['whisper_conf'] = {} + configs['whisper_conf']['is_multilingual'] = False + configs['whisper_conf']['num_languages'] = 99 + + tokenizer = init_tokenizer(configs, None) + text = "whisper powered by wenet, it's great" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/aishell2.words.txt", +]) +def test_init_char_tokenizer(symbol_table_path): + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path) + + text = "大家都好帅" + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize( + "symbol_table_path, bpe_model", + [("test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel")]) +def test_init_bpe_tokenizer(symbol_table_path, bpe_model): + + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + text = "WENET IT'S GREAT" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 33b31a6fd..7d4a9a11f 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -75,7 +75,8 @@ def main(): configs = override_config(configs, args.override_config) # init tokenizer - tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms) + tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, + non_lang_syms) # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args)