Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][text] add tokens #2201

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/wenet/text/test_bpe_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,25 @@ def test_consistency(bpe_tokenizer):
text = "WENET IS GREAT"
assert text == bpe_tokenizer.tokens2text(bpe_tokenizer.text2tokens(text))
assert text == bpe_tokenizer.detokenize(bpe_tokenizer.tokenize(text)[1])[0]


def test_add_tokens(bpe_tokenizer):
tokenizer = bpe_tokenizer
tokenizer.upper = False

special_tokens = ["<s>", "</s>", "▁wenet", "WENET", "▁WENET"]
tokenizer.add_tokens(special_tokens)

# text = "wenethappy IT'S OKAY wenethappy hawenethappy wenethappy 好"
text = "种 wenet OK HELLO HAWENET WENETSPEECH"
# text = "▁wenet OK HELLO HAWENET SPEECHWENET"
expected = [
'种', '▁wenet', '▁O', 'K', '▁HE', 'LL', 'O', '▁HA', 'WENET', '▁WENET',
'S', 'PE', 'E', 'CH'
]
tokens, labels = tokenizer.tokenize(text)
print(tokens)
print(labels)
assert tokenizer.vocab_size() == 5002 + len(special_tokens)
assert len(tokens) == len(expected)
assert all(h == r for (h, r) in zip(tokens, expected))
176 changes: 0 additions & 176 deletions test/wenet/text/test_wenet_tokenzier.py

This file was deleted.

4 changes: 4 additions & 0 deletions wenet/text/base_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def vocab_size(self) -> int:
@abstractproperty
def symbol_table(self) -> Dict[T, int]:
raise NotImplementedError("abstract method")

@abstractmethod
def add_tokens(self, tokens: List[T]) -> int:
raise NotImplementedError("abstract method")
44 changes: 41 additions & 3 deletions wenet/text/bpe_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from os import PathLike

from typing import Dict, List, Optional, Union

from wenet.text.char_tokenizer import CharTokenizer
from wenet.text.tokenize_utils import tokenize_by_bpe_model

Expand All @@ -14,31 +16,56 @@ def __init__(
split_with_space: bool = False,
connect_symbol: str = '',
unk='<unk>',
upper: bool = True,
) -> None:
super().__init__(symbol_table, non_lang_syms, split_with_space,
connect_symbol, unk)
self._model = bpe_model
# NOTE(Mddct): multiprocessing.Process() issues
# don't build sp here
self.bpe_model = None
# NOTE(Mddct): we can handle proto, see:
# https://github.com/google/sentencepiece/issues/121#issuecomment-400362011
self.bpe_spm = None
self.upper = upper
self.extra_tokens = {}

def _build_sp(self):
import sentencepiece as spm
if self.bpe_model is None:
import sentencepiece as spm
self.bpe_model = spm.SentencePieceProcessor()
self.bpe_model.load(self._model)
self.bpe_model.Load(self._model)
if len(self.extra_tokens) > 0:
from transformers.utils import (sentencepiece_model_pb2_new as
sentencepiece_model_pb2)
self.bpe_spm = sentencepiece_model_pb2.ModelProto()
self.bpe_spm.ParseFromString(
self.bpe_model.serialized_model_proto())
for token_id in sorted(self.extra_tokens.items(),
key=lambda x: x[1]):
new_p = sentencepiece_model_pb2.ModelProto().SentencePiece(
)
new_p.piece = token_id[0]
new_p.score = 0
self.bpe_spm.pieces.append(new_p)

self.bpe_model = spm.SentencePieceProcessor(
model_proto=self.bpe_spm.SerializeToString())

def text2tokens(self, line: str) -> List[str]:
self._build_sp()
line = line.strip()
line = line.upper() if self.upper else line
if self.non_lang_syms_pattern is not None:
parts = self.non_lang_syms_pattern.split(line.upper())
parts = self.non_lang_syms_pattern.split(line)
parts = [w for w in parts if len(w.strip()) > 0]
else:
parts = [line]

tokens = []
for part in parts:
if part == '':
continue
if part in self.non_lang_syms:
tokens.append(part)
else:
Expand All @@ -49,3 +76,14 @@ def tokens2text(self, tokens: List[str]) -> str:
self._build_sp()
text = super().tokens2text(tokens)
return text.replace("▁", ' ').strip()

def add_tokens(self, tokens: List[str]) -> int:
added_tokens = 0
for token in tokens:
token = token.upper() if self.upper else token
if token not in self.symbol_table:
self.symbol_table[token] = len(self.symbol_table)
added_tokens += 1
self.char_dict[len(self.char_dict)] = token
self.extra_tokens[token] = self.symbol_table[token]
return added_tokens
9 changes: 9 additions & 0 deletions wenet/text/char_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,12 @@ def vocab_size(self) -> int:
@property
def symbol_table(self) -> Dict[str, int]:
return self._symbol_table

def add_tokens(self, tokens: List[str]) -> int:
n = 0
for token in tokens:
if token not in self.symbol_table.keys():
self.symbol_table[token] = len(self.symbol_table)
self.char_dict[len(self.char_dict)] = token
n += 1
return n
3 changes: 3 additions & 0 deletions wenet/text/hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ def vocab_size(self) -> int:
def symbol_table(self) -> Dict[Type, int]:
self._build_hugging_face()
return self.t2i

def add_tokens(self, tokens: List[Type]) -> int:
return [0]
2 changes: 1 addition & 1 deletion wenet/text/tokenize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def tokenize_by_bpe_model(sp, txt):
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(txt.upper())
chars = pattern.split(txt)
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
Expand Down
Loading
Loading