Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c0c15bb
Add test for SentencePiece not adding special tokens to strings
beneyal Feb 22, 2022
08d47a7
Add SentencePieceStringConversionMixin to fix issue 15003
beneyal Feb 22, 2022
6c82c09
Fix conversion from tokens to string for most SentencePiece tokenizers
beneyal Feb 22, 2022
7bfe422
Fix MarianTokenizer, adjust SentencePiece test to accomodate vocab
beneyal Feb 22, 2022
cb4c824
Fix DebertaV2Tokenizer
beneyal Feb 22, 2022
e6795b9
Ignore LayoutXLMTokenizer in SentencePiece string conversion test
beneyal Feb 22, 2022
17e0921
Run 'make style' and 'make quality'
beneyal Feb 22, 2022
c29bcdd
Clean convert_tokens_to_string test
beneyal Feb 24, 2022
e11cf2a
Remove commented out code
beneyal Feb 24, 2022
fb1c273
Improve robustness of convert_tokens_to_string test
beneyal Feb 24, 2022
91413e5
Inline and remove SentencePieceStringConversionMixin
beneyal Feb 24, 2022
0743ae0
Run 'make style' and 'make quality'
beneyal Feb 24, 2022
bad0f43
Revert removal of space in convert_tokens_to_string
beneyal Feb 25, 2022
8cb264e
Remove redundant import
Feb 25, 2022
c14ebcb
Revert test text to original
Feb 25, 2022
f021c2d
Uncomment the lowercasing of the reverse_text variable
Feb 25, 2022
cee809c
Mimic Rust tokenizer behavior for tokenizers
beneyal Mar 4, 2022
adc06ff
Fix accidentally skipping test in wrong tokenizer
beneyal Mar 4, 2022
6b3cd77
Add test for equivalent Rust and slow tokenizer behavior
beneyal Mar 4, 2022
e94d85e
Override _decode in BigBirdTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
3273b1a
Override _decode in FNetTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
a89dba6
Override _decode in XLNetTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
cb53d92
Merge 'main' into the 15003 fix branch
beneyal Apr 28, 2022
205e133
Remove unused 're' import
beneyal Apr 28, 2022
c271a63
Update DebertaV2Tokenizer to mimic Rust tokenizer
beneyal Apr 29, 2022
0a8c149
Merge branch 'main' of https://github.com/huggingface/transformers in…
beneyal Nov 2, 2022
073b749
Deberta tokenizer now behaves like Albert and its `convert_tokens_to_…
beneyal Nov 2, 2022
048bc65
Ignore problematic tests in Deberta V2
beneyal Nov 2, 2022
21c3f6d
Add comment on why the Deberta V2 tests are skipped
beneyal Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/transformers/models/albert/tokenization_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,23 @@ def _convert_id_to_token(self, index):
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/models/barthez/tokenization_barthez.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ def _convert_id_to_token(self, index):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
Expand All @@ -278,10 +297,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,17 @@ def _convert_id_to_token(self, index):

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
Comment on lines 152 to +164
Copy link
Contributor

@SaulLu SaulLu May 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose to take the same solution as the tokenizer with a fast version because if we add a fast version later we will be blocked - and the decoding method is quite important for Bert generation. In general I think is easier to take the same behavior for all the tokenizers we are currently changing (and maybe replace the rust test by an hardcoded test).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean like, for example, the convert_tokens_to_string method in AlbertTokenizer?


def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
62 changes: 60 additions & 2 deletions src/transformers/models/big_bird/tokenization_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import os
import re
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -182,8 +183,65 @@ def _convert_id_to_token(self, index):

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))

# Mimic the behavior of the Rust tokenizer:
# No space before [MASK] and [SEP]
if spaces_between_special_tokens:
text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)

if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/models/camembert/tokenization_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ def _convert_id_to_token(self, index):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
Expand All @@ -276,10 +295,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
26 changes: 23 additions & 3 deletions src/transformers/models/deberta_v2/tokenization_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs)
self._tokenizer = SPMTokenizer(
vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
)

@property
def vocab_size(self):
Expand Down Expand Up @@ -291,7 +293,9 @@ class SPMTokenizer:
BPE-dropout.
"""

def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None):
def __init__(
self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
):
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
Expand All @@ -312,6 +316,7 @@ def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[D
# self.vocab['[UNK]'] = 3

self.spm = spm
self.special_tokens = special_tokens

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -339,7 +344,22 @@ def convert_ids_to_tokens(self, ids):

def decode(self, tokens, start=-1, end=-1, raw_text=None):
if raw_text is None:
return self.spm.decode_pieces([t for t in tokens])
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.spm.decode_pieces(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.spm.decode_pieces(current_sub_tokens)
return out_string.strip()
else:
words = self.split_to_words(raw_text)
word_tokens = [self.tokenize(w) for w in words]
Expand Down
62 changes: 61 additions & 1 deletion src/transformers/models/fnet/tokenization_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Tokenization classes for FNet model."""

import os
import re
import unicodedata
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -213,7 +214,66 @@ def _convert_id_to_token(self, index):
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))

# Mimic the behavior of the Rust tokenizer:
# No space after <unk>
if spaces_between_special_tokens:
text = re.sub(r"(<unk>) ", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)

if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,19 @@ def _convert_id_to_token(self, index: int) -> str:
return self.id_to_lang_token[index]
return self.decoder.get(index, self.unk_token)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/marian/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,18 @@ def decode(self, token_ids, **kwargs):

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens)
else:
return self.spm_target.DecodePieces(tokens)
sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
Expand Down
21 changes: 18 additions & 3 deletions src/transformers/models/mbart50/tokenization_mbart50.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,24 @@ def _convert_id_to_token(self, index: int) -> str:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
Loading