diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index dd7a6ad9b5..7f10c83051 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -12,10 +12,10 @@ import numpy as np import torch.nn.functional as F -from whisper.tokenizer import get_tokenizer from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim from wenet.dataset.processor import compute_log_mel_spectrogram +from wenet.text.whisper_tokenizer import WhisperTokenizer from wenet.transformer.embedding import WhisperPositionalEncoding from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import ( convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units @@ -108,8 +108,8 @@ def test_model(model, audio_path): checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual, num_languages=num_languages, - language=language, task=task) + tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages, + language=language, task=task) convert_to_wenet_state_dict( checkpoint["model_state_dict"], @@ -132,7 +132,7 @@ def test_model(model, audio_path): wenet_model.eval() with torch.no_grad(): - dummy_tokens = tokenizer.encode("WeNet x OpenAI") + _, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI") # 3. Forward whisper.encoder mel1 = whisper.log_mel_spectrogram( @@ -173,8 +173,8 @@ def test_model(model, audio_path): rtol=1e-7, atol=1e-10) # 4. Forward whisper.decoder - whisper_tokens = torch.tensor(list(tokenizer.sot_sequence) - + [tokenizer.no_timestamps] + whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence) + + [tokenizer.tokenizer.no_timestamps] + dummy_tokens, dtype=torch.long).unsqueeze(0) # (B=1, 9) whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens) @@ -273,10 +273,15 @@ def test_model(model, audio_path): # 6. Forward wenet.decoder wenet_tokens, _ = add_whisper_tokens( - tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1, - task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa - no_timestamp=True, language=language, use_prev=False - ) + tokenizer, + torch.tensor([dummy_tokens], dtype=torch.long), + ignore_id=-1, + task_id=tokenizer.tokenizer.transcribe + if task == "transcribe" else tokenizer.tokenizer.translate, + no_timestamp=True, + language=language, + use_prev=False) + L = wenet_tokens.size(1) tgt_mask = ~make_pad_mask( torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L) diff --git a/wenet/text/whisper_tokenizer.py b/wenet/text/whisper_tokenizer.py index f1adc41a1c..61056a0e72 100644 --- a/wenet/text/whisper_tokenizer.py +++ b/wenet/text/whisper_tokenizer.py @@ -1,7 +1,6 @@ from os import PathLike from typing import List, Optional, Tuple, Union from wenet.text.base_tokenizer import BaseTokenizer -from whisper.tokenizer import get_tokenizer from wenet.utils.file_utils import read_non_lang_symbols @@ -18,6 +17,7 @@ def __init__( *args, **kwargs, ) -> None: + from whisper.tokenizer import get_tokenizer self.tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages, language=language, diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 824e775811..50a37e80b4 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -20,9 +20,6 @@ import torch from torch.nn.utils.rnn import pad_sequence -from whisper.tokenizer import LANGUAGES as WhiserLanguages - -WHISPER_LANGS = tuple(WhiserLanguages.keys()) IGNORE_ID = -1 @@ -176,6 +173,9 @@ def add_whisper_tokens( ys_out (torch.Tensor) : (B, Lmax + ?) """ + tokenizer = whisper.tokenizer + from whisper.tokenizer import LANGUAGES as WhiserLanguages + WHISPER_LANGS = tuple(WhiserLanguages.keys()) if use_prev: # i.e., hotword list _prev = [tokenizer.sot_prev] diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 45c36d9709..2e8e32836f 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -39,11 +39,12 @@ _cpath_ = sys.path[0] sys.path.remove(_cpath_) -from whisper.tokenizer import get_tokenizer +from wenet.test.tokenizer import WhisperTokenizer sys.path.insert(0, _cpath_) def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): + tokenizer = tokenizer.tokenizer configs = {} configs['whisper'] = True configs['whisper_conf'] = {} @@ -205,6 +206,7 @@ def convert_to_wenet_units(tokenizer, units_txt_path): It does not play any role in the tokenization process, which is carried out by the tokenizer of openai-whisper. """ + tokenizer = tokenizer.tokenizer n_vocab = tokenizer.encoding.n_vocab with open(units_txt_path, "+w") as f: for i in range(n_vocab): @@ -234,7 +236,7 @@ def main(): checkpoint = torch.load(args.whisper_ckpt, map_location="cpu") multilingual = checkpoint["dims"]['n_vocab'] >= 51865 num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) - tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages) + tokenizer = WhisperTokenizer(multilingual=multilingual, num_languages=num_languages) convert_to_wenet_state_dict( checkpoint["model_state_dict"],