From 75b1e78b3ce6d60a09d8044d98a96f2b5307a6db Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 21:39:08 +0800 Subject: [PATCH] [text] add init_tokenizer unit test --- test/wenet/utils/test_init_tokenizer.py | 41 +++++++++++++++++++++++++ wenet/bin/train.py | 3 +- wenet/utils/init_tokenizer.py | 4 +++ 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test/wenet/utils/test_init_tokenizer.py diff --git a/test/wenet/utils/test_init_tokenizer.py b/test/wenet/utils/test_init_tokenizer.py new file mode 100644 index 0000000000..1b5b4af94a --- /dev/null +++ b/test/wenet/utils/test_init_tokenizer.py @@ -0,0 +1,41 @@ +import pytest + +from wenet.utils.init_tokenizer import init_tokenizer + + +def test_init_whisper_tokenizer(): + # TODO(Mddct): add configs generator + configs = {} + configs['whisper'] = True + configs['whisper_conf'] = {} + configs['whisper_conf']['is_multilingual'] = False + configs['whisper_conf']['num_languages'] = 99 + + tokenizer = init_tokenizer(configs, None) + text = "whisper powered by wenet, it's great" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/aishell2.words.txt", +]) +def test_init_char_tokenizer(symbol_table_path): + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path) + + text = "大家都好帅" + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize( + "symbol_table_path, bpe_model", + [("test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel")]) +def test_init_bpe_tokenizer(symbol_table_path, bpe_model): + + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + text = "WENET IT'S GREAT" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 33b31a6fdd..7d4a9a11fc 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -75,7 +75,8 @@ def main(): configs = override_config(configs, args.override_config) # init tokenizer - tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms) + tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, + non_lang_syms) # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args) diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index fcafba7a4a..f6120db4b3 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -28,3 +28,7 @@ def init_tokenizer(configs, 'split_with_space', False)) return tokenizer + + +tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, + non_lang_syms)