From 6a69c7582be8bcbdc55cdac1dc093de1fa90d560 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 29 Nov 2023 01:15:22 +0800 Subject: [PATCH 1/2] [text] fix whisper tokens and others --- test/wenet/text/test_parallel.py | 36 ++++++++++++++++++++++++++++++++ wenet/bin/alignment.py | 9 +++----- wenet/bin/recognize_onnx_gpu.py | 7 +++---- wenet/text/bpe_tokenizer.py | 2 +- wenet/text/whisper_tokenizer.py | 11 ++++++++++ 5 files changed, 54 insertions(+), 11 deletions(-) diff --git a/test/wenet/text/test_parallel.py b/test/wenet/text/test_parallel.py index 28a0f37b2..54fe7a76e 100644 --- a/test/wenet/text/test_parallel.py +++ b/test/wenet/text/test_parallel.py @@ -25,6 +25,23 @@ def test_whisper_tokenzier_parallel(): assert all(h == r for (h, r) in zip(results, inputs)) +def test_whisper_tokenzier_parallel_after_property(): + + inputs = ["it's ok", "wenet is simple", "test for new io"] + tokenizer = WhisperTokenizer(False) + + _ = tokenizer.vocab_size + _ = tokenizer.symbol_table + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) + + def test_bpe_tokenzier_parallel(): symbol_table_path = "test/resources/librispeech.words.txt" @@ -40,3 +57,22 @@ def test_bpe_tokenzier_parallel(): results.sort() assert all(h == r for (h, r) in zip(results, inputs)) + + +def test_bpe_tokenizer_parallel_after_property(): + symbol_table_path = "test/resources/librispeech.words.txt" + bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" + + inputs = ["WENR IS SIMPLE", "GOOD"] + tokenizer = BpeTokenizer(bpe_model, symbol_table_path) + _ = tokenizer.vocab_size + _ = tokenizer.symbol_table + + partial_tokenize = partial(consistency, tokenizer) + with Pool(processes=len(inputs)) as pool: + results = pool.map(partial_tokenize, inputs) + + inputs.sort() + results.sort() + + assert all(h == r for (h, r) in zip(results, inputs)) diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index b8e8728a1..810d56e7b 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -28,10 +28,10 @@ import math from wenet.dataset.dataset import Dataset -from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.ctc_utils import force_align from wenet.utils.common import get_subsample from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer def generator_textgrid(maxtime, lines, output): @@ -183,7 +183,6 @@ def get_labformat(timestamp, subsample): char_dict[int(arr[1])] = arr[0] eos = len(char_dict) - 1 - symbol_table = read_symbol_table(args.dict) # Init dataset and data loader ali_conf = copy.deepcopy(configs['dataset_conf']) @@ -202,14 +201,12 @@ def get_labformat(timestamp, subsample): ali_conf['fbank_conf']['dither'] = 0.0 ali_conf['batch_conf']['batch_type'] = "static" ali_conf['batch_conf']['batch_size'] = args.batch_size - non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + tokenizer = init_tokenizer(ali_conf, args.dict, args.bpe_model, args.non_lang_syms) ali_dataset = Dataset(args.data_type, args.input_file, - symbol_table, + tokenizer, ali_conf, - args.bpe_model, - non_lang_syms, partition=False) ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) diff --git a/wenet/bin/recognize_onnx_gpu.py b/wenet/bin/recognize_onnx_gpu.py index fabc6713d..de06aef57 100644 --- a/wenet/bin/recognize_onnx_gpu.py +++ b/wenet/bin/recognize_onnx_gpu.py @@ -47,8 +47,8 @@ from wenet.dataset.dataset import Dataset from wenet.utils.common import IGNORE_ID -from wenet.utils.file_utils import read_symbol_table from wenet.utils.config import override_config +from wenet.utils.init_tokenizer import init_tokenizer import onnxruntime as rt import multiprocessing @@ -118,7 +118,6 @@ def main(): configs = override_config(configs, args.override_config) reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) - symbol_table = read_symbol_table(args.dict) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 test_conf['filter_conf']['min_length'] = 0 @@ -136,11 +135,11 @@ def main(): test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size + tokenizer = init_tokenizer(test_conf, args.dict, args.bpe_model) test_dataset = Dataset(args.data_type, args.test_data, - symbol_table, + tokenizer, test_conf, - args.bpe_model, partition=False) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py index de1b504a9..8ac507704 100644 --- a/wenet/text/bpe_tokenizer.py +++ b/wenet/text/bpe_tokenizer.py @@ -8,7 +8,7 @@ class BpeTokenizer(CharTokenizer): def __init__( self, - bpe_model: PathLike, + bpe_model: Union[PathLike, str], symbol_table: Union[str, PathLike, Dict], non_lang_syms: Optional[Union[str, PathLike, List]] = None, split_with_space: bool = False, diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index a5ad7e9c6..cb118a3b6 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -33,6 +33,16 @@ def __init__( # TODO(Mddct): add special tokens, like non_lang_syms del self.non_lang_syms + def __getstate__(self): + state = self.__dict__.copy() + del state['tokenizer'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + recovery = {'tokenizer': None} + self.__dict__.update(recovery) + def _build_tiktoken(self): if self.tokenizer is None: from whisper.tokenizer import get_tokenizer @@ -87,6 +97,7 @@ def vocab_size(self) -> int: self._build_tiktoken() return len(self.t2i) + @property def symbol_table(self) -> Dict[str, int]: self._build_tiktoken() return self.t2i From 9389c74ca877dccef2b15921876cc8b2c1620ea9 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 29 Nov 2023 11:24:45 +0800 Subject: [PATCH 2/2] [text] fix --- examples/vkw2021/s0/local/vkw_kws_results.py | 10 +++++----- tools/onnx2horizonbin.py | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/vkw2021/s0/local/vkw_kws_results.py b/examples/vkw2021/s0/local/vkw_kws_results.py index 45ac39e7d..5fc9fcd88 100644 --- a/examples/vkw2021/s0/local/vkw_kws_results.py +++ b/examples/vkw2021/s0/local/vkw_kws_results.py @@ -24,8 +24,8 @@ from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset -from wenet.transformer.asr_model import init_asr_model -from wenet.utils.checkpoint import load_checkpoint +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.common import get_subsample from wenet.utils.common import remove_duplicates_and_blank @@ -186,11 +186,11 @@ def get_labformat_frames(timestamp, subsample, char_dict): cv_conf['speed_perturb'] = False cv_conf['spec_aug'] = False + tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) cv_dataset = Dataset(args.data_type, args.input_data, - symbol_table, + tokenizer, cv_conf, - None, partition=False) cv_data_loader = DataLoader(cv_dataset, @@ -205,7 +205,7 @@ def get_labformat_frames(timestamp, subsample, char_dict): print("word_unit_list has the size of %d" % (len(word_unit_list))) # Init asr model from configs - model = init_asr_model(configs) + model, configs = init_model(args, configs) load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') diff --git a/tools/onnx2horizonbin.py b/tools/onnx2horizonbin.py index a94b647fb..e0db5e137 100755 --- a/tools/onnx2horizonbin.py +++ b/tools/onnx2horizonbin.py @@ -51,6 +51,7 @@ from wenet.utils.checkpoint import load_checkpoint from wenet.utils.file_utils import read_symbol_table from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer from wenet.bin.export_onnx_cpu import to_numpy from wenet.bin.export_onnx_bpu import export_encoder, export_ctc @@ -80,9 +81,9 @@ def save_data(tensor, dirs, prefix): def make_calibration_data(enc, args, conf): conf['shuffle'] = True logger.info(conf) + tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) dataset = Dataset( - "shard", args.cali_datalist, args.symbol_table, conf, - bpe_model=args.bpe_model, non_lang_syms=None, partition=False) + "shard", args.cali_datalist, tokenizer, conf, partition=False) dataloader = DataLoader(dataset, batch_size=None, num_workers=0) subsampling = enc.embed.subsampling_rate @@ -148,9 +149,9 @@ def make_calibration_data(enc, args, conf): def check_wer(enc, ctc, args, conf): conf['shuffle'] = False + tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) dataset = Dataset( - "shard", args.wer_datalist, args.symbol_table, conf, - bpe_model=args.bpe_model, non_lang_syms=None, partition=False) + "shard", args.wer_datalist, tokenizer, conf, partition=False) dataloader = DataLoader(dataset, batch_size=None, num_workers=0) char_dict = {v: k for k, v in args.symbol_table.items()} eos = len(char_dict) - 1