Skip to content

Commit

Permalink
[text] add symbol table
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 27, 2023
1 parent cf754ff commit bd24277
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 30 deletions.
4 changes: 2 additions & 2 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def main():
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size

tokenizer = init_tokenizer(configs, args, non_lang_syms)
tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, args.non_lang_syms)
test_dataset = Dataset(args.data_type,
args.test_data,
tokenizer,
Expand All @@ -225,7 +225,7 @@ def main():

context_graph = None
if 'decoding-graph' in args.context_bias_mode:
context_graph = ContextGraph(args.context_list_path, symbol_table,
context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table,
args.bpe_model, args.context_graph_score)

# TODO(Dinghao Zhou): Support RNN-T related decoding
Expand Down
8 changes: 6 additions & 2 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from wenet.utils.executor import Executor
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.train_utils import (add_model_args, add_dataset_args,
add_ddp_args, add_deepspeed_args,
add_trace_args, init_distributed,
Expand Down Expand Up @@ -73,15 +74,18 @@ def main():
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)

# init tokenizer
tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, non_lang_syms)

# Init env for ddp OR deepspeed
world_size, local_rank, rank = init_distributed(args)

# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
init_dataset_and_dataloader(args, configs, tokenizer)

# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
configs = check_modify_and_save_config(args, configs, tokenizer.symbol_table)

# Init asr model from configs
model, configs = init_model(args, configs)
Expand Down
8 changes: 6 additions & 2 deletions wenet/text/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Tuple


class BaseTokenizer(ABC):
Expand Down Expand Up @@ -33,3 +33,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]:
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError("abstract method")

@abstractproperty
def symbol_table(self) -> Dict[str, int]:
raise NotImplementedError("abstract method")
18 changes: 11 additions & 7 deletions wenet/text/char_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def __init__(
self.non_lang_syms_pattern = re.compile(
r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
if not isinstance(symbol_table, Dict):
self.symbol_table = read_symbol_table(symbol_table)
self._symbol_table = read_symbol_table(symbol_table)
else:
# symbol_table = {"我": 1, "是": 2, "{NOISE}": 3}
self.symbol_table = symbol_table
self._symbol_table = symbol_table
if not isinstance(non_lang_syms, List):
self.non_lang_syms = read_non_lang_symbols(non_lang_syms)
else:
# non_lang_syms=["{NOISE}"]
self.non_lang_syms = non_lang_syms
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.char_dict = {v: k for k, v in self._symbol_table.items()}
self.split_with_space = split_with_space
self.connect_symbol = connect_symbol
self.unk = unk
Expand Down Expand Up @@ -60,10 +60,10 @@ def tokens2text(self, tokens: List[str]) -> str:
def tokens2ids(self, tokens: List[str]) -> List[int]:
ids = []
for ch in tokens:
if ch in self.symbol_table:
ids.append(self.symbol_table[ch])
elif self.unk in self.symbol_table:
ids.append(self.symbol_table[self.unk])
if ch in self._symbol_table:
ids.append(self._symbol_table[ch])
elif self.unk in self._symbol_table:
ids.append(self._symbol_table[self.unk])
return ids

def ids2tokens(self, ids: List[int]) -> List[str]:
Expand All @@ -72,3 +72,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]:

def vocab_size(self) -> int:
return len(self.char_dict)

@property
def symbol_table(self) -> Dict[str, int]:
return self._symbol_table
18 changes: 11 additions & 7 deletions wenet/text/wenet_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(
self.non_lang_syms_pattern = re.compile(
r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
if not isinstance(symbol_table, Dict):
self.symbol_table = read_symbol_table(symbol_table)
self._symbol_table = read_symbol_table(symbol_table)
else:
# symbol_table = {"我": 1, "是": 2, "{NOISE}": 3}
self.symbol_table = symbol_table
self._symbol_table = symbol_table
if not isinstance(non_lang_syms, List):
self.non_lang_syms = read_non_lang_symbols(non_lang_syms)
else:
Expand All @@ -38,7 +38,7 @@ def __init__(
import sentencepiece as spm
self.bpe_model = spm.SentencePieceProcessor()
self.bpe_model.load(bpe_model)
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.char_dict = {v: k for k, v in self._symbol_table.items()}
self.split_with_space = split_with_space
self.connect_symbol = connect_symbol

Expand Down Expand Up @@ -72,10 +72,10 @@ def tokens2text(self, tokens: List[str]) -> str:
def tokens2ids(self, tokens: List[str]) -> List[int]:
ids = []
for ch in tokens:
if ch in self.symbol_table:
ids.append(self.symbol_table[ch])
elif '<unk>' in self.symbol_table:
ids.append(self.symbol_table['<unk>'])
if ch in self._symbol_table:
ids.append(self._symbol_table[ch])
elif '<unk>' in self._symbol_table:
ids.append(self._symbol_table['<unk>'])
return ids

def ids2tokens(self, ids: List[int]) -> List[str]:
Expand All @@ -84,3 +84,7 @@ def ids2tokens(self, ids: List[int]) -> List[str]:

def vocab_size(self) -> int:
return len(self.char_dict)

@property
def symbol_table(self) -> Dict[str, int]:
return self._symbol_table
5 changes: 4 additions & 1 deletion wenet/text/whisper_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import PathLike
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from wenet.text.base_tokenizer import BaseTokenizer

from wenet.utils.file_utils import read_non_lang_symbols
Expand Down Expand Up @@ -67,3 +67,6 @@ def ids2tokens(self, ids: List[int]) -> List[str]:

def vocab_size(self) -> int:
return len(self.t2i)

def symbol_table(self) -> Dict[str, int]:
return self.t2i
24 changes: 19 additions & 5 deletions wenet/utils/init_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
from wenet.text.base_tokenizer import BaseTokenizer
from wenet.text.wenet_tokenizer import WenetTokenizer
from wenet.text.bpe_tokenizer import BpeTokenizer
from wenet.text.char_tokenizer import CharTokenizer
from wenet.text.whisper_tokenizer import WhisperTokenizer


def init_tokenizer(configs, args, non_lang_syms) -> BaseTokenizer:
def init_tokenizer(configs,
symbol_table,
bpe_model=None,
non_lang_syms=None) -> BaseTokenizer:
# TODO:
# 1 huggface tokenizer
# 2 paraformer tokenizer

if configs.get("whisper", False):
tokenizer = WhisperTokenizer(
multilingual=configs['whisper_conf']['is_multilingual'],
num_languages=configs['whisper_conf']['num_languages'])
elif bpe_model is None:
tokenizer = CharTokenizer(symbol_table,
non_lang_syms,
split_with_space=configs.get(
'split_with_space', False))
else:
tokenizer = WenetTokenizer(args.symbol_table, args.bpe_model,
non_lang_syms,
configs.get('split_with_space', False))
tokenizer = BpeTokenizer(bpe_model,
symbol_table,
split_with_space=configs.get(
'split_with_space', False))

return tokenizer
6 changes: 2 additions & 4 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
)
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing


Expand Down Expand Up @@ -161,7 +160,7 @@ def init_distributed(args):
return world_size, local_rank, rank


def check_modify_and_save_config(args, configs):
def check_modify_and_save_config(args, configs, symbol_table):
if args.train_engine == "torch_ddp":
if args.use_amp:
configs["dtype"] = "fp16"
Expand Down Expand Up @@ -244,7 +243,7 @@ def check_modify_and_save_config(args, configs):
return configs


def init_dataset_and_dataloader(args, configs):
def init_dataset_and_dataloader(args, configs, tokenizer):
train_conf = configs['dataset_conf']
cv_conf = copy.deepcopy(train_conf)
cv_conf['speed_perturb'] = False
Expand All @@ -253,7 +252,6 @@ def init_dataset_and_dataloader(args, configs):
cv_conf['spec_trim'] = False
cv_conf['shuffle'] = False

tokenizer = init_tokenizer(configs, args, non_lang_syms)
configs['vocab_size'] = tokenizer.vocab_size()
train_dataset = Dataset(args.data_type,
args.train_data,
Expand Down

0 comments on commit bd24277

Please sign in to comment.