Skip to content

Commit

Permalink
[text] add WhisperTokenizer for convert whisper and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 27, 2023
1 parent 266a4fa commit 23459fa
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
25 changes: 15 additions & 10 deletions test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import numpy as np
import torch.nn.functional as F

from whisper.tokenizer import get_tokenizer
from whisper.audio import N_FFT, HOP_LENGTH, N_SAMPLES, N_FRAMES, pad_or_trim

from wenet.dataset.processor import compute_log_mel_spectrogram
from wenet.text.whisper_tokenizer import WhisperTokenizer
from wenet.transformer.embedding import WhisperPositionalEncoding
from wenet.whisper.convert_whisper_to_wenet_config_and_ckpt import (
convert_to_wenet_yaml, convert_to_wenet_state_dict, convert_to_wenet_units
Expand Down Expand Up @@ -108,8 +108,8 @@ def test_model(model, audio_path):
checkpoint = torch.load("{}/{}.pt".format(download_root, model), map_location="cpu")
multilingual = checkpoint["dims"]['n_vocab'] >= 51865
num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual)
tokenizer = get_tokenizer(multilingual, num_languages=num_languages,
language=language, task=task)
tokenizer = WhisperTokenizer(multilingual, num_languages=num_languages,
language=language, task=task)

convert_to_wenet_state_dict(
checkpoint["model_state_dict"],
Expand All @@ -132,7 +132,7 @@ def test_model(model, audio_path):
wenet_model.eval()

with torch.no_grad():
dummy_tokens = tokenizer.encode("WeNet x OpenAI")
_, dummy_tokens = tokenizer.tokenize("WeNet x OpenAI")

# 3. Forward whisper.encoder
mel1 = whisper.log_mel_spectrogram(
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_model(model, audio_path):
rtol=1e-7, atol=1e-10)

# 4. Forward whisper.decoder
whisper_tokens = torch.tensor(list(tokenizer.sot_sequence)
+ [tokenizer.no_timestamps]
whisper_tokens = torch.tensor(list(tokenizer.tokenizer.sot_sequence)
+ [tokenizer.tokenizer.no_timestamps]
+ dummy_tokens,
dtype=torch.long).unsqueeze(0) # (B=1, 9)
whisper_decoder_embed = whisper_model.decoder.token_embedding(whisper_tokens)
Expand Down Expand Up @@ -273,10 +273,15 @@ def test_model(model, audio_path):

# 6. Forward wenet.decoder
wenet_tokens, _ = add_whisper_tokens(
tokenizer, torch.tensor([dummy_tokens], dtype=torch.long), ignore_id=-1,
task_id=tokenizer.transcribe if task == "transcribe" else tokenizer.translate, # noqa
no_timestamp=True, language=language, use_prev=False
)
tokenizer,
torch.tensor([dummy_tokens], dtype=torch.long),
ignore_id=-1,
task_id=tokenizer.tokenizer.transcribe
if task == "transcribe" else tokenizer.tokenizer.translate,
no_timestamp=True,
language=language,
use_prev=False)

L = wenet_tokens.size(1)
tgt_mask = ~make_pad_mask(
torch.tensor([L], dtype=torch.long), L).unsqueeze(1) # (B=1, 1, L)
Expand Down
2 changes: 1 addition & 1 deletion wenet/text/whisper_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from os import PathLike
from typing import List, Optional, Tuple, Union
from wenet.text.base_tokenizer import BaseTokenizer
from whisper.tokenizer import get_tokenizer

from wenet.utils.file_utils import read_non_lang_symbols

Expand All @@ -18,6 +17,7 @@ def __init__(
*args,
**kwargs,
) -> None:
from whisper.tokenizer import get_tokenizer
self.tokenizer = get_tokenizer(multilingual=multilingual,
num_languages=num_languages,
language=language,
Expand Down
6 changes: 3 additions & 3 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from whisper.tokenizer import LANGUAGES as WhiserLanguages

WHISPER_LANGS = tuple(WhiserLanguages.keys())
IGNORE_ID = -1


Expand Down Expand Up @@ -176,6 +173,9 @@ def add_whisper_tokens(
ys_out (torch.Tensor) : (B, Lmax + ?)
"""
tokenizer = whisper.tokenizer
from whisper.tokenizer import LANGUAGES as WhiserLanguages
WHISPER_LANGS = tuple(WhiserLanguages.keys())
if use_prev:
# i.e., hotword list
_prev = [tokenizer.sot_prev]
Expand Down
6 changes: 4 additions & 2 deletions wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@

_cpath_ = sys.path[0]
sys.path.remove(_cpath_)
from whisper.tokenizer import get_tokenizer
from wenet.test.tokenizer import WhisperTokenizer
sys.path.insert(0, _cpath_)


def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):
tokenizer = tokenizer.tokenizer
configs = {}
configs['whisper'] = True
configs['whisper_conf'] = {}
Expand Down Expand Up @@ -205,6 +206,7 @@ def convert_to_wenet_units(tokenizer, units_txt_path):
It does not play any role in the tokenization process,
which is carried out by the tokenizer of openai-whisper.
"""
tokenizer = tokenizer.tokenizer
n_vocab = tokenizer.encoding.n_vocab
with open(units_txt_path, "+w") as f:
for i in range(n_vocab):
Expand Down Expand Up @@ -234,7 +236,7 @@ def main():
checkpoint = torch.load(args.whisper_ckpt, map_location="cpu")
multilingual = checkpoint["dims"]['n_vocab'] >= 51865
num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual)
tokenizer = get_tokenizer(multilingual=multilingual, num_languages=num_languages)
tokenizer = WhisperTokenizer(multilingual=multilingual, num_languages=num_languages)

convert_to_wenet_state_dict(
checkpoint["model_state_dict"],
Expand Down

0 comments on commit 23459fa

Please sign in to comment.