From 8144a2dad27d41768e3c5bb58c8c442f43968403 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 25 Nov 2023 19:12:03 +0800 Subject: [PATCH 01/19] [text] refine tokenizer --- test/test_tokenize.py | 126 ---------------- test/wenet/dataset/test_processor.py | 150 ++++++++++++++++++ test/wenet/text/test_wenet_tokenzier.py | 176 ++++++++++++++++++++++ test/wenet/text/test_whisper_tokenizer.py | 60 ++++++++ wenet/bin/recognize.py | 17 +-- wenet/dataset/dataset.py | 15 +- wenet/dataset/processor.py | 60 +------- wenet/text/base_tokenizer.py | 36 +++++ wenet/{utils => text}/tokenize_utils.py | 0 wenet/text/wenet_tokenizer.py | 87 +++++++++++ wenet/text/whisper_tokenizer.py | 71 +++++++++ wenet/utils/context_graph.py | 2 +- wenet/utils/file_utils.py | 4 +- wenet/utils/init_tokenizer.py | 16 ++ wenet/utils/train_utils.py | 37 ++--- 15 files changed, 627 insertions(+), 230 deletions(-) delete mode 100644 test/test_tokenize.py create mode 100644 test/wenet/dataset/test_processor.py create mode 100644 test/wenet/text/test_wenet_tokenzier.py create mode 100644 test/wenet/text/test_whisper_tokenizer.py create mode 100644 wenet/text/base_tokenizer.py rename wenet/{utils => text}/tokenize_utils.py (100%) create mode 100644 wenet/text/wenet_tokenizer.py create mode 100644 wenet/text/whisper_tokenizer.py create mode 100644 wenet/utils/init_tokenizer.py diff --git a/test/test_tokenize.py b/test/test_tokenize.py deleted file mode 100644 index 8f2cac1a2..000000000 --- a/test/test_tokenize.py +++ /dev/null @@ -1,126 +0,0 @@ -import pytest - -import wenet.dataset.processor as processor - -@pytest.mark.parametrize( - "symbol_table_path", - [ - "test/resources/librispeech.words.txt", - "test/resources/aishell2.words.txt" - ] -) -def test_tokenize(symbol_table_path): - txts = [ - {"txt": "震东好帅"}, - {"txt": " 吴迪也好帅 "}, - {"txt": "binbin is also handsome"}, - {"txt": " life is short i use wenet "}, - {"txt": "超哥 is the most handsome 吧"}, - {"txt": " 人生苦短i use wenet "}, - {"txt": "人生苦短I USE WENET"}, - {"txt": "zhendong ist so schön"}, - {"txt": " zhendong ist so schön "}, - {"txt": "It's okay"} - ] - if symbol_table_path == "test/resources/librispeech.words.txt": - bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" - refs = [ - {"tokens": ['震', '东', '好', '帅'], - "label": [1, 1, 1, 1]}, - {"tokens": ['吴', '迪', '也', '好', '帅'], - "label": [1, 1, 1, 1, 1]}, - {"tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], - "label": [347, 2216, 346, 2216, 2332, 143, 1990]}, - {"tokens": ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', - 'NE', 'T'], - "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], - "label": [1, 1, 2332, 4435, 2860, 1990, 1]}, - {"tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], - "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], - "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366]}, - {"tokens": ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', - 'Ö', 'N'], - "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, - 1, 2901]}, - {"tokens": ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', - 'Ö', 'N'], - "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, - 1, 2901]}, - {"tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], - "label": [2344, 2, 3790, 3010, 2418, 4979]} - ] - else: - bpe_model = None - refs = [ - {"tokens": ['震', '东', '好', '帅'], - "label": [4932, 80, 1059, 1375]}, - {"tokens": ['吴', '迪', '也', '好', '帅'], - "label": [656, 4540, 117, 1059, 1375]}, - {"tokens": ['b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', - 'a', 'l', 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', - 'o', 'm', 'e'], - "label": [9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, - 1, 21, 7, 33, 13, 43, 35, 31, 15]}, - {"tokens": ['l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', - 'o', 'r', 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', - 'w', 'e', 'n', 'e', 't'], - "label": [29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, - 1, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46]}, - {"tokens": ['超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', - 'm', 'o', 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', - 'm', 'e', '▁', '吧'], - "label": [4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, - 1, 21, 7, 33, 13, 43, 35, 31, 15, 1, 647]}, - {"tokens": ['人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', - 'w', 'e', 'n', 'e', 't'], - "label": [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, - 15, 46]}, - {"tokens": ['人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', - 'W', 'E', 'N', 'E', 'T'], - "label": [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, - 16, 47]}, - {"tokens": ['z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', - 't', '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n'], - "label": [58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, - 35, 1, 43, 11, 21, 1, 33]}, - {"tokens": ['z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', - 't', '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n'], - "label": [58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, - 35, 1, 43, 11, 21, 1, 33]}, - {"tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], - "label": [24, 46, 2, 43, 1, 35, 27, 7, 56]} - ] - symbol_table = {} - with open(symbol_table_path, 'r') as f: - lines = f.readlines() - for l in lines: - l = l.strip().split() - symbol_table[l[0]] = int(l[1]) - outs = processor.tokenize( - txts, symbol_table, bpe_model, split_with_space=False - ) - for (hyp, ref) in zip(outs, refs): - assert(len(hyp["tokens"]) == len(ref["tokens"])) - assert(all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) - assert(len(hyp["label"]) == len(ref["label"])) - assert(all(h == r for h, r in zip(hyp["label"], ref["label"]))) - -@pytest.mark.parametrize("use_pbe_model", [True, False]) -def test_non_lang_symbol_tokenize(use_pbe_model): - data = [{"txt": "我是{NOISE}"}] - symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} - - if use_pbe_model: - bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" - - sample = next(processor.tokenize(data, symbol_table, bpe_model, - non_lang_syms=["{NOISE}"])) - - assert sample["tokens"] == ["我", "是", "{NOISE}"] - else: - sample = next(processor.tokenize(data, symbol_table, - non_lang_syms=["{NOISE}"])) - - assert sample["tokens"] == ["我", "是", "{NOISE}"] diff --git a/test/wenet/dataset/test_processor.py b/test/wenet/dataset/test_processor.py new file mode 100644 index 000000000..131d018c6 --- /dev/null +++ b/test/wenet/dataset/test_processor.py @@ -0,0 +1,150 @@ +import pytest + +import wenet.dataset.processor as processor +from wenet.text.wenet_tokenizer import WenetTokenizer + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", "test/resources/aishell2.words.txt" +]) +def test_tokenize(symbol_table_path): + txts = [{ + "txt": "震东好帅" + }, { + "txt": " 吴迪也好帅 " + }, { + "txt": "binbin is also handsome" + }, { + "txt": " life is short i use wenet " + }, { + "txt": "超哥 is the most handsome 吧" + }, { + "txt": " 人生苦短i use wenet " + }, { + "txt": "人生苦短I USE WENET" + }, { + "txt": "zhendong ist so schön" + }, { + "txt": " zhendong ist so schön " + }, { + "txt": "It's okay" + }] + if symbol_table_path == "test/resources/librispeech.words.txt": + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": + ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + else: + bpe_model = None + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', + 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, + 33, 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', + 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, + 43, 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', + 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, + 7, 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', + 'n', 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', + 'N', 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + + tokenizer = WenetTokenizer(symbol_table_path, + bpe_model, + split_with_space=False) + outs = processor.tokenize(txts, tokenizer) + for (hyp, ref) in zip(outs, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) diff --git a/test/wenet/text/test_wenet_tokenzier.py b/test/wenet/text/test_wenet_tokenzier.py new file mode 100644 index 000000000..d70e9a782 --- /dev/null +++ b/test/wenet/text/test_wenet_tokenzier.py @@ -0,0 +1,176 @@ +import pytest + +from wenet.text.wenet_tokenizer import WenetTokenizer + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", + "test/resources/aishell2.words.txt", +]) +def test_tokenize(symbol_table_path): + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + if symbol_table_path == "test/resources/librispeech.words.txt": + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": + ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": + [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + else: + bpe_model = None + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', + 's', 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, + 33, 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', + 't', '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, + 43, 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', + 's', 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, + 7, 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', + 'n', 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', + 'N', 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', + '▁', 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, + 43, 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + tokenizer = WenetTokenizer(symbol_table=symbol_table_path, + bpe_model=bpe_model, + split_with_space=False) + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + print(hyp["tokens"], ref["tokens"]) + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +@pytest.mark.parametrize("use_pbe_model", [True, False]) +def test_non_lang_symbol_tokenize(use_pbe_model): + data = ["我是{NOISE}"] + symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + bpe_model = None + non_lang_syms = ["{NOISE}"] + expected = ["我", "是", "{NOISE}"] + + if use_pbe_model: + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + + tokenizer = WenetTokenizer(symbol_table, + bpe_model=bpe_model, + non_lang_syms=non_lang_syms) + for line in data: + tokens, _ = tokenizer.tokenize(line) + assert (all(h == r for h, r in zip(tokens, expected))) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/librispeech.words.txt", + "test/resources/aishell2.words.txt", +]) +def test_vocab_size(symbol_table_path): + tokenizer = WenetTokenizer(symbol_table_path) + if symbol_table_path == "test/resources/librispeech.words.txt": + assert tokenizer.vocab_size() == 5002 + else: + assert tokenizer.vocab_size() == 5235 diff --git a/test/wenet/text/test_whisper_tokenizer.py b/test/wenet/text/test_whisper_tokenizer.py new file mode 100644 index 000000000..370a2c52e --- /dev/null +++ b/test/wenet/text/test_whisper_tokenizer.py @@ -0,0 +1,60 @@ +import pytest + +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +@pytest.fixture(params=[False]) +def whisper_tokenizer(request): + is_multilingual = request.param + return WhisperTokenizer(is_multilingual) + + +def test_tokeniz(whisper_tokenizer): + + tokenizer = whisper_tokenizer + texts = ["life is short, i use wenet"] + expected = [{ + "tokens": [ + "b'life'", "b'is'", "b'short'", "b','", + "b'i'", "b'use'", "b'w'", "b'en'", "b'et'" + ], + "ids": [6042, 318, 1790, 11, 1312, 779, 266, 268, 316], + }] + + for i, text in enumerate(texts): + tokens, ids = tokenizer.tokenize(text) + assert len(tokens) == len(ids) + assert (all([h == r for h, r in zip(tokens, expected[i]["tokens"])])) + assert (all([h == r for h, r in zip(ids, expected[i]["ids"])])) + + +def test_detokenize(whisper_tokenizer): + tokenize = whisper_tokenizer + + inputs = [[6042, 318, 1790, 11, 1312, 779, 266, 268, 316]] + expected = [{ + "tokens": [ + "b'life'", "b'is'", "b'short'", "b','", + "b'i'", "b'use'", "b'w'", "b'en'", "b'et'" + ], + 'labels': + "life is short, i use wenet", + }] + + for i, input in enumerate(inputs): + text, tokens = tokenize.detokenize(input) + assert len(tokens) == len(expected[i]["tokens"]) + assert text == expected[i]["labels"] + assert all([h == r for h, r in zip(tokens, expected[i]["tokens"])]) + + +def test_consistency(whisper_tokenizer): + text = "whisper powered by wenet, it't great" + + assert text == whisper_tokenizer.tokens2text( + whisper_tokenizer.text2tokens(text)) + + +def test_vocab_size(whisper_tokenizer): + assert whisper_tokenizer.vocab_size( + ) == whisper_tokenizer.tokenizer.encoding.n_vocab diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 2e991f9da..c493f782a 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -24,9 +24,9 @@ from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset -from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.config import override_config from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.context_graph import ContextGraph @@ -185,7 +185,6 @@ def main(): if len(args.override_config) > 0: configs = override_config(configs, args.override_config) - symbol_table = read_symbol_table(args.dict) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 @@ -206,14 +205,12 @@ def main(): test_conf['mfcc_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + tokenizer = init_tokenizer(configs, args, non_lang_syms) test_dataset = Dataset(args.data_type, args.test_data, - symbol_table, + tokenizer, test_conf, - args.bpe_model, - non_lang_syms, partition=False) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) @@ -221,9 +218,6 @@ def main(): # Init asr model from configs model, configs = init_model(args, configs) - # Load dict - char_dict = {v: k for k, v in symbol_table.items()} - use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) @@ -264,9 +258,8 @@ def main(): context_graph=context_graph) for i, key in enumerate(keys): for mode, hyps in results.items(): - content = [char_dict[w] for w in hyps[i].tokens] - line = '{} {}'.format(key, - args.connect_symbol.join(content)) + tokens = hyps[i].tokens + line = '{} {}'.format(key, tokenizer.detokenize(tokens)[0]) logging.info('{} {}'.format(mode.ljust(max_format_len), line)) files[mode].write(line + '\n') diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 9845cb805..9595011ec 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -19,10 +19,12 @@ from torch.utils.data import IterableDataset import wenet.dataset.processor as processor +from wenet.text.base_tokenizer import BaseTokenizer from wenet.utils.file_utils import read_lists class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): assert callable(f) self.source = source @@ -47,6 +49,7 @@ def apply(self, f): class DistributedSampler: + def __init__(self, shuffle=True, partition=True): self.epoch = -1 self.update() @@ -99,6 +102,7 @@ def sample(self, data): class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): self.lists = lists self.sampler = DistributedSampler(shuffle, partition) @@ -118,12 +122,9 @@ def __iter__(self): def Dataset(data_type, data_list_file, - symbol_table, + tokenizer: BaseTokenizer, conf, - bpe_model=None, - non_lang_syms=None, - partition=True, - whisper_tokenizer=None): + partition=True): """ Construct dataset from arguments We have two shuffle stage in the Dataset. The first is global @@ -145,9 +146,7 @@ def Dataset(data_type, else: dataset = Processor(dataset, processor.parse_raw) - dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model, - non_lang_syms, conf.get('split_with_space', False), - whisper_tokenizer) + dataset = Processor(dataset, processor.tokenize, tokenizer) filter_conf = conf.get('filter_conf', {}) dataset = Processor(dataset, processor.filter, **filter_conf) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 5c48e286d..68e500e45 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -16,7 +16,6 @@ import logging import json import random -import re import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse @@ -26,9 +25,9 @@ import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from wenet.utils.tokenize_utils import tokenize_by_bpe_model +from wenet.text.base_tokenizer import BaseTokenizer -torchaudio.utils.sox_utils.set_buffer_size(16500) +# torchaudio.utils.sox_utils.set_buffer_size(16500) AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) @@ -366,12 +365,7 @@ def compute_log_mel_spectrogram(data, feat=log_spec.transpose(0, 1)) -def tokenize(data, - symbol_table, - bpe_model=None, - non_lang_syms=None, - split_with_space=False, - whisper_tokenizer=None): +def tokenize(data, tokenizer: BaseTokenizer): """ Decode text to chars or BPE Inplace operation @@ -381,55 +375,9 @@ def tokenize(data, Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ - if non_lang_syms is not None: - non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") - else: - non_lang_syms = {} - non_lang_syms_pattern = None - - if bpe_model is not None: - import sentencepiece as spm - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - else: - sp = None - for sample in data: assert 'txt' in sample - txt = sample['txt'].strip() - # TODO(xcsong): This is a dirty workaround for whisper tokernizer, - # refine it in the future - if whisper_tokenizer is not None: - sample['label'] = whisper_tokenizer.encode(txt) - yield sample - if non_lang_syms_pattern is not None: - parts = non_lang_syms_pattern.split(txt.upper()) - parts = [w for w in parts if len(w.strip()) > 0] - else: - parts = [txt] - - label = [] - tokens = [] - for part in parts: - if part in non_lang_syms: - tokens.append(part) - else: - if bpe_model is not None: - tokens.extend(tokenize_by_bpe_model(sp, part)) - else: - if split_with_space: - part = part.split(" ") - for ch in part: - if ch == ' ': - ch = "▁" - tokens.append(ch) - - for ch in tokens: - if ch in symbol_table: - label.append(symbol_table[ch]) - elif '' in symbol_table: - label.append(symbol_table['']) - + tokens, label = tokenizer.tokenize(sample['txt']) sample['tokens'] = tokens sample['label'] = label yield sample diff --git a/wenet/text/base_tokenizer.py b/wenet/text/base_tokenizer.py new file mode 100644 index 000000000..67055e5c3 --- /dev/null +++ b/wenet/text/base_tokenizer.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import List, Tuple + + +class BaseTokenizer(ABC): + + def tokenize(self, line: str) -> Tuple[List[str], List[int]]: + tokens = self.text2tokens(line) + ids = self.tokens2ids(tokens) + return tokens, ids + + def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: + tokens = self.ids2tokens(ids) + text = self.tokens2text(tokens) + return text, tokens + + @abstractmethod + def text2tokens(self, line: str) -> List[str]: + raise NotImplementedError("abstract method") + + @abstractmethod + def tokens2text(self, tokens: Iterable[str]) -> str: + raise NotImplementedError("abstract method") + + @abstractmethod + def tokens2ids(self, tokens: List[str]) -> List[int]: + raise NotImplementedError("abstract method") + + @abstractmethod + def ids2tokens(self, ids: List[int]) -> List[str]: + raise NotImplementedError("abstract method") + + @abstractmethod + def vocab_size(self) -> int: + raise NotImplementedError("abstract method") diff --git a/wenet/utils/tokenize_utils.py b/wenet/text/tokenize_utils.py similarity index 100% rename from wenet/utils/tokenize_utils.py rename to wenet/text/tokenize_utils.py diff --git a/wenet/text/wenet_tokenizer.py b/wenet/text/wenet_tokenizer.py new file mode 100644 index 000000000..c2eb42abd --- /dev/null +++ b/wenet/text/wenet_tokenizer.py @@ -0,0 +1,87 @@ +import re + +from collections.abc import Iterable +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.tokenize_utils import tokenize_by_bpe_model + + +class WenetTokenizer(BaseTokenizer): + """Wrapper for original wenet tokenize implementation + """ + + def __init__( + self, + symbol_table: Union[str, PathLike, Dict], + bpe_model: Optional[Union[str, PathLike]] = None, + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + ) -> None: + self.non_lang_syms_pattern = None + if non_lang_syms is not None: + self.non_lang_syms_pattern = re.compile( + r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + if not isinstance(symbol_table, Dict): + self.symbol_table = read_symbol_table(symbol_table) + else: + # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + self.symbol_table = symbol_table + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + self.bpe_model = None + if bpe_model is not None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(bpe_model) + self.char_dict = {v: k for k, v in self.symbol_table.items()} + self.split_with_space = split_with_space + self.connect_symbol = connect_symbol + + def text2tokens(self, line: str) -> List[str]: + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + if self.bpe_model is not None: + tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) + else: + if self.split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + return self.connect_symbol.join(tokens) + + def tokens2ids(self, tokens: List[str]) -> List[int]: + ids = [] + for ch in tokens: + if ch in self.symbol_table: + ids.append(self.symbol_table[ch]) + elif '' in self.symbol_table: + ids.append(self.symbol_table['']) + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + content = [self.char_dict[w] for w in ids] + return content + + def vocab_size(self) -> int: + return len(self.char_dict) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py new file mode 100644 index 000000000..bd2290fbc --- /dev/null +++ b/wenet/text/whisper_tokenizer.py @@ -0,0 +1,71 @@ +from collections.abc import Iterable +from os import PathLike +from typing import List, Optional, Tuple, Union +from wenet.text.base_tokenizer import BaseTokenizer +from whisper.tokenizer import get_tokenizer + +from wenet.utils.file_utils import read_non_lang_symbols + + +class WhisperTokenizer(BaseTokenizer): + + def __init__( + self, + multilingual: bool, + num_languages: int = 99, + language: Optional[str] = None, + task: Optional[str] = None, + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + *args, + **kwargs, + ) -> None: + self.tokenizer = get_tokenizer(multilingual=multilingual, + num_languages=num_languages, + language=language, + task=task) + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + # TODO(Mddct): add special tokens, like non_lang_syms + del self.non_lang_syms + self.t2i = {} + self.i2t = {} + for i in range(self.tokenizer.encoding.n_vocab): + unit = str(self.tokenizer.encoding.decode_single_token_bytes(i)) + if len(unit) == 0: + unit = str(i) + unit = unit.replace(" ", "") + # unit = bytes(unit, 'utf-8') + self.t2i[unit] = i + self.i2t[i] = unit + print(len(self.t2i), len(self.i2t)) + assert len(self.t2i) == len(self.i2t) + + def tokenize(self, line: str) -> Tuple[List[str], List[int]]: + ids = self.tokenizer.encoding.encode(line) + text = [self.i2t[d] for d in ids] + return text, ids + + def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: + tokens = [self.i2t[d] for d in ids] + text = self.tokenizer.encoding.decode(ids) + return text, tokens + + def text2tokens(self, line: str) -> List[str]: + return self.tokenize(line)[0] + + def tokens2text(self, tokens: Iterable[str]) -> str: + ids = [self.t2i[t] for t in tokens] + return self.detokenize(ids)[0] + + def tokens2ids(self, tokens: List[str]) -> List[int]: + ids = [self.t2i[t] for t in tokens] + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + return [self.tokenizer.encoding.decode([id]) for id in ids] + + def vocab_size(self) -> int: + return len(self.t2i) diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index 0f9c1f866..d3fadd3d0 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from wenet.utils.tokenize_utils import tokenize_by_bpe_model +from wenet.text.tokenize_utils import tokenize_by_bpe_model from typing import Dict, List, Tuple from collections import deque diff --git a/wenet/utils/file_utils.py b/wenet/utils/file_utils.py index 7b7e516cc..0444c1ac2 100644 --- a/wenet/utils/file_utils.py +++ b/wenet/utils/file_utils.py @@ -39,14 +39,16 @@ def read_non_lang_symbols(non_lang_sym_path): """ if non_lang_sym_path is None: - return None + return {} else: syms = read_lists(non_lang_sym_path) non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") for sym in syms: if non_lang_syms_pattern.fullmatch(sym) is None: + class BadSymbolFormat(Exception): pass + raise BadSymbolFormat( "Non-linguistic symbols should be " "formatted in {xxx}//[xxx], consider" diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py new file mode 100644 index 000000000..81ad6ecd7 --- /dev/null +++ b/wenet/utils/init_tokenizer.py @@ -0,0 +1,16 @@ +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.wenet_tokenizer import WenetTokenizer +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +def init_tokenizer(configs, args, non_lang_syms) -> BaseTokenizer: + if configs.get("whisper", False): + tokenizer = WhisperTokenizer( + multilingual=configs['whisper_conf']['is_multilingual'], + num_languages=configs['whisper_conf']['num_languages']) + else: + tokenizer = WenetTokenizer(args.symbol_table, args.bpe_model, + non_lang_syms, + configs.get('split_with_space', False)) + + return tokenizer diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 4c2aba516..ad3cda18a 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -37,11 +37,9 @@ from deepspeed.utils.zero_to_fp32 import ( convert_zero_checkpoint_to_fp32_state_dict ) -from whisper.tokenizer import get_tokenizer - from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint -from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing @@ -213,11 +211,9 @@ def check_modify_and_save_config(args, configs): input_dim = configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] else: input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - symbol_table = read_symbol_table(args.symbol_table) - vocab_size = len(symbol_table) configs['input_dim'] = input_dim - configs['output_dim'] = configs.get('output_dim', vocab_size) + configs['output_dim'] = configs['vocab_size'] configs['cmvn_file'] = args.cmvn configs['is_json_cmvn'] = True configs['lfmmi_dir'] = args.lfmmi_dir @@ -246,29 +242,18 @@ def init_dataset_and_dataloader(args, configs): cv_conf['spec_trim'] = False cv_conf['shuffle'] = False - symbol_table = read_symbol_table(args.symbol_table) - non_lang_syms = read_non_lang_symbols(args.non_lang_syms) - - # TODO(xcsong): This is a dirty workaround for whisper tokenizer, - # refine it in the future - if configs.get("whisper", False): - logging.info("using whisper tokenizer") - whisper_tok = get_tokenizer( - multilingual=configs['whisper_conf']['is_multilingual'], - num_languages=configs['whisper_conf']['num_languages']) - else: - whisper_tok = None - train_dataset = Dataset(args.data_type, args.train_data, symbol_table, - train_conf, args.bpe_model, non_lang_syms, True, - whisper_tok) + tokenizer = init_tokenizer(configs, args, non_lang_syms) + configs['vocab_size'] = tokenizer.vocab_size() + train_dataset = Dataset(args.data_type, + args.train_data, + tokenizer, + train_conf, + True) cv_dataset = Dataset(args.data_type, args.cv_data, - symbol_table, + tokenizer, cv_conf, - args.bpe_model, - non_lang_syms, - partition=False, - whisper_tokenizer=whisper_tok) + partition=False) # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 From 99cf7d7547e02f7d9f2ffef7af08caf773980ab0 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 25 Nov 2023 20:42:12 +0800 Subject: [PATCH 02/19] [text] fix flake8 --- test/wenet/text/test_whisper_tokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/wenet/text/test_whisper_tokenizer.py b/test/wenet/text/test_whisper_tokenizer.py index 370a2c52e..4e2e7607c 100644 --- a/test/wenet/text/test_whisper_tokenizer.py +++ b/test/wenet/text/test_whisper_tokenizer.py @@ -24,8 +24,8 @@ def test_tokeniz(whisper_tokenizer): for i, text in enumerate(texts): tokens, ids = tokenizer.tokenize(text) assert len(tokens) == len(ids) - assert (all([h == r for h, r in zip(tokens, expected[i]["tokens"])])) - assert (all([h == r for h, r in zip(ids, expected[i]["ids"])])) + assert (all((h == r for h, r in zip(tokens, expected[i]["tokens"])))) + assert (all((h == r for h, r in zip(ids, expected[i]["ids"])))) def test_detokenize(whisper_tokenizer): @@ -45,7 +45,7 @@ def test_detokenize(whisper_tokenizer): text, tokens = tokenize.detokenize(input) assert len(tokens) == len(expected[i]["tokens"]) assert text == expected[i]["labels"] - assert all([h == r for h, r in zip(tokens, expected[i]["tokens"])]) + assert all((h == r for h, r in zip(tokens, expected[i]["tokens"]))) def test_consistency(whisper_tokenizer): From 5694565e7a3ef4e9ced0f79a521809191c49d721 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 25 Nov 2023 20:44:03 +0800 Subject: [PATCH 03/19] [text] fix lint --- wenet/utils/train_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index ad3cda18a..ea13907ad 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -244,8 +244,8 @@ def init_dataset_and_dataloader(args, configs): tokenizer = init_tokenizer(configs, args, non_lang_syms) configs['vocab_size'] = tokenizer.vocab_size() - train_dataset = Dataset(args.data_type, - args.train_data, + train_dataset = Dataset(args.data_type, + args.train_data, tokenizer, train_conf, True) From 2418d79bcc4496187922f95c4ad30d4aa8cda768 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 25 Nov 2023 21:04:34 +0800 Subject: [PATCH 04/19] [text] fix unit --- wenet/text/base_tokenizer.py | 3 +-- wenet/text/wenet_tokenizer.py | 3 +-- wenet/text/whisper_tokenizer.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/wenet/text/base_tokenizer.py b/wenet/text/base_tokenizer.py index 67055e5c3..6dedf6475 100644 --- a/wenet/text/base_tokenizer.py +++ b/wenet/text/base_tokenizer.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable from typing import List, Tuple @@ -20,7 +19,7 @@ def text2tokens(self, line: str) -> List[str]: raise NotImplementedError("abstract method") @abstractmethod - def tokens2text(self, tokens: Iterable[str]) -> str: + def tokens2text(self, tokens: List[str]) -> str: raise NotImplementedError("abstract method") @abstractmethod diff --git a/wenet/text/wenet_tokenizer.py b/wenet/text/wenet_tokenizer.py index c2eb42abd..a65f28c2b 100644 --- a/wenet/text/wenet_tokenizer.py +++ b/wenet/text/wenet_tokenizer.py @@ -1,6 +1,5 @@ import re -from collections.abc import Iterable from os import PathLike from typing import Dict, List, Optional, Union from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols @@ -67,7 +66,7 @@ def text2tokens(self, line: str) -> List[str]: tokens.append(ch) return tokens - def tokens2text(self, tokens: Iterable[str]) -> str: + def tokens2text(self, tokens: List[str]) -> str: return self.connect_symbol.join(tokens) def tokens2ids(self, tokens: List[str]) -> List[int]: diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index bd2290fbc..f1adc41a1 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -1,4 +1,3 @@ -from collections.abc import Iterable from os import PathLike from typing import List, Optional, Tuple, Union from wenet.text.base_tokenizer import BaseTokenizer @@ -56,7 +55,7 @@ def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: def text2tokens(self, line: str) -> List[str]: return self.tokenize(line)[0] - def tokens2text(self, tokens: Iterable[str]) -> str: + def tokens2text(self, tokens: List[str]) -> str: ids = [self.t2i[t] for t in tokens] return self.detokenize(ids)[0] From 3552b949ce788e1b2d5f5aa44e6a95614e418583 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 26 Nov 2023 23:37:21 +0800 Subject: [PATCH 05/19] [text] add bpe tokenizer and char tokenizer --- test/wenet/text/test_whisper_tokenizer.py | 2 +- wenet/text/bpe_tokenizer.py | 42 +++++++++++++ wenet/text/char_tokenizer.py | 74 +++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 wenet/text/bpe_tokenizer.py create mode 100644 wenet/text/char_tokenizer.py diff --git a/test/wenet/text/test_whisper_tokenizer.py b/test/wenet/text/test_whisper_tokenizer.py index 4e2e7607c..84681b8ec 100644 --- a/test/wenet/text/test_whisper_tokenizer.py +++ b/test/wenet/text/test_whisper_tokenizer.py @@ -9,7 +9,7 @@ def whisper_tokenizer(request): return WhisperTokenizer(is_multilingual) -def test_tokeniz(whisper_tokenizer): +def test_tokenize(whisper_tokenizer): tokenizer = whisper_tokenizer texts = ["life is short, i use wenet"] diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py new file mode 100644 index 000000000..6dffa0989 --- /dev/null +++ b/wenet/text/bpe_tokenizer.py @@ -0,0 +1,42 @@ +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.text.char_tokenizer import CharTokenizer +from wenet.text.tokenize_utils import tokenize_by_bpe_model + + +class BpeTokenizer(CharTokenizer): + + def __init__( + self, + bpe_model: PathLike, + symbol_table: Union[str, PathLike, Dict], + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + unk='', + ) -> None: + super().__init__(symbol_table, non_lang_syms, split_with_space, + connect_symbol, unk) + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(bpe_model) + + def text2tokens(self, line: str) -> List[str]: + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) + return tokens + + def tokens2text(self, tokens: List[str]) -> str: + text = super().tokens2text(tokens) + return text.replace("▁", ' ').strip() diff --git a/wenet/text/char_tokenizer.py b/wenet/text/char_tokenizer.py new file mode 100644 index 000000000..a4151f04e --- /dev/null +++ b/wenet/text/char_tokenizer.py @@ -0,0 +1,74 @@ +from os import PathLike +from typing import Dict, List, Optional, Union +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.text.base_tokenizer import BaseTokenizer + + +class CharTokenizer(BaseTokenizer): + + def __init__( + self, + symbol_table: Union[str, PathLike, Dict], + non_lang_syms: Optional[Union[str, PathLike, List]] = None, + split_with_space: bool = False, + connect_symbol: str = '', + unk='', + ) -> None: + self.non_lang_syms_pattern = None + if non_lang_syms is not None: + self.non_lang_syms_pattern = re.compile( + r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + if not isinstance(symbol_table, Dict): + self.symbol_table = read_symbol_table(symbol_table) + else: + # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} + self.symbol_table = symbol_table + if not isinstance(non_lang_syms, List): + self.non_lang_syms = read_non_lang_symbols(non_lang_syms) + else: + # non_lang_syms=["{NOISE}"] + self.non_lang_syms = non_lang_syms + self.char_dict = {v: k for k, v in self.symbol_table.items()} + self.split_with_space = split_with_space + self.connect_symbol = connect_symbol + self.unk = unk + + def text2tokens(self, line: str) -> List[str]: + line = line.strip() + if self.non_lang_syms_pattern is not None: + parts = self.non_lang_syms_pattern.split(line.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [line] + + tokens = [] + for part in parts: + if part in self.non_lang_syms: + tokens.append(part) + else: + if self.split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + return tokens + + def tokens2text(self, tokens: List[str]) -> str: + return self.connect_symbol.join(tokens) + + def tokens2ids(self, tokens: List[str]) -> List[int]: + ids = [] + for ch in tokens: + if ch in self.symbol_table: + ids.append(self.symbol_table[ch]) + elif self.unk in self.symbol_table: + ids.append(self.symbol_table[self.unk]) + return ids + + def ids2tokens(self, ids: List[int]) -> List[str]: + content = [self.char_dict[w] for w in ids] + return content + + def vocab_size(self) -> int: + return len(self.char_dict) From 9912df9fe89b8e576e83c15b54d6279ce219bc51 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 26 Nov 2023 23:40:30 +0800 Subject: [PATCH 06/19] [text] add char tokenizer unit test --- test/wenet/text/test_char_tokenizer.py | 133 +++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 test/wenet/text/test_char_tokenizer.py diff --git a/test/wenet/text/test_char_tokenizer.py b/test/wenet/text/test_char_tokenizer.py new file mode 100644 index 000000000..24c9098a4 --- /dev/null +++ b/test/wenet/text/test_char_tokenizer.py @@ -0,0 +1,133 @@ +import pytest +from wenet.text.char_tokenizer import CharTokenizer + + +@pytest.fixture(params=["test/resources/aishell2.words.txt"]) +def char_tokenizer(request): + symbol_table = request.param + return CharTokenizer(symbol_table) + + +def test_tokenize(char_tokenizer): + tokenizer = char_tokenizer + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [4932, 80, 1059, 1375] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [656, 4540, 117, 1059, 1375] + }, { + "tokens": [ + 'b', 'i', 'n', 'b', 'i', 'n', '▁', 'i', 's', '▁', 'a', 'l', 's', + 'o', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e' + ], + "label": [ + 9, 23, 33, 9, 23, 33, 1, 23, 43, 1, 7, 29, 43, 35, 1, 21, 7, 33, + 13, 43, 35, 31, 15 + ] + }, { + "tokens": [ + 'l', 'i', 'f', 'e', '▁', 'i', 's', '▁', 's', 'h', 'o', 'r', 't', + '▁', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', 'e', 't' + ], + "label": [ + 29, 23, 17, 15, 1, 23, 43, 1, 43, 21, 35, 41, 46, 1, 23, 1, 48, 43, + 15, 1, 52, 15, 33, 15, 46 + ] + }, { + "tokens": [ + '超', '哥', '▁', 'i', 's', '▁', 't', 'h', 'e', '▁', 'm', 'o', 's', + 't', '▁', 'h', 'a', 'n', 'd', 's', 'o', 'm', 'e', '▁', '吧' + ], + "label": [ + 4395, 736, 1, 23, 43, 1, 46, 21, 15, 1, 31, 35, 43, 46, 1, 21, 7, + 33, 13, 43, 35, 31, 15, 1, 647 + ] + }, { + "tokens": [ + '人', '生', '苦', '短', 'i', '▁', 'u', 's', 'e', '▁', 'w', 'e', 'n', + 'e', 't' + ], + "label": + [155, 2980, 3833, 3178, 23, 1, 48, 43, 15, 1, 52, 15, 33, 15, 46] + }, { + "tokens": [ + '人', '生', '苦', '短', 'I', '▁', 'U', 'S', 'E', '▁', 'W', 'E', 'N', + 'E', 'T' + ], + "label": + [155, 2980, 3833, 3178, 24, 1, 49, 44, 16, 1, 53, 16, 34, 16, 47] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', '▁', + 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, 43, + 11, 21, 1, 33 + ] + }, { + "tokens": [ + 'z', 'h', 'e', 'n', 'd', 'o', 'n', 'g', '▁', 'i', 's', 't', '▁', + 's', 'o', '▁', 's', 'c', 'h', 'ö', 'n' + ], + "label": [ + 58, 21, 15, 33, 13, 35, 33, 19, 1, 23, 43, 46, 1, 43, 35, 1, 43, + 11, 21, 1, 33 + ] + }, { + "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], + "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] + }] + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +def test_detokenize(char_tokenizer): + tokenizer = char_tokenizer + idss = [ + [4932, 80, 1059, 1375], + [656, 4540, 117, 1059, 1375], + ] + + refs = [{ + "txt": "震东好帅", + "tokens": ['震', '东', '好', '帅'], + }, { + "txt": "吴迪也好帅", + "tokens": ['吴', '迪', '也', '好', '帅'], + }] + results = [] + for ids in idss: + txt, tokens = tokenizer.detokenize(ids) + results.append({"tokens": tokens, "txt": txt}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert len(hyp["txt"]) == len(ref["txt"]) + assert (all(h == r for h, r in zip(hyp["txt"], ref["txt"]))) + + +def test_vocab_size(char_tokenizer): + assert char_tokenizer.vocab_size() == 5235 From 266a4fabcc7659a4355038027857c0f9eba8dada Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 26 Nov 2023 23:58:14 +0800 Subject: [PATCH 07/19] [text] add bpe tokenizer unit test --- test/wenet/text/test_bpe_tokenizer.py | 88 +++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 test/wenet/text/test_bpe_tokenizer.py diff --git a/test/wenet/text/test_bpe_tokenizer.py b/test/wenet/text/test_bpe_tokenizer.py new file mode 100644 index 000000000..c095fe752 --- /dev/null +++ b/test/wenet/text/test_bpe_tokenizer.py @@ -0,0 +1,88 @@ +import pytest +from wenet.text.bpe_tokenizer import BpeTokenizer + + +@pytest.fixture(params=[[ + "test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel" +]]) +def bpe_tokenizer(request): + symbol_table, bpe_model = request.param + return BpeTokenizer(bpe_model, symbol_table) + + +def test_tokenize(bpe_tokenizer): + tokenizer = bpe_tokenizer + txts = [ + "震东好帅", + " 吴迪也好帅 ", + "binbin is also handsome", + " life is short i use wenet ", + "超哥 is the most handsome 吧", + " 人生苦短i use wenet ", + "人生苦短I USE WENET", + "zhendong ist so schön", + " zhendong ist so schön ", + "It's okay", + ] + refs = [{ + "tokens": ['震', '东', '好', '帅'], + "label": [1, 1, 1, 1] + }, { + "tokens": ['吴', '迪', '也', '好', '帅'], + "label": [1, 1, 1, 1, 1] + }, { + "tokens": ['▁B', 'IN', 'B', 'IN', '▁IS', '▁ALSO', "▁HANDSOME"], + "label": [347, 2216, 346, 2216, 2332, 143, 1990] + }, { + "tokens": ['▁LIFE', '▁IS', '▁SHORT', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [2568, 2332, 3968, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['超', '哥', '▁IS', '▁THE', '▁MOST', '▁HANDSOME', '吧'], + "label": [1, 1, 2332, 4435, 2860, 1990, 1] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": ['人', '生', '苦', '短', '▁I', '▁USE', '▁WE', 'NE', 'T'], + "label": [1, 1, 1, 1, 2152, 4699, 4833, 2926, 4366] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": + ['▁', 'Z', 'HEN', 'DO', 'NG', '▁IS', 'T', '▁SO', '▁SCH', 'Ö', 'N'], + "label": [3, 4999, 2048, 1248, 2960, 2332, 4366, 4072, 3844, 1, 2901] + }, { + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], + "label": [2344, 2, 3790, 3010, 2418, 4979] + }] + + results = [] + for line in txts: + tokens, label = tokenizer.tokenize(line) + results.append({"tokens": tokens, "label": label}) + + for (hyp, ref) in zip(results, refs): + assert (len(hyp["tokens"]) == len(ref["tokens"])) + assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) + assert (len(hyp["label"]) == len(ref["label"])) + assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) + + +def test_detokenize(bpe_tokenizer): + tokenizer = bpe_tokenizer + # TODO(Mddct): more unit test + ids = [2344, 2, 3790, 3010, 2418, 4979] + expected = { + 'txt': "IT'S OKAY", + "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'] + } + txt, tokens = tokenizer.detokenize(ids) + assert txt == expected['txt'] + assert (all(h == r for h, r in zip(tokens, expected['tokens']))) + + +def test_vocab_size(bpe_tokenizer): + assert bpe_tokenizer.vocab_size() == 5002 From c2ecc7c1762b97ced04146143aaea84e3e822fac Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 13:43:06 +0800 Subject: [PATCH 08/19] [text] add WhisperTokenizer for test_whisper.py --- test/wenet/whisper/test_whisper.py | 29 +++++++++++-------- wenet/text/whisper_tokenizer.py | 2 +- wenet/utils/common.py | 5 ++-- ...onvert_whisper_to_wenet_config_and_ckpt.py | 8 +++-- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index dd7a6ad9b..06b046859 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -12,10 +12,10 @@ import numpy as np import torch.nn.functional as F -from whisper.tokenizer import get_tokenizer from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim from wenet.dataset.processor import compute_log_mel_spectrogram +from wenet.text.whisper_tokenizer import WhisperTokenizer from wenet.transformer.embedding import WhisperPositionalEncoding from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import ( convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units @@ -108,19 +108,19 @@ def test_model(model, audio_path): checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual, num_languages=num_languages, - language=language, task=task) + tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, + language=language, task=task) convert_to_wenet_state_dict( checkpoint["model_state_dict"], os.path.join(download_root, 'wenet_whisper.pt') ) convert_to_wenet_units( - tokenizer, + tokenizer.tokenizer, os.path.join(download_root, 'units.txt') ) convert_to_wenet_yaml( - tokenizer, checkpoint["dims"], + tokenizer.tokenizer, checkpoint["dims"], os.path.join(download_root, 'train.yaml') ) with open("{}/train.yaml".format(download_root), 'r') as fin: @@ -132,7 +132,7 @@ def test_model(model, audio_path): wenet_model.eval() with torch.no_grad(): - dummy_tokens = tokenizer.encode("WeNet x OpenAI") + _, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI") # 3. Forward whisper.encoder mel1 = whisper.log_mel_spectrogram( @@ -173,8 +173,8 @@ def test_model(model, audio_path): rtol=1e-7, atol=1e-10) # 4. Forward whisper.decoder - whisper_tokens = torch.tensor(list(tokenizer.sot_sequence) - + [tokenizer.no_timestamps] + whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence) + + [tokenizer.tokenizer.no_timestamps] + dummy_tokens, dtype=torch.long).unsqueeze(0) # (B=1, 9) whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens) @@ -273,10 +273,15 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( - tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, - task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa - no_timestamp=True, language=language, use_prev=False - ) + tokenizer.tokenizer, + torch.tensor([dummy_tokens], dtype=torch.long), + ignore_id=-1, + task_id=tokenizer.tokenizer.transcribe + if task == "transcribe" else tokenizer.tokenizer.translate, + no_timestamp=True, + language=language, + use_prev=False) + L = wenet_tokens.size(1) tgt_mask = ~make_pad_mask( torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index f1adc41a1..61056a0e7 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -1,7 +1,6 @@ from os import PathLike from typing import List, Optional, Tuple, Union from wenet.text.base_tokenizer import BaseTokenizer -from whisper.tokenizer import get_tokenizer from wenet.utils.file_utils import read_non_lang_symbols @@ -18,6 +17,7 @@ def __init__( *args, **kwargs, ) -> None: + from whisper.tokenizer import get_tokenizer self.tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages, language=language, diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 824e77581..c73741a5b 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -20,9 +20,6 @@ import torch from torch.nn.utils.rnn import pad_sequence -from whisper.tokenizer import LANGUAGES as WhiserLanguages - -WHISPER_LANGS = tuple(WhiserLanguages.keys()) IGNORE_ID = -1 @@ -176,6 +173,8 @@ def add_whisper_tokens( ys_out (torch.Tensor) : (B, Lmax + ?) """ + from whisper.tokenizer import LANGUAGES as WhiserLanguages + WHISPER_LANGS = tuple(WhiserLanguages.keys()) if use_prev: # i.e., hotword list _prev = [tokenizer.sot_prev] diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 45c36d970..443aa07b2 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -89,6 +89,9 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['decoder_conf']['key_bias'] = False configs['decoder_conf']['activation_type'] = "gelu" + configs['ctc_conf'] = {} + configs['ctc_conf']['ctc_blank_id'] = 50362 # + configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 @@ -208,13 +211,12 @@ def convert_to_wenet_units(tokenizer, units_txt_path): n_vocab = tokenizer.encoding.n_vocab with open(units_txt_path, "+w") as f: for i in range(n_vocab): - unit = tokenizer.encoding.decode([i]) + unit = str(tokenizer.encoding.decode_single_token_bytes(i)) if len(unit) == 0: unit = str(i) print("can not decode id {}, convert to str({})".format(i, i)) unit = unit.replace(" ", "") - unit = bytes(unit, 'utf-8') - f.write("{} {}\n".format(str(unit), i)) + f.write("{} {}\n".format(unit, i)) f.flush() From 55da48a2b7eff46b6f860cb547548087c63b4f81 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 16:55:49 +0800 Subject: [PATCH 09/19] [text] revert wenet/utils/file_utils.py --- wenet/utils/file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/utils/file_utils.py b/wenet/utils/file_utils.py index 0444c1ac2..264e37e73 100644 --- a/wenet/utils/file_utils.py +++ b/wenet/utils/file_utils.py @@ -39,7 +39,7 @@ def read_non_lang_symbols(non_lang_sym_path): """ if non_lang_sym_path is None: - return {} + return [] else: syms = read_lists(non_lang_sym_path) non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") From 50422fc89e92863db78311594958aa25539416c1 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 18:30:48 +0800 Subject: [PATCH 10/19] [text] add consistency for char and bpe unit --- test/wenet/text/test_bpe_tokenizer.py | 6 ++++++ test/wenet/text/test_char_tokenizer.py | 8 ++++++++ test/wenet/text/test_whisper_tokenizer.py | 2 +- wenet/text/whisper_tokenizer.py | 1 - 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test/wenet/text/test_bpe_tokenizer.py b/test/wenet/text/test_bpe_tokenizer.py index c095fe752..852e2744e 100644 --- a/test/wenet/text/test_bpe_tokenizer.py +++ b/test/wenet/text/test_bpe_tokenizer.py @@ -86,3 +86,9 @@ def test_detokenize(bpe_tokenizer): def test_vocab_size(bpe_tokenizer): assert bpe_tokenizer.vocab_size() == 5002 + + +def test_consistency(bpe_tokenizer): + text = "WENET IS GREAT" + assert text == bpe_tokenizer.tokens2text(bpe_tokenizer.text2tokens(text)) + assert text == bpe_tokenizer.detokenize(bpe_tokenizer.tokenize(text)[1])[0] diff --git a/test/wenet/text/test_char_tokenizer.py b/test/wenet/text/test_char_tokenizer.py index 24c9098a4..fcf4f0762 100644 --- a/test/wenet/text/test_char_tokenizer.py +++ b/test/wenet/text/test_char_tokenizer.py @@ -131,3 +131,11 @@ def test_detokenize(char_tokenizer): def test_vocab_size(char_tokenizer): assert char_tokenizer.vocab_size() == 5235 + + +def test_consistency(char_tokenizer): + text = "大家都好帅" + + assert text == char_tokenizer.tokens2text(char_tokenizer.text2tokens(text)) + assert text == char_tokenizer.detokenize( + char_tokenizer.tokenize(text)[1])[0] diff --git a/test/wenet/text/test_whisper_tokenizer.py b/test/wenet/text/test_whisper_tokenizer.py index 84681b8ec..bee4f6cc8 100644 --- a/test/wenet/text/test_whisper_tokenizer.py +++ b/test/wenet/text/test_whisper_tokenizer.py @@ -49,7 +49,7 @@ def test_detokenize(whisper_tokenizer): def test_consistency(whisper_tokenizer): - text = "whisper powered by wenet, it't great" + text = "whisper powered by wenet, it's great" assert text == whisper_tokenizer.tokens2text( whisper_tokenizer.text2tokens(text)) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index 61056a0e7..2b752bec0 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -39,7 +39,6 @@ def __init__( # unit = bytes(unit, 'utf-8') self.t2i[unit] = i self.i2t[i] = unit - print(len(self.t2i), len(self.i2t)) assert len(self.t2i) == len(self.i2t) def tokenize(self, line: str) -> Tuple[List[str], List[int]]: From cf754ffbed54382ab82ebc93f756f4c716104cfc Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 20:07:56 +0800 Subject: [PATCH 11/19] [text] merge main --- test/wenet/whisper/test_whisper.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index e0e77dad7..de25746f3 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -273,22 +273,10 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( -<<<<<<< HEAD - tokenizer.tokenizer, - torch.tensor([dummy_tokens], dtype=torch.long), - ignore_id=-1, - task_id=tokenizer.tokenizer.transcribe - if task == "transcribe" else tokenizer.tokenizer.translate, - no_timestamp=True, - language=language, - use_prev=False) - -======= configs['model_conf']['special_tokens'], torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, task=task, no_timestamp=True, language=language, use_prev=False ) ->>>>>>> main L = wenet_tokens.size(1) tgt_mask = ~make_pad_mask( torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L) From bd24277d9f502cef53728f46302f5b1320cdb035 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 21:03:43 +0800 Subject: [PATCH 12/19] [text] add symbol table --- wenet/bin/recognize.py | 4 ++-- wenet/bin/train.py | 8 ++++++-- wenet/text/base_tokenizer.py | 8 ++++++-- wenet/text/char_tokenizer.py | 18 +++++++++++------- wenet/text/wenet_tokenizer.py | 18 +++++++++++------- wenet/text/whisper_tokenizer.py | 5 ++++- wenet/utils/init_tokenizer.py | 24 +++++++++++++++++++----- wenet/utils/train_utils.py | 6 ++---- 8 files changed, 61 insertions(+), 30 deletions(-) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index c493f782a..1a842f8fe 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -206,7 +206,7 @@ def main(): test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(configs, args, non_lang_syms) + tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, args.non_lang_syms) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, @@ -225,7 +225,7 @@ def main(): context_graph = None if 'decoding-graph' in args.context_bias_mode: - context_graph = ContextGraph(args.context_list_path, symbol_table, + context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table, args.bpe_model, args.context_graph_score) # TODO(Dinghao Zhou): Support RNN-T related decoding diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 0d8b4c094..33b31a6fd 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -28,6 +28,7 @@ from wenet.utils.executor import Executor from wenet.utils.config import override_config from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.train_utils import (add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, add_trace_args, init_distributed, @@ -73,15 +74,18 @@ def main(): if len(args.override_config) > 0: configs = override_config(configs, args.override_config) + # init tokenizer + tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms) + # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args) # Get dataset & dataloader train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ - init_dataset_and_dataloader(args, configs) + init_dataset_and_dataloader(args, configs, tokenizer) # Do some sanity checks and save config to arsg.model_dir - configs = check_modify_and_save_config(args, configs) + configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table) # Init asr model from configs model, configs = init_model(args, configs) diff --git a/wenet/text/base_tokenizer.py b/wenet/text/base_tokenizer.py index 6dedf6475..c96993309 100644 --- a/wenet/text/base_tokenizer.py +++ b/wenet/text/base_tokenizer.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod -from typing import List, Tuple +from abc import ABC, abstractmethod, abstractproperty +from typing import Dict, List, Tuple class BaseTokenizer(ABC): @@ -33,3 +33,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]: @abstractmethod def vocab_size(self) -> int: raise NotImplementedError("abstract method") + + @abstractproperty + def symbol_table(self) -> Dict[str, int]: + raise NotImplementedError("abstract method") diff --git a/wenet/text/char_tokenizer.py b/wenet/text/char_tokenizer.py index a4151f04e..5e47be41b 100644 --- a/wenet/text/char_tokenizer.py +++ b/wenet/text/char_tokenizer.py @@ -19,16 +19,16 @@ def __init__( self.non_lang_syms_pattern = re.compile( r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") if not isinstance(symbol_table, Dict): - self.symbol_table = read_symbol_table(symbol_table) + self._symbol_table = read_symbol_table(symbol_table) else: # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} - self.symbol_table = symbol_table + self._symbol_table = symbol_table if not isinstance(non_lang_syms, List): self.non_lang_syms = read_non_lang_symbols(non_lang_syms) else: # non_lang_syms=["{NOISE}"] self.non_lang_syms = non_lang_syms - self.char_dict = {v: k for k, v in self.symbol_table.items()} + self.char_dict = {v: k for k, v in self._symbol_table.items()} self.split_with_space = split_with_space self.connect_symbol = connect_symbol self.unk = unk @@ -60,10 +60,10 @@ def tokens2text(self, tokens: List[str]) -> str: def tokens2ids(self, tokens: List[str]) -> List[int]: ids = [] for ch in tokens: - if ch in self.symbol_table: - ids.append(self.symbol_table[ch]) - elif self.unk in self.symbol_table: - ids.append(self.symbol_table[self.unk]) + if ch in self._symbol_table: + ids.append(self._symbol_table[ch]) + elif self.unk in self._symbol_table: + ids.append(self._symbol_table[self.unk]) return ids def ids2tokens(self, ids: List[int]) -> List[str]: @@ -72,3 +72,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]: def vocab_size(self) -> int: return len(self.char_dict) + + @property + def symbol_table(self) -> Dict[str, int]: + return self._symbol_table diff --git a/wenet/text/wenet_tokenizer.py b/wenet/text/wenet_tokenizer.py index a65f28c2b..23ed81483 100644 --- a/wenet/text/wenet_tokenizer.py +++ b/wenet/text/wenet_tokenizer.py @@ -24,10 +24,10 @@ def __init__( self.non_lang_syms_pattern = re.compile( r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") if not isinstance(symbol_table, Dict): - self.symbol_table = read_symbol_table(symbol_table) + self._symbol_table = read_symbol_table(symbol_table) else: # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} - self.symbol_table = symbol_table + self._symbol_table = symbol_table if not isinstance(non_lang_syms, List): self.non_lang_syms = read_non_lang_symbols(non_lang_syms) else: @@ -38,7 +38,7 @@ def __init__( import sentencepiece as spm self.bpe_model = spm.SentencePieceProcessor() self.bpe_model.load(bpe_model) - self.char_dict = {v: k for k, v in self.symbol_table.items()} + self.char_dict = {v: k for k, v in self._symbol_table.items()} self.split_with_space = split_with_space self.connect_symbol = connect_symbol @@ -72,10 +72,10 @@ def tokens2text(self, tokens: List[str]) -> str: def tokens2ids(self, tokens: List[str]) -> List[int]: ids = [] for ch in tokens: - if ch in self.symbol_table: - ids.append(self.symbol_table[ch]) - elif '' in self.symbol_table: - ids.append(self.symbol_table['']) + if ch in self._symbol_table: + ids.append(self._symbol_table[ch]) + elif '' in self._symbol_table: + ids.append(self._symbol_table['']) return ids def ids2tokens(self, ids: List[int]) -> List[str]: @@ -84,3 +84,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]: def vocab_size(self) -> int: return len(self.char_dict) + + @property + def symbol_table(self) -> Dict[str, int]: + return self._symbol_table diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index 2b752bec0..3cc32270d 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -1,5 +1,5 @@ from os import PathLike -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from wenet.text.base_tokenizer import BaseTokenizer from wenet.utils.file_utils import read_non_lang_symbols @@ -67,3 +67,6 @@ def ids2tokens(self, ids: List[int]) -> List[str]: def vocab_size(self) -> int: return len(self.t2i) + + def symbol_table(self) -> Dict[str, int]: + return self.t2i diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index 81ad6ecd7..fcafba7a4 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -1,16 +1,30 @@ from wenet.text.base_tokenizer import BaseTokenizer -from wenet.text.wenet_tokenizer import WenetTokenizer +from wenet.text.bpe_tokenizer import BpeTokenizer +from wenet.text.char_tokenizer import CharTokenizer from wenet.text.whisper_tokenizer import WhisperTokenizer -def init_tokenizer(configs, args, non_lang_syms) -> BaseTokenizer: +def init_tokenizer(configs, + symbol_table, + bpe_model=None, + non_lang_syms=None) -> BaseTokenizer: + # TODO: + # 1 huggface tokenizer + # 2 paraformer tokenizer + if configs.get("whisper", False): tokenizer = WhisperTokenizer( multilingual=configs['whisper_conf']['is_multilingual'], num_languages=configs['whisper_conf']['num_languages']) + elif bpe_model is None: + tokenizer = CharTokenizer(symbol_table, + non_lang_syms, + split_with_space=configs.get( + 'split_with_space', False)) else: - tokenizer = WenetTokenizer(args.symbol_table, args.bpe_model, - non_lang_syms, - configs.get('split_with_space', False)) + tokenizer = BpeTokenizer(bpe_model, + symbol_table, + split_with_space=configs.get( + 'split_with_space', False)) return tokenizer diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index a53d8e515..54c325f0a 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -39,7 +39,6 @@ ) from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint -from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing @@ -161,7 +160,7 @@ def init_distributed(args): return world_size, local_rank, rank -def check_modify_and_save_config(args, configs): +def check_modify_and_save_config(args, configs, symbol_table): if args.train_engine == "torch_ddp": if args.use_amp: configs["dtype"] = "fp16" @@ -244,7 +243,7 @@ def check_modify_and_save_config(args, configs): return configs -def init_dataset_and_dataloader(args, configs): +def init_dataset_and_dataloader(args, configs, tokenizer): train_conf = configs['dataset_conf'] cv_conf = copy.deepcopy(train_conf) cv_conf['speed_perturb'] = False @@ -253,7 +252,6 @@ def init_dataset_and_dataloader(args, configs): cv_conf['spec_trim'] = False cv_conf['shuffle'] = False - tokenizer = init_tokenizer(configs, args, non_lang_syms) configs['vocab_size'] = tokenizer.vocab_size() train_dataset = Dataset(args.data_type, args.train_data, From 49994bfeb6693cee259c95f29be3c619373ac666 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 21:39:08 +0800 Subject: [PATCH 13/19] [text] add init_tokenizer unit test --- test/wenet/utils/test_init_tokenizer.py | 41 +++++++++++++++++++++++++ wenet/bin/train.py | 3 +- 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 test/wenet/utils/test_init_tokenizer.py diff --git a/test/wenet/utils/test_init_tokenizer.py b/test/wenet/utils/test_init_tokenizer.py new file mode 100644 index 000000000..1b5b4af94 --- /dev/null +++ b/test/wenet/utils/test_init_tokenizer.py @@ -0,0 +1,41 @@ +import pytest + +from wenet.utils.init_tokenizer import init_tokenizer + + +def test_init_whisper_tokenizer(): + # TODO(Mddct): add configs generator + configs = {} + configs['whisper'] = True + configs['whisper_conf'] = {} + configs['whisper_conf']['is_multilingual'] = False + configs['whisper_conf']['num_languages'] = 99 + + tokenizer = init_tokenizer(configs, None) + text = "whisper powered by wenet, it's great" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize("symbol_table_path", [ + "test/resources/aishell2.words.txt", +]) +def test_init_char_tokenizer(symbol_table_path): + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path) + + text = "大家都好帅" + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) + + +@pytest.mark.parametrize( + "symbol_table_path, bpe_model", + [("test/resources/librispeech.words.txt", + "test/resources/librispeech.train_960_unigram5000.bpemodel")]) +def test_init_bpe_tokenizer(symbol_table_path, bpe_model): + + configs = {} + tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + text = "WENET IT'S GREAT" + + assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 33b31a6fd..7d4a9a11f 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -75,7 +75,8 @@ def main(): configs = override_config(configs, args.override_config) # init tokenizer - tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms) + tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, + non_lang_syms) # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args) From 301af9eb9296db3ae3e9b902fd4aebbde0fbd48a Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 27 Nov 2023 23:42:02 +0800 Subject: [PATCH 14/19] [text] uncomment --- wenet/dataset/processor.py | 2 +- wenet/utils/common.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 68e500e45..a769eba8e 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -27,7 +27,7 @@ from torch.nn.utils.rnn import pad_sequence from wenet.text.base_tokenizer import BaseTokenizer -# torchaudio.utils.sox_utils.set_buffer_size(16500) +torchaudio.utils.sox_utils.set_buffer_size(16500) AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 7384bad57..5da1fb341 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -20,6 +20,9 @@ import torch from torch.nn.utils.rnn import pad_sequence +from whisper.tokenizer import LANGUAGES as WhiserLanguages + +WHISPER_LANGS = tuple(WhiserLanguages.keys()) IGNORE_ID = -1 @@ -173,8 +176,6 @@ def add_whisper_tokens( ys_out (torch.Tensor) : (B, Lmax + ?) """ - from whisper.tokenizer import LANGUAGES as WhiserLanguages - WHISPER_LANGS = tuple(WhiserLanguages.keys()) if use_prev: # i.e., hotword list _prev = [special_tokens["sot_prev"]] From a13dedf5ad16fc98152d0a825c3c3158ac25890c Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 00:24:33 +0800 Subject: [PATCH 15/19] [text] fix bpe model in multiprocess env --- wenet/text/bpe_tokenizer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py index 6dffa0989..de1b504a9 100644 --- a/wenet/text/bpe_tokenizer.py +++ b/wenet/text/bpe_tokenizer.py @@ -17,11 +17,19 @@ def __init__( ) -> None: super().__init__(symbol_table, non_lang_syms, split_with_space, connect_symbol, unk) - import sentencepiece as spm - self.bpe_model = spm.SentencePieceProcessor() - self.bpe_model.load(bpe_model) + self._model = bpe_model + # NOTE(Mddct): multiprocessing.Process() issues + # don't build sp here + self.bpe_model = None + + def _build_sp(self): + if self.bpe_model is None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(self._model) def text2tokens(self, line: str) -> List[str]: + self._build_sp() line = line.strip() if self.non_lang_syms_pattern is not None: parts = self.non_lang_syms_pattern.split(line.upper()) @@ -38,5 +46,6 @@ def text2tokens(self, line: str) -> List[str]: return tokens def tokens2text(self, tokens: List[str]) -> str: + self._build_sp() text = super().tokens2text(tokens) return text.replace("▁", ' ').strip() From ec2d83821ca7cbc7a973b915eecc1a753ac9018c Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 12:01:49 +0800 Subject: [PATCH 16/19] [text] fix whisper tokenzier in multiprocess env --- wenet/text/whisper_tokenizer.py | 52 +++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index 3cc32270d..a5ad7e9c6 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -17,11 +17,14 @@ def __init__( *args, **kwargs, ) -> None: - from whisper.tokenizer import get_tokenizer - self.tokenizer = get_tokenizer(multilingual=multilingual, - num_languages=num_languages, - language=language, - task=task) + # NOTE(Mddct): don't build here, pickle issues + self.tokenizer = None + # TODO: we don't need this in future + self.multilingual = multilingual + self.num_languages = num_languages + self.language = language + self.task = task + if not isinstance(non_lang_syms, List): self.non_lang_syms = read_non_lang_symbols(non_lang_syms) else: @@ -29,44 +32,61 @@ def __init__( self.non_lang_syms = non_lang_syms # TODO(Mddct): add special tokens, like non_lang_syms del self.non_lang_syms - self.t2i = {} - self.i2t = {} - for i in range(self.tokenizer.encoding.n_vocab): - unit = str(self.tokenizer.encoding.decode_single_token_bytes(i)) - if len(unit) == 0: - unit = str(i) - unit = unit.replace(" ", "") - # unit = bytes(unit, 'utf-8') - self.t2i[unit] = i - self.i2t[i] = unit - assert len(self.t2i) == len(self.i2t) + + def _build_tiktoken(self): + if self.tokenizer is None: + from whisper.tokenizer import get_tokenizer + self.tokenizer = get_tokenizer(multilingual=self.multilingual, + num_languages=self.num_languages, + language=self.language, + task=self.task) + self.t2i = {} + self.i2t = {} + for i in range(self.tokenizer.encoding.n_vocab): + unit = str( + self.tokenizer.encoding.decode_single_token_bytes(i)) + if len(unit) == 0: + unit = str(i) + unit = unit.replace(" ", "") + # unit = bytes(unit, 'utf-8') + self.t2i[unit] = i + self.i2t[i] = unit + assert len(self.t2i) == len(self.i2t) def tokenize(self, line: str) -> Tuple[List[str], List[int]]: + self._build_tiktoken() ids = self.tokenizer.encoding.encode(line) text = [self.i2t[d] for d in ids] return text, ids def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: + self._build_tiktoken() tokens = [self.i2t[d] for d in ids] text = self.tokenizer.encoding.decode(ids) return text, tokens def text2tokens(self, line: str) -> List[str]: + self._build_tiktoken() return self.tokenize(line)[0] def tokens2text(self, tokens: List[str]) -> str: + self._build_tiktoken() ids = [self.t2i[t] for t in tokens] return self.detokenize(ids)[0] def tokens2ids(self, tokens: List[str]) -> List[int]: + self._build_tiktoken() ids = [self.t2i[t] for t in tokens] return ids def ids2tokens(self, ids: List[int]) -> List[str]: + self._build_tiktoken() return [self.tokenizer.encoding.decode([id]) for id in ids] def vocab_size(self) -> int: + self._build_tiktoken() return len(self.t2i) def symbol_table(self) -> Dict[str, int]: + self._build_tiktoken() return self.t2i From 51a10fad657218237992f4b22a156dbcaea1e544 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 12:15:13 +0800 Subject: [PATCH 17/19] [text] add test unit parallel for bpe and whisper --- test/wenet/text/test_parallel.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 test/wenet/text/test_parallel.py diff --git a/test/wenet/text/test_parallel.py b/test/wenet/text/test_parallel.py new file mode 100644 index 000000000..28a0f37b2 --- /dev/null +++ b/test/wenet/text/test_parallel.py @@ -0,0 +1,42 @@ +from functools import partial +from multiprocessing import Pool +from wenet.text.base_tokenizer import BaseTokenizer + +from wenet.text.bpe_tokenizer import BpeTokenizer +from wenet.text.whisper_tokenizer import WhisperTokenizer + + +def consistency(tokenizer: BaseTokenizer, line: str) -> str: + return tokenizer.detokenize(tokenizer.tokenize(line)[1])[0] + + +def test_whisper_tokenzier_parallel(): + + inputs = ["it's ok", "wenet is simple", "test for new io"] + tokenizer = WhisperTokenizer(False) + + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) + + +def test_bpe_tokenzier_parallel(): + + symbol_table_path = "test/resources/librispeech.words.txt" + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + + inputs = ["WENR IS SIMPLE", "GOOD"] + tokenizer = BpeTokenizer(bpe_model, symbol_table_path) + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) From 1dc2d7990489954c0cda4b3952da4ecc8f3954e3 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 13:20:07 +0800 Subject: [PATCH 18/19] [text] fix none type in test_whisper.py --- test/wenet/whisper/test_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index de25746f3..d71716869 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -110,6 +110,7 @@ def test_model(model, audio_path): num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, language=language, task=task) + tokenizer._build_tiktoken() convert_to_wenet_state_dict( checkpoint["model_state_dict"], From f1099a8fdcb4b4e7ef5dd4b655cc118ce0a3dd2e Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 21:10:01 +0800 Subject: [PATCH 19/19] [text] all work --- wenet/bin/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 7d4a9a11f..c8e957439 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -76,7 +76,7 @@ def main(): # init tokenizer tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, - non_lang_syms) + args.non_lang_syms) # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args)