From 3ab6718466af071170d76dc908047ecfba2082e1 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 21:46:29 +0800 Subject: [PATCH] [text] refine tokenizer (#2165) * [text] refine tokenizer * [text] fix flake8 * [text] fix lint * [text] fix unit * [text] add bpe tokenizer and char tokenizer * [text] add char tokenizer unit test * [text] add bpe tokenizer unit test * [text] add WhisperTokenizer for test_whisper.py * [text] revert wenet/utils/file_utils.py * [text] add consistency for char and bpe unit * [text] merge main * [text] add symbol table * [text] add init_tokenizer unit test * [text] uncomment * [text] fix bpe model in multiprocess env * [text] fix whisper tokenzier in multiprocess env * [text] add test unit parallel for bpe and whisper * [text] fix none type in test_whisper.py * [text] all work --- test/test_tokenize.py | 126 ---------------- test/wenet/dataset/test_processor.py | 150 ++++++++++++++++++ test/wenet/text/test_bpe_tokenizer.py | 94 ++++++++++++ test/wenet/text/test_char_tokenizer.py | 141 +++++++++++++++++ test/wenet/text/test_parallel.py | 42 ++++++ test/wenet/text/test_wenet_tokenzier.py | 176 ++++++++++++++++++++++ test/wenet/text/test_whisper_tokenizer.py | 60 ++++++++ test/wenet/utils/test_init_tokenizer.py | 41 +++++ test/wenet/whisper/test_whisper.py | 17 ++- wenet/bin/recognize.py | 19 +-- wenet/bin/train.py | 9 +- wenet/dataset/dataset.py | 15 +- wenet/dataset/processor.py | 58 +------ wenet/text/base_tokenizer.py | 39 +++++ wenet/text/bpe_tokenizer.py | 51 +++++++ wenet/text/char_tokenizer.py | 78 ++++++++++ wenet/{utils => text}/tokenize_utils.py | 0 wenet/text/wenet_tokenizer.py | 90 +++++++++++ wenet/text/whisper_tokenizer.py | 92 +++++++++++ wenet/utils/context_graph.py | 2 +- wenet/utils/file_utils.py | 4 +- wenet/utils/init_tokenizer.py | 30 ++++ wenet/utils/train_utils.py | 39 ++--- 23 files changed, 1131 insertions(+), 242 deletions(-) delete mode 100644 test/test_tokenize.py create mode 100644 test/wenet/dataset/test_processor.py create mode 100644 test/wenet/text/test_bpe_tokenizer.py create mode 100644 test/wenet/text/test_char_tokenizer.py create mode 100644 test/wenet/text/test_parallel.py create mode 100644 test/wenet/text/test_wenet_tokenzier.py create mode 100644 test/wenet/text/test_whisper_tokenizer.py create mode 100644 test/wenet/utils/test_init_tokenizer.py create mode 100644 wenet/text/base_tokenizer.py create mode 100644 wenet/text/bpe_tokenizer.py create mode 100644 wenet/text/char_tokenizer.py rename wenet/{utils => text}/tokenize_utils.py (100%) create mode 100644 wenet/text/wenet_tokenizer.py create mode 100644 wenet/text/whisper_tokenizer.py create mode 100644 wenet/utils/init_tokenizer.py diff --git a/test/test_tokenize.py b/test/test_tokenize.py deleted file mode 100644 index 8f2cac1a2..000000000 --- a/test/test_tokenize.py +++ /dev/null @@ -1,126 +0,0 @@ -import pytest - -import wenet.dataset.processor as processor - -@pytest.mark.parametrize( - "symbol_table_path", - [ - "test/resources/librispeech.words.txt", - "test/resources/aishell2.words.txt" - ] -) -def test_tokenize(symbol_table_path): - txts = [ - {"txt": "震东好帅"}, - {"txt": " 吴迪也好帅 "}, - {"txt": "binbin is also handsome"}, - {"txt": " life is short i use wenet "}, - {"txt": "超哥 is the most handsome 吧"}, - {"txt": " 人生苦短i use wenet "}, - {"txt": "人生苦短I USE WENET"}, - {"txt": "zhendong ist so schön"}, - {"txt": " zhendong ist so schön "}, - {"txt": "It's okay"} - ] - if symbol_table_path == "test/resources/librispeech.words.txt": - bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" - refs = [ - {"tokens": ['震', '东', '好', '帅'], - "label": [1, 1, 1, 1]}, - {"tokens": ['吴', '迪', '也', '好', '帅'], - "label": [1, 1, 1, 1, 1]}, - {"tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], - "label": [347, 2216, 346, 2216, 2332, 143, 1990]}, - {"tokens": ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', - 'NE', 'T'], - "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], - "label": [1, 1, 2332, 4435, 2860, 1990, 1]}, - {"tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], - "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], - "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', - 'Ö', 'N'], - "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, - 1, 2901]}, - {"tokens": ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', - 'Ö', 'N'], - "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, - 1, 2901]}, - {"tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], - "label": [2344, 2, 3790, 3010, 2418, 4979]} - ] - else: - bpe_model = None - refs = [ - {"tokens": ['震', '东', '好', '帅'], - "label": [4932, 80, 1059, 1375]}, - {"tokens": ['吴', '迪', '也', '好', '帅'], - "label": [656, 4540, 117, 1059, 1375]}, - {"tokens": ['b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', - 'a', 'l', 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', - 'o', 'm', 'e'], - "label": [9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, - 1, 21, 7, 33, 13, 43, 35, 31, 15]}, - {"tokens": ['l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', - 'o', 'r', 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', - 'w', 'e', 'n', 'e', 't'], - "label": [29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, - 1, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46]}, - {"tokens": ['超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', - 'm', 'o', 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', - 'm', 'e', '▁', '吧'], - "label": [4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, - 1, 21, 7, 33, 13, 43, 35, 31, 15, 1, 647]}, - {"tokens": ['人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', - 'w', 'e', 'n', 'e', 't'], - "label": [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, - 15, 46]}, - {"tokens": ['人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', - 'W', 'E', 'N', 'E', 'T'], - "label": [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, - 16, 47]}, - {"tokens": ['z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', - 't', '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n'], - "label": [58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, - 35, 1, 43, 11, 21, 1, 33]}, - {"tokens": ['z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', - 't', '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n'], - "label": [58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, - 35, 1, 43, 11, 21, 1, 33]}, - {"tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], - "label": [24, 46, 2, 43, 1, 35, 27, 7, 56]} - ] - symbol_table = {} - with open(symbol_table_path, 'r') as f: - lines = f.readlines() - for l in lines: - l = l.strip().split() - symbol_table[l[0]] = int(l[1]) - outs = processor.tokenize( - txts, symbol_table, bpe_model, split_with_space=False - ) - for (hyp, ref) in zip(outs, refs): - assert(len(hyp["tokens"]) == len(ref["tokens"])) - assert(all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) - assert(len(hyp["label"]) == len(ref["label"])) - assert(all(h == r for h, r in zip(hyp["label"], ref["label"]))) - -@pytest.mark.parametrize("use_pbe_model", [True, False]) -def test_non_lang_symbol_tokenize(use_pbe_model): - data = [{"txt": "我是{NOISE}"}] - symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} - - if use_pbe_model: - bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" - - sample = next(processor.tokenize(data, symbol_table, bpe_model, - non_lang_syms=["{NOISE}"])) - - assert sample["tokens"] == ["我", "是", "{NOISE}"] - else: - sample = next(processor.tokenize(data, symbol_table, - non_lang_syms=["{NOISE}"])) - - assert sample["tokens"] == ["我", "是", "{NOISE}"] diff --git a/test/wenet/dataset/test_processor.py b/test/wenet/dataset/test_processor.py new file mode 100644 index 000000000..131d018c6 --- /dev/null +++ b/test/wenet/dataset/test_processor.py @@ -0,0 +1,150 @@ +import pytest + +import wenet.dataset.processor as processor +from wenet.text.wenet_tokenizer import WenetTokenizer + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", "test/resources/aishell2.words.txt" +]) +def test_tokenize(symbol_table_path): + txts = [{ + "txt": "震东好帅" + }, { + "txt": " 吴迪也好帅 " + }, { + "txt": "binbin is also handsome" + }, { + "txt": " life is short i use wenet " + }, { + "txt": "超哥 is the most handsome 吧" + }, { + "txt": " 人生苦短i use wenet " + }, { + "txt": "人生苦短I USE WENET" + }, { + "txt": "zhendong ist so schön" + }, { + "txt": " zhendong ist so schön " + }, { + "txt": "It's okay" + }] + if symbol_table_path == "test/resources/librispeech.words.txt": + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": + ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + else: + bpe_model = None + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', + 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, + 33, 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', + 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, + 43, 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', + 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, + 7, 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', + 'n', 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', + 'N', 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + + tokenizer = WenetTokenizer(symbol_table_path, + bpe_model, + split_with_space=False) + outs = processor.tokenize(txts, tokenizer) + for (hyp, ref) in zip(outs, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) diff --git a/test/wenet/text/test_bpe_tokenizer.py b/test/wenet/text/test_bpe_tokenizer.py new file mode 100644 index 000000000..852e2744e --- /dev/null +++ b/test/wenet/text/test_bpe_tokenizer.py @@ -0,0 +1,94 @@ +import pytest +from wenet.text.bpe_tokenizer import BpeTokenizer + + +@pytest.fixture(params=[[ + "test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel" +]]) +def bpe_tokenizer(request): + symbol_table, bpe_model = request.param + return BpeTokenizer(bpe_model, symbol_table) + + +def test_tokenize(bpe_tokenizer): + tokenizer = bpe_tokenizer + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +def test_detokenize(bpe_tokenizer): + tokenizer = bpe_tokenizer + # TODO(Mddct): more unit test + ids = [2344, 2, 3790, 3010, 2418, 4979] + expected = { + 'txt': "IT'S OKAY", + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'] + } + txt, tokens = tokenizer.detokenize(ids) + assert txt == expected['txt'] + assert (all(h == r for h, r in zip(tokens, expected['tokens']))) + + +def test_vocab_size(bpe_tokenizer): + assert bpe_tokenizer.vocab_size() == 5002 + + +def test_consistency(bpe_tokenizer): + text = "WENET IS GREAT" + assert text == bpe_tokenizer.tokens2text(bpe_tokenizer.text2tokens(text)) + assert text == bpe_tokenizer.detokenize(bpe_tokenizer.tokenize(text)[1])[0] diff --git a/test/wenet/text/test_char_tokenizer.py b/test/wenet/text/test_char_tokenizer.py new file mode 100644 index 000000000..fcf4f0762 --- /dev/null +++ b/test/wenet/text/test_char_tokenizer.py @@ -0,0 +1,141 @@ +import pytest +from wenet.text.char_tokenizer import CharTokenizer + + +@pytest.fixture(params=["test/resources/aishell2.words.txt"]) +def char_tokenizer(request): + symbol_table = request.param + return CharTokenizer(symbol_table) + + +def test_tokenize(char_tokenizer): + tokenizer = char_tokenizer + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', 's', + 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, 33, + 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', 't', + '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, 43, + 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', 's', + 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, 7, + 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', + 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', 'N', + 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', '▁', + 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, 43, + 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', '▁', + 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, 43, + 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +def test_detokenize(char_tokenizer): + tokenizer = char_tokenizer + idss = [ + [4932, 80, 1059, 1375], + [656, 4540, 117, 1059, 1375], + ] + + refs = [{ + "txt": "震东好帅", + "tokens": ['震', '东', '好', '帅'], + }, { + "txt": "吴迪也好帅", + "tokens": ['吴', '迪', '也', '好', '帅'], + }] + results = [] + for ids in idss: + txt, tokens = tokenizer.detokenize(ids) + results.append({"tokens": tokens, "txt": txt}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert len(hyp["txt"]) == len(ref["txt"]) + assert (all(h == r for h, r in zip(hyp["txt"], ref["txt"]))) + + +def test_vocab_size(char_tokenizer): + assert char_tokenizer.vocab_size() == 5235 + + +def test_consistency(char_tokenizer): + text = "大家都好帅" + + assert text == char_tokenizer.tokens2text(char_tokenizer.text2tokens(text)) + assert text == char_tokenizer.detokenize( + char_tokenizer.tokenize(text)[1])[0] diff --git a/test/wenet/text/test_parallel.py b/test/wenet/text/test_parallel.py new file mode 100644 index 000000000..28a0f37b2 --- /dev/null +++ b/test/wenet/text/test_parallel.py @@ -0,0 +1,42 @@ +from functools import partial +from multiprocessing import Pool +from wenet.text.base_tokenizer import BaseTokenizer + +from wenet.text.bpe_tokenizer import BpeTokenizer +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +def consistency(tokenizer: BaseTokenizer, line: str) -> str: + return tokenizer.detokenize(tokenizer.tokenize(line)[1])[0] + + +def test_whisper_tokenzier_parallel(): + + inputs = ["it's ok", "wenet is simple", "test for new io"] + tokenizer = WhisperTokenizer(False) + + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) + + +def test_bpe_tokenzier_parallel(): + + symbol_table_path = "test/resources/librispeech.words.txt" + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + + inputs = ["WENR IS SIMPLE", "GOOD"] + tokenizer = BpeTokenizer(bpe_model, symbol_table_path) + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) diff --git a/test/wenet/text/test_wenet_tokenzier.py b/test/wenet/text/test_wenet_tokenzier.py new file mode 100644 index 000000000..d70e9a782 --- /dev/null +++ b/test/wenet/text/test_wenet_tokenzier.py @@ -0,0 +1,176 @@ +import pytest + +from wenet.text.wenet_tokenizer import WenetTokenizer + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", + "test/resources/aishell2.words.txt", +]) +def test_tokenize(symbol_table_path): + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + if symbol_table_path == "test/resources/librispeech.words.txt": + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": + ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + else: + bpe_model = None + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', + 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, + 33, 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', + 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, + 43, 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', + 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, + 7, 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', + 'n', 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', + 'N', 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + tokenizer = WenetTokenizer(symbol_table=symbol_table_path, + bpe_model=bpe_model, + split_with_space=False) + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + print(hyp["tokens"], ref["tokens"]) + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +@pytest.mark.parametrize("use_pbe_model", [True, False]) +def test_non_lang_symbol_tokenize(use_pbe_model): + data = ["我是{NOISE}"] + symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + bpe_model = None + non_lang_syms = ["{NOISE}"] + expected = ["我", "是", "{NOISE}"] + + if use_pbe_model: + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + + tokenizer = WenetTokenizer(symbol_table, + bpe_model=bpe_model, + non_lang_syms=non_lang_syms) + for line in data: + tokens, _ = tokenizer.tokenize(line) + assert (all(h == r for h, r in zip(tokens, expected))) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", + "test/resources/aishell2.words.txt", +]) +def test_vocab_size(symbol_table_path): + tokenizer = WenetTokenizer(symbol_table_path) + if symbol_table_path == "test/resources/librispeech.words.txt": + assert tokenizer.vocab_size() == 5002 + else: + assert tokenizer.vocab_size() == 5235 diff --git a/test/wenet/text/test_whisper_tokenizer.py b/test/wenet/text/test_whisper_tokenizer.py new file mode 100644 index 000000000..bee4f6cc8 --- /dev/null +++ b/test/wenet/text/test_whisper_tokenizer.py @@ -0,0 +1,60 @@ +import pytest + +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +@pytest.fixture(params=[False]) +def whisper_tokenizer(request): + is_multilingual = request.param + return WhisperTokenizer(is_multilingual) + + +def test_tokenize(whisper_tokenizer): + + tokenizer = whisper_tokenizer + texts = ["life is short, i use wenet"] + expected = [{ + "tokens": [ + "b'life'", "b'is'", "b'short'", "b','", + "b'i'", "b'use'", "b'w'", "b'en'", "b'et'" + ], + "ids": [6042, 318, 1790, 11, 1312, 779, 266, 268, 316], + }] + + for i, text in enumerate(texts): + tokens, ids = tokenizer.tokenize(text) + assert len(tokens) == len(ids) + assert (all((h == r for h, r in zip(tokens, expected[i]["tokens"])))) + assert (all((h == r for h, r in zip(ids, expected[i]["ids"])))) + + +def test_detokenize(whisper_tokenizer): + tokenize = whisper_tokenizer + + inputs = [[6042, 318, 1790, 11, 1312, 779, 266, 268, 316]] + expected = [{ + "tokens": [ + "b'life'", "b'is'", "b'short'", "b','", + "b'i'", "b'use'", "b'w'", "b'en'", "b'et'" + ], + 'labels': + "life is short, i use wenet", + }] + + for i, input in enumerate(inputs): + text, tokens = tokenize.detokenize(input) + assert len(tokens) == len(expected[i]["tokens"]) + assert text == expected[i]["labels"] + assert all((h == r for h, r in zip(tokens, expected[i]["tokens"]))) + + +def test_consistency(whisper_tokenizer): + text = "whisper powered by wenet, it's great" + + assert text == whisper_tokenizer.tokens2text( + whisper_tokenizer.text2tokens(text)) + + +def test_vocab_size(whisper_tokenizer): + assert whisper_tokenizer.vocab_size( + ) == whisper_tokenizer.tokenizer.encoding.n_vocab diff --git a/test/wenet/utils/test_init_tokenizer.py b/test/wenet/utils/test_init_tokenizer.py new file mode 100644 index 000000000..1b5b4af94 --- /dev/null +++ b/test/wenet/utils/test_init_tokenizer.py @@ -0,0 +1,41 @@ +import pytest + +from wenet.utils.init_tokenizer import init_tokenizer + + +def test_init_whisper_tokenizer(): + # TODO(Mddct): add configs generator + configs = {} + configs['whisper'] = True + configs['whisper_conf'] = {} + configs['whisper_conf']['is_multilingual'] = False + configs['whisper_conf']['num_languages'] = 99 + + tokenizer = init_tokenizer(configs, None) + text = "whisper powered by wenet, it's great" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/aishell2.words.txt", +]) +def test_init_char_tokenizer(symbol_table_path): + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path) + + text = "大家都好帅" + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize( + "symbol_table_path, bpe_model", + [("test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel")]) +def test_init_bpe_tokenizer(symbol_table_path, bpe_model): + + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + text = "WENET IT'S GREAT" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index 3e61fb2cc..d71716869 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -12,10 +12,10 @@ import numpy as np import torch.nn.functional as F -from whisper.tokenizer import get_tokenizer from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim from wenet.dataset.processor import compute_log_mel_spectrogram +from wenet.text.whisper_tokenizer import WhisperTokenizer from wenet.transformer.embedding import WhisperPositionalEncoding from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import ( convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units @@ -108,19 +108,20 @@ def test_model(model, audio_path): checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual, num_languages=num_languages, - language=language, task=task) + tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, + language=language, task=task) + tokenizer._build_tiktoken() convert_to_wenet_state_dict( checkpoint["model_state_dict"], os.path.join(download_root, 'wenet_whisper.pt') ) convert_to_wenet_units( - tokenizer, + tokenizer.tokenizer, os.path.join(download_root, 'units.txt') ) convert_to_wenet_yaml( - tokenizer, checkpoint["dims"], + tokenizer.tokenizer, checkpoint["dims"], os.path.join(download_root, 'train.yaml') ) with open("{}/train.yaml".format(download_root), 'r') as fin: @@ -132,7 +133,7 @@ def test_model(model, audio_path): wenet_model.eval() with torch.no_grad(): - dummy_tokens = tokenizer.encode("WeNet x OpenAI") + _, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI") # 3. Forward whisper.encoder mel1 = whisper.log_mel_spectrogram( @@ -173,8 +174,8 @@ def test_model(model, audio_path): rtol=1e-7, atol=1e-10) # 4. Forward whisper.decoder - whisper_tokens = torch.tensor(list(tokenizer.sot_sequence) - + [tokenizer.no_timestamps] + whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence) + + [tokenizer.tokenizer.no_timestamps] + dummy_tokens, dtype=torch.long).unsqueeze(0) # (B=1, 9) whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 2e991f9da..1a842f8fe 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -24,9 +24,9 @@ from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset -from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.config import override_config from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.context_graph import ContextGraph @@ -185,7 +185,6 @@ def main(): if len(args.override_config) > 0: configs = override_config(configs, args.override_config) - symbol_table = read_symbol_table(args.dict) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 @@ -206,14 +205,12 @@ def main(): test_conf['mfcc_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, args.non_lang_syms) test_dataset = Dataset(args.data_type, args.test_data, - symbol_table, + tokenizer, test_conf, - args.bpe_model, - non_lang_syms, partition=False) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) @@ -221,9 +218,6 @@ def main(): # Init asr model from configs model, configs = init_model(args, configs) - # Load dict - char_dict = {v: k for k, v in symbol_table.items()} - use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) @@ -231,7 +225,7 @@ def main(): context_graph = None if 'decoding-graph' in args.context_bias_mode: - context_graph = ContextGraph(args.context_list_path, symbol_table, + context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table, args.bpe_model, args.context_graph_score) # TODO(Dinghao Zhou): Support RNN-T related decoding @@ -264,9 +258,8 @@ def main(): context_graph=context_graph) for i, key in enumerate(keys): for mode, hyps in results.items(): - content = [char_dict[w] for w in hyps[i].tokens] - line = '{} {}'.format(key, - args.connect_symbol.join(content)) + tokens = hyps[i].tokens + line = '{} {}'.format(key, tokenizer.detokenize(tokens)[0]) logging.info('{} {}'.format(mode.ljust(max_format_len), line)) files[mode].write(line + '\n') diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 0d8b4c094..c8e957439 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -28,6 +28,7 @@ from wenet.utils.executor import Executor from wenet.utils.config import override_config from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.train_utils import (add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, add_trace_args, init_distributed, @@ -73,15 +74,19 @@ def main(): if len(args.override_config) > 0: configs = override_config(configs, args.override_config) + # init tokenizer + tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, + args.non_lang_syms) + # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args) # Get dataset & dataloader train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ - init_dataset_and_dataloader(args, configs) + init_dataset_and_dataloader(args, configs, tokenizer) # Do some sanity checks and save config to arsg.model_dir - configs = check_modify_and_save_config(args, configs) + configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table) # Init asr model from configs model, configs = init_model(args, configs) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 9845cb805..9595011ec 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -19,10 +19,12 @@ from torch.utils.data import IterableDataset import wenet.dataset.processor as processor +from wenet.text.base_tokenizer import BaseTokenizer from wenet.utils.file_utils import read_lists class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): assert callable(f) self.source = source @@ -47,6 +49,7 @@ def apply(self, f): class DistributedSampler: + def __init__(self, shuffle=True, partition=True): self.epoch = -1 self.update() @@ -99,6 +102,7 @@ def sample(self, data): class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): self.lists = lists self.sampler = DistributedSampler(shuffle, partition) @@ -118,12 +122,9 @@ def __iter__(self): def Dataset(data_type, data_list_file, - symbol_table, + tokenizer: BaseTokenizer, conf, - bpe_model=None, - non_lang_syms=None, - partition=True, - whisper_tokenizer=None): + partition=True): """ Construct dataset from arguments We have two shuffle stage in the Dataset. The first is global @@ -145,9 +146,7 @@ def Dataset(data_type, else: dataset = Processor(dataset, processor.parse_raw) - dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model, - non_lang_syms, conf.get('split_with_space', False), - whisper_tokenizer) + dataset = Processor(dataset, processor.tokenize, tokenizer) filter_conf = conf.get('filter_conf', {}) dataset = Processor(dataset, processor.filter, **filter_conf) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 5c48e286d..a769eba8e 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -16,7 +16,6 @@ import logging import json import random -import re import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse @@ -26,7 +25,7 @@ import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from wenet.utils.tokenize_utils import tokenize_by_bpe_model +from wenet.text.base_tokenizer import BaseTokenizer torchaudio.utils.sox_utils.set_buffer_size(16500) @@ -366,12 +365,7 @@ def compute_log_mel_spectrogram(data, feat=log_spec.transpose(0, 1)) -def tokenize(data, - symbol_table, - bpe_model=None, - non_lang_syms=None, - split_with_space=False, - whisper_tokenizer=None): +def tokenize(data, tokenizer: BaseTokenizer): """ Decode text to chars or BPE Inplace operation @@ -381,55 +375,9 @@ def tokenize(data, Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ - if non_lang_syms is not None: - non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") - else: - non_lang_syms = {} - non_lang_syms_pattern = None - - if bpe_model is not None: - import sentencepiece as spm - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - else: - sp = None - for sample in data: assert 'txt' in sample - txt = sample['txt'].strip() - # TODO(xcsong): This is a dirty workaround for whisper tokernizer, - # refine it in the future - if whisper_tokenizer is not None: - sample['label'] = whisper_tokenizer.encode(txt) - yield sample - if non_lang_syms_pattern is not None: - parts = non_lang_syms_pattern.split(txt.upper()) - parts = [w for w in parts if len(w.strip()) > 0] - else: - parts = [txt] - - label = [] - tokens = [] - for part in parts: - if part in non_lang_syms: - tokens.append(part) - else: - if bpe_model is not None: - tokens.extend(tokenize_by_bpe_model(sp, part)) - else: - if split_with_space: - part = part.split(" ") - for ch in part: - if ch == ' ': - ch = "▁" - tokens.append(ch) - - for ch in tokens: - if ch in symbol_table: - label.append(symbol_table[ch]) - elif '' in symbol_table: - label.append(symbol_table['']) - + tokens, label = tokenizer.tokenize(sample['txt']) sample['tokens'] = tokens sample['label'] = label yield sample diff --git a/wenet/text/base_tokenizer.py b/wenet/text/base_tokenizer.py new file mode 100644 index 000000000..c96993309 --- /dev/null +++ b/wenet/text/base_tokenizer.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Dict, List, Tuple + + +class BaseTokenizer(ABC): + + def tokenize(self, line: str) -> Tuple[List[str], List[int]]: + tokens = self.text2tokens(line) + ids = self.tokens2ids(tokens) + return tokens, ids + + def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: + tokens = self.ids2tokens(ids) + text = self.tokens2text(tokens) + return text, tokens + + @abstractmethod + def text2tokens(self, line: str) -> List[str]: + raise NotImplementedError("abstract method") + + @abstractmethod + def tokens2text(self, tokens: List[str]) -> str: + raise NotImplementedError("abstract method") + + @abstractmethod + def tokens2ids(self, tokens: List[str]) -> List[int]: + raise NotImplementedError("abstract method") + + @abstractmethod + def ids2tokens(self, ids: List[int]) -> List[str]: + raise NotImplementedError("abstract method") + + @abstractmethod + def vocab_size(self) -> int: + raise NotImplementedError("abstract method") + + @abstractproperty + def symbol_table(self) -> Dict[str, int]: + raise NotImplementedError("abstract method") diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py new file mode 100644 index 000000000..de1b504a9 --- /dev/null +++ b/wenet/text/bpe_tokenizer.py @@ -0,0 +1,51 @@ +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.text.char_tokenizer import CharTokenizer +from wenet.text.tokenize_utils import tokenize_by_bpe_model + + +class BpeTokenizer(CharTokenizer): + + def __init__( + self, + bpe_model: PathLike, + symbol_table: Union[str, PathLike, Dict], + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + unk='', + ) -> None: + super().__init__(symbol_table, non_lang_syms, split_with_space, + connect_symbol, unk) + self._model = bpe_model + # NOTE(Mddct): multiprocessing.Process() issues + # don't build sp here + self.bpe_model = None + + def _build_sp(self): + if self.bpe_model is None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(self._model) + + def text2tokens(self, line: str) -> List[str]: + self._build_sp() + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) + return tokens + + def tokens2text(self, tokens: List[str]) -> str: + self._build_sp() + text = super().tokens2text(tokens) + return text.replace("▁", ' ').strip() diff --git a/wenet/text/char_tokenizer.py b/wenet/text/char_tokenizer.py new file mode 100644 index 000000000..5e47be41b --- /dev/null +++ b/wenet/text/char_tokenizer.py @@ -0,0 +1,78 @@ +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.text.base_tokenizer import BaseTokenizer + + +class CharTokenizer(BaseTokenizer): + + def __init__( + self, + symbol_table: Union[str, PathLike, Dict], + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + unk='', + ) -> None: + self.non_lang_syms_pattern = None + if non_lang_syms is not None: + self.non_lang_syms_pattern = re.compile( + r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + if not isinstance(symbol_table, Dict): + self._symbol_table = read_symbol_table(symbol_table) + else: + # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + self._symbol_table = symbol_table + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + self.char_dict = {v: k for k, v in self._symbol_table.items()} + self.split_with_space = split_with_space + self.connect_symbol = connect_symbol + self.unk = unk + + def text2tokens(self, line: str) -> List[str]: + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + if self.split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + return tokens + + def tokens2text(self, tokens: List[str]) -> str: + return self.connect_symbol.join(tokens) + + def tokens2ids(self, tokens: List[str]) -> List[int]: + ids = [] + for ch in tokens: + if ch in self._symbol_table: + ids.append(self._symbol_table[ch]) + elif self.unk in self._symbol_table: + ids.append(self._symbol_table[self.unk]) + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + content = [self.char_dict[w] for w in ids] + return content + + def vocab_size(self) -> int: + return len(self.char_dict) + + @property + def symbol_table(self) -> Dict[str, int]: + return self._symbol_table diff --git a/wenet/utils/tokenize_utils.py b/wenet/text/tokenize_utils.py similarity index 100% rename from wenet/utils/tokenize_utils.py rename to wenet/text/tokenize_utils.py diff --git a/wenet/text/wenet_tokenizer.py b/wenet/text/wenet_tokenizer.py new file mode 100644 index 000000000..23ed81483 --- /dev/null +++ b/wenet/text/wenet_tokenizer.py @@ -0,0 +1,90 @@ +import re + +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.tokenize_utils import tokenize_by_bpe_model + + +class WenetTokenizer(BaseTokenizer): + """Wrapper for original wenet tokenize implementation + """ + + def __init__( + self, + symbol_table: Union[str, PathLike, Dict], + bpe_model: Optional[Union[str, PathLike]] = None, + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + ) -> None: + self.non_lang_syms_pattern = None + if non_lang_syms is not None: + self.non_lang_syms_pattern = re.compile( + r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + if not isinstance(symbol_table, Dict): + self._symbol_table = read_symbol_table(symbol_table) + else: + # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + self._symbol_table = symbol_table + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + self.bpe_model = None + if bpe_model is not None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(bpe_model) + self.char_dict = {v: k for k, v in self._symbol_table.items()} + self.split_with_space = split_with_space + self.connect_symbol = connect_symbol + + def text2tokens(self, line: str) -> List[str]: + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + if self.bpe_model is not None: + tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) + else: + if self.split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + return tokens + + def tokens2text(self, tokens: List[str]) -> str: + return self.connect_symbol.join(tokens) + + def tokens2ids(self, tokens: List[str]) -> List[int]: + ids = [] + for ch in tokens: + if ch in self._symbol_table: + ids.append(self._symbol_table[ch]) + elif '' in self._symbol_table: + ids.append(self._symbol_table['']) + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + content = [self.char_dict[w] for w in ids] + return content + + def vocab_size(self) -> int: + return len(self.char_dict) + + @property + def symbol_table(self) -> Dict[str, int]: + return self._symbol_table diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py new file mode 100644 index 000000000..a5ad7e9c6 --- /dev/null +++ b/wenet/text/whisper_tokenizer.py @@ -0,0 +1,92 @@ +from os import PathLike +from typing import Dict, List, Optional, Tuple, Union +from wenet.text.base_tokenizer import BaseTokenizer + +from wenet.utils.file_utils import read_non_lang_symbols + + +class WhisperTokenizer(BaseTokenizer): + + def __init__( + self, + multilingual: bool, + num_languages: int = 99, + language: Optional[str] = None, + task: Optional[str] = None, + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + *args, + **kwargs, + ) -> None: + # NOTE(Mddct): don't build here, pickle issues + self.tokenizer = None + # TODO: we don't need this in future + self.multilingual = multilingual + self.num_languages = num_languages + self.language = language + self.task = task + + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + # TODO(Mddct): add special tokens, like non_lang_syms + del self.non_lang_syms + + def _build_tiktoken(self): + if self.tokenizer is None: + from whisper.tokenizer import get_tokenizer + self.tokenizer = get_tokenizer(multilingual=self.multilingual, + num_languages=self.num_languages, + language=self.language, + task=self.task) + self.t2i = {} + self.i2t = {} + for i in range(self.tokenizer.encoding.n_vocab): + unit = str( + self.tokenizer.encoding.decode_single_token_bytes(i)) + if len(unit) == 0: + unit = str(i) + unit = unit.replace(" ", "") + # unit = bytes(unit, 'utf-8') + self.t2i[unit] = i + self.i2t[i] = unit + assert len(self.t2i) == len(self.i2t) + + def tokenize(self, line: str) -> Tuple[List[str], List[int]]: + self._build_tiktoken() + ids = self.tokenizer.encoding.encode(line) + text = [self.i2t[d] for d in ids] + return text, ids + + def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: + self._build_tiktoken() + tokens = [self.i2t[d] for d in ids] + text = self.tokenizer.encoding.decode(ids) + return text, tokens + + def text2tokens(self, line: str) -> List[str]: + self._build_tiktoken() + return self.tokenize(line)[0] + + def tokens2text(self, tokens: List[str]) -> str: + self._build_tiktoken() + ids = [self.t2i[t] for t in tokens] + return self.detokenize(ids)[0] + + def tokens2ids(self, tokens: List[str]) -> List[int]: + self._build_tiktoken() + ids = [self.t2i[t] for t in tokens] + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + self._build_tiktoken() + return [self.tokenizer.encoding.decode([id]) for id in ids] + + def vocab_size(self) -> int: + self._build_tiktoken() + return len(self.t2i) + + def symbol_table(self) -> Dict[str, int]: + self._build_tiktoken() + return self.t2i diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index 0f9c1f866..d3fadd3d0 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from wenet.utils.tokenize_utils import tokenize_by_bpe_model +from wenet.text.tokenize_utils import tokenize_by_bpe_model from typing import Dict, List, Tuple from collections import deque diff --git a/wenet/utils/file_utils.py b/wenet/utils/file_utils.py index 7b7e516cc..264e37e73 100644 --- a/wenet/utils/file_utils.py +++ b/wenet/utils/file_utils.py @@ -39,14 +39,16 @@ def read_non_lang_symbols(non_lang_sym_path): """ if non_lang_sym_path is None: - return None + return [] else: syms = read_lists(non_lang_sym_path) non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") for sym in syms: if non_lang_syms_pattern.fullmatch(sym) is None: + class BadSymbolFormat(Exception): pass + raise BadSymbolFormat( "Non-linguistic symbols should be " "formatted in {xxx}//[xxx], consider" diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py new file mode 100644 index 000000000..fcafba7a4 --- /dev/null +++ b/wenet/utils/init_tokenizer.py @@ -0,0 +1,30 @@ +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.bpe_tokenizer import BpeTokenizer +from wenet.text.char_tokenizer import CharTokenizer +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +def init_tokenizer(configs, + symbol_table, + bpe_model=None, + non_lang_syms=None) -> BaseTokenizer: + # TODO: + # 1 huggface tokenizer + # 2 paraformer tokenizer + + if configs.get("whisper", False): + tokenizer = WhisperTokenizer( + multilingual=configs['whisper_conf']['is_multilingual'], + num_languages=configs['whisper_conf']['num_languages']) + elif bpe_model is None: + tokenizer = CharTokenizer(symbol_table, + non_lang_syms, + split_with_space=configs.get( + 'split_with_space', False)) + else: + tokenizer = BpeTokenizer(bpe_model, + symbol_table, + split_with_space=configs.get( + 'split_with_space', False)) + + return tokenizer diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 37e8b85ff..54c325f0a 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -37,11 +37,8 @@ from deepspeed.utils.zero_to_fp32 import ( convert_zero_checkpoint_to_fp32_state_dict ) -from whisper.tokenizer import get_tokenizer - from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint -from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing @@ -163,7 +160,7 @@ def init_distributed(args): return world_size, local_rank, rank -def check_modify_and_save_config(args, configs): +def check_modify_and_save_config(args, configs, symbol_table): if args.train_engine == "torch_ddp": if args.use_amp: configs["dtype"] = "fp16" @@ -213,8 +210,6 @@ def check_modify_and_save_config(args, configs): input_dim = configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] else: input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - symbol_table = read_symbol_table(args.symbol_table) - vocab_size = len(symbol_table) if 'ctc_conf' not in configs: configs['ctc_conf'] = {} @@ -228,7 +223,7 @@ def check_modify_and_save_config(args, configs): assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml" configs['input_dim'] = input_dim - configs['output_dim'] = configs.get('output_dim', vocab_size) + configs['output_dim'] = configs['vocab_size'] configs['cmvn_file'] = args.cmvn configs['is_json_cmvn'] = True configs['lfmmi_dir'] = args.lfmmi_dir @@ -248,7 +243,7 @@ def check_modify_and_save_config(args, configs): return configs -def init_dataset_and_dataloader(args, configs): +def init_dataset_and_dataloader(args, configs, tokenizer): train_conf = configs['dataset_conf'] cv_conf = copy.deepcopy(train_conf) cv_conf['speed_perturb'] = False @@ -257,29 +252,17 @@ def init_dataset_and_dataloader(args, configs): cv_conf['spec_trim'] = False cv_conf['shuffle'] = False - symbol_table = read_symbol_table(args.symbol_table) - non_lang_syms = read_non_lang_symbols(args.non_lang_syms) - - # TODO(xcsong): This is a dirty workaround for whisper tokenizer, - # refine it in the future - if configs.get("whisper", False): - logging.info("using whisper tokenizer") - whisper_tok = get_tokenizer( - multilingual=configs['whisper_conf']['is_multilingual'], - num_languages=configs['whisper_conf']['num_languages']) - else: - whisper_tok = None - train_dataset = Dataset(args.data_type, args.train_data, symbol_table, - train_conf, args.bpe_model, non_lang_syms, True, - whisper_tok) + configs['vocab_size'] = tokenizer.vocab_size() + train_dataset = Dataset(args.data_type, + args.train_data, + tokenizer, + train_conf, + True) cv_dataset = Dataset(args.data_type, args.cv_data, - symbol_table, + tokenizer, cv_conf, - args.bpe_model, - non_lang_syms, - partition=False, - whisper_tokenizer=whisper_tok) + partition=False) # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110