diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index ce4385e478b2..8a884ab2932b 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -192,6 +192,7 @@ def extract(self, model_type, **kwargs) -> tuple[dict[str, int], list[tuple]]: AddedToken(token, normalized=False, special=special) for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0]) ] + kwargs["_spm_precompiled_charsmap"] = getattr(self.proto.normalizer_spec, "precompiled_charsmap", None) return kwargs @@ -635,6 +636,54 @@ class SpmConverter(Converter): SpmExtractor = SentencePieceExtractor special_tokens = {} + @staticmethod + def build_tokenizer_from_spm_proto(proto, vocab, merges=None): + """ + Similar to convert_from_spm method, but used only when there is no `model_type` class, i.e. there is no matching class in `TOKENIZERS_MAPPING` and we just create a tokenizer instead of extracting stuff from the sentencepiece file + """ + byte_fallback = proto.trainer_spec.byte_fallback + unk_piece = proto.trainer_spec.unk_piece + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + + # model + if isinstance(vocab, dict): + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges or [], + unk_token=unk_piece, + fuse_unk=True, + byte_fallback=byte_fallback, + dropout=None, + ) + ) + elif isinstance(vocab, list) and vocab and isinstance(vocab[0], (tuple, list)): + tokenizer = Tokenizer( + Unigram( + vocab=vocab, + unk_id=proto.trainer_spec.unk_id, + byte_fallback=byte_fallback, + ) + ) + else: + return None + + # normalizer + _normalizers = [normalizers.Replace(" ", "▁")] + if precompiled_charsmap: + _normalizers.insert(0, normalizers.Precompiled(precompiled_charsmap)) + tokenizer.normalizer = normalizers.Sequence(_normalizers) + + # decoder + if byte_fallback: + tokenizer.decoder = decoders.Sequence( + [decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()] + ) + else: + tokenizer.decoder = decoders.Sequence([decoders.Replace("▁", " ")]) + + return tokenizer + @classmethod def convert_from_spm(cls, vocab=None, **kwargs): """ diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 2abcc62bd2b2..691e1205b56d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -171,7 +171,6 @@ ("lighton_ocr", "Qwen2TokenizerFast" if is_tokenizers_available() else None), ("lilt", "RobertaTokenizer" if is_tokenizers_available() else None), ("longformer", "RobertaTokenizer" if is_tokenizers_available() else None), - ("longt5", "T5Tokenizer" if is_tokenizers_available() else None), ("luke", "LukeTokenizer"), ("lxmert", "LxmertTokenizer" if is_tokenizers_available() else None), ("m2m_100", "M2M100Tokenizer" if is_sentencepiece_available() else None), @@ -342,9 +341,11 @@ MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS: set[str] = { "arctic", "deepseek_vl", + "deepseek_vl_v2", "deepseek_vl_hybrid", "fuyu", "hyperclovax_vlm", + "internlm2", "janus", "jamba", "llava", @@ -706,6 +707,12 @@ def from_pretrained( or tokenizer_class_from_name(tokenizer_config_class + "Fast") is not None ) ) + + # V5: Skip remote tokenizer for custom models with incorrect hub tokenizer class + if has_remote_code and config_model_type in MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS: + has_remote_code = False + tokenizer_auto_map = None + if has_remote_code: # V5: Always prefer fast tokenizer (index 1), fallback to slow (index 0) if tokenizer_auto_map[1] is not None: diff --git a/src/transformers/models/bert/tokenization_bert.py b/src/transformers/models/bert/tokenization_bert.py index 9c68ad916e5f..47cb639fab40 100644 --- a/src/transformers/models/bert/tokenization_bert.py +++ b/src/transformers/models/bert/tokenization_bert.py @@ -48,7 +48,7 @@ class BertTokenizer(TokenizersBackend): Args: vocab (`str` or `dict[str, int]`, *optional*): Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`. - do_lower_case (`bool`, *optional*, defaults to `False`): + do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. unk_token (`str`, *optional*, defaults to `"[UNK]"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this @@ -79,7 +79,7 @@ class BertTokenizer(TokenizersBackend): def __init__( self, vocab: str | dict[str, int] | None = None, - do_lower_case: bool = False, + do_lower_case: bool = True, unk_token: str = "[UNK]", sep_token: str = "[SEP]", pad_token: str = "[PAD]", diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py index f324c14eaae4..6d1951cd3783 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py @@ -13,7 +13,7 @@ # limitations under the License. """Tokenization class for Blenderbot.""" -from tokenizers import Tokenizer, decoders, pre_tokenizers, processors +from tokenizers import Tokenizer, decoders, pre_tokenizers from tokenizers.models import BPE from ...tokenization_utils_base import AddedToken @@ -170,12 +170,6 @@ def __init__( add_prefix_space=add_prefix_space, **kwargs, ) - self._tokenizer.post_processor = processors.RobertaProcessing( - sep=(str(eos_token), self.eos_token_id), - cls=(str(bos_token), self.bos_token_id), - add_prefix_space=add_prefix_space, - trim_offsets=True, - ) __all__ = ["BlenderbotTokenizer"] diff --git a/src/transformers/models/gemma/tokenization_gemma.py b/src/transformers/models/gemma/tokenization_gemma.py index 26323e5b953e..85f0a74bb4c8 100644 --- a/src/transformers/models/gemma/tokenization_gemma.py +++ b/src/transformers/models/gemma/tokenization_gemma.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tokenizers import Tokenizer, decoders, normalizers +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers from tokenizers.models import BPE from ...tokenization_utils_tokenizers import TokenizersBackend @@ -88,6 +88,9 @@ def __init__( byte_fallback=True, ) ) + self._tokenizer.pre_tokenizer = pre_tokenizers.Split( + pattern=" ", behavior="merged_with_previous", invert=False + ) self._tokenizer.decoder = decoders.Sequence( [decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()] @@ -102,9 +105,5 @@ def __init__( **kwargs, ) - def _unk_id(self) -> int: - # Align with historical Gemma convention: pad, eos, bos, unk - return 3 - __all__ = ["GemmaTokenizer"] diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox.py index 2387791cc13a..325c6225f49f 100644 --- a/src/transformers/models/gpt_neox/tokenization_gpt_neox.py +++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox.py @@ -127,7 +127,7 @@ def __init__( self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( add_prefix_space=add_prefix_space, trim_offsets=trim_offsets ) - self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True) + self._tokenizer.decoder = decoders.ByteLevel() super().__init__( errors=errors, diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index 7435ef3c43cd..770877ed8347 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -16,7 +16,7 @@ from collections.abc import Callable import torch -from tokenizers import Tokenizer +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import Unigram from torch import nn @@ -46,28 +46,76 @@ def __init__( eos_token="", unk_token="", pad_token="", + _spm_precompiled_charsmap=None, extra_ids=100, additional_special_tokens=None, vocab=None, vocab_file=None, **kwargs, ): - super().__init__( + self._extra_ids = extra_ids + + # Handle extra_ids and additional_special_tokens + if additional_special_tokens is not None: + extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] + elif extra_ids > 0 and extra_ids != len(extra_tokens): + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + else: + extra_tokens = [f"" for i in range(extra_ids)] + additional_special_tokens = extra_tokens + + # LASR vocab structure: =0, =1, =2, then regular vocab, then extra_ids in reverse + if vocab is not None: + self._vocab_scores = vocab + else: + self._vocab_scores = [ + (str(pad_token), 0.0), + (str(eos_token), 0.0), + (str(unk_token), 0.0), + ("▁", -2.0), # Space token + ] + for i in range(extra_ids - 1, -1, -1): + self._vocab_scores.append((f"", 0.0)) + self._tokenizer = Tokenizer( + Unigram( + self._vocab_scores, + unk_id=3, + byte_fallback=False, + ) + ) + + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap) + + self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.WhitespaceSplit(), + pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True), + ] + ) + self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) + + TokenizersBackend.__init__( eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, extra_ids=extra_ids, additional_special_tokens=additional_special_tokens, - vocab=vocab, - vocab_file=vocab_file, **kwargs, ) - self._tokenizer = Tokenizer( - Unigram( - self._vocab_scores, - unk_id=3, - byte_fallback=False, - ) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.eos_token_id), + ], ) def _decode( diff --git a/src/transformers/models/lasr/tokenization_lasr.py b/src/transformers/models/lasr/tokenization_lasr.py index 0cd45c042f11..f9f59acb7960 100644 --- a/src/transformers/models/lasr/tokenization_lasr.py +++ b/src/transformers/models/lasr/tokenization_lasr.py @@ -21,7 +21,7 @@ import itertools import re -from tokenizers import Tokenizer, decoders, pre_tokenizers, processors +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import Unigram from ...tokenization_utils_tokenizers import TokenizersBackend @@ -76,6 +76,7 @@ def __init__( eos_token="", unk_token="", pad_token="", + _spm_precompiled_charsmap=None, extra_ids=100, additional_special_tokens=None, vocab=None, @@ -119,7 +120,8 @@ def __init__( ) ) - self._tokenizer.normalizer = None + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap) self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ @@ -127,7 +129,6 @@ def __init__( pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True), ] ) - self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) super().__init__( diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py index 7c26dee15297..a6eba73c7389 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50.py +++ b/src/transformers/models/mbart50/tokenization_mbart50.py @@ -85,6 +85,7 @@ class MBart50Tokenizer(TokenizersBackend): def __init__( self, vocab: str | dict | list | None = None, + _spm_precompiled_charsmap: str | None = None, src_lang=None, tgt_lang=None, eos_token="", @@ -158,19 +159,11 @@ def __init__( ) ) - # Set normalizer equivalent to Precompiled + Strip + Replace from tokenizer.json - # When loading from pretrained, this will be overridden by the tokenizer.json config - # When creating from extractor (vocab), this provides equivalent behavior - self._tokenizer.normalizer = normalizers.Sequence( - [ - normalizers.Replace(Regex(r"[\n\r\t]"), " "), # Precompiled converts newlines/tabs to spaces - normalizers.NFKC(), # Precompiled does NFKC normalization - normalizers.Strip(left=False, right=True), # Strip trailing whitespace (matches tokenizer.json) - normalizers.Replace( - Regex(r" {2,}"), "▁" - ), # Replace multiple spaces with underscore (matches tokenizer.json) - ] - ) + normalizers_ = [normalizers.Replace(Regex(r" {2,}"), " ")] + if _spm_precompiled_charsmap is not None: + normalizers_ = [normalizers.Precompiled(_spm_precompiled_charsmap)] + normalizers_ + + self._tokenizer.normalizer = normalizers.Sequence(normalizers_) self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True) self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) diff --git a/src/transformers/models/nllb/tokenization_nllb.py b/src/transformers/models/nllb/tokenization_nllb.py index 7defd36430c0..4a4f4d67e9b9 100644 --- a/src/transformers/models/nllb/tokenization_nllb.py +++ b/src/transformers/models/nllb/tokenization_nllb.py @@ -99,6 +99,7 @@ def __init__( mask_token="", src_lang=None, tgt_lang=None, + _spm_precompiled_charsmap: str | None = None, additional_special_tokens=None, extra_special_tokens=None, legacy_behaviour=False, @@ -139,13 +140,13 @@ def __init__( ) ) - self._tokenizer.normalizer = normalizers.Sequence( - [ - normalizers.Replace(Regex(r"[\n\r\t]"), " "), - normalizers.NFKC(), - normalizers.Replace(Regex(r" {2,}"), " "), - ] - ) + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Precompiled(_spm_precompiled_charsmap), + normalizers.Replace(Regex(r" {2,}"), " "), + ] + ) self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True) self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) @@ -269,22 +270,24 @@ def set_src_lang_special_tokens(self, src_lang) -> None: - In default mode: Prefix=[src_lang_code], suffix = [eos] """ self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + lang_code_token = src_lang if self.legacy_behaviour: self.prefix_tokens = [] self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + self._tokenizer.post_processor = processors.TemplateProcessing( + single=["$A", self.eos_token, lang_code_token], + pair=["$A", "$B", self.eos_token, lang_code_token], + special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)], + ) else: self.prefix_tokens = [self.cur_lang_code] self.suffix_tokens = [self.eos_token_id] - - prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) - suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) - - self._tokenizer.post_processor = processors.TemplateProcessing( - single=prefix_tokens_str + ["$A"] + suffix_tokens_str, - pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, - special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), - ) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=[lang_code_token, "$A", self.eos_token], + pair=[lang_code_token, "$A", "$B", self.eos_token], + special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)], + ) def set_tgt_lang_special_tokens(self, lang: str) -> None: """Reset the special tokens to the target lang setting. @@ -292,21 +295,24 @@ def set_tgt_lang_special_tokens(self, lang: str) -> None: - In default mode: Prefix=[tgt_lang_code], suffix = [eos] """ self.cur_lang_code = self.convert_tokens_to_ids(lang) + lang_code_token = lang + if self.legacy_behaviour: self.prefix_tokens = [] self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + self._tokenizer.post_processor = processors.TemplateProcessing( + single=["$A", self.eos_token, lang_code_token], + pair=["$A", "$B", self.eos_token, lang_code_token], + special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)], + ) else: self.prefix_tokens = [self.cur_lang_code] self.suffix_tokens = [self.eos_token_id] - - prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) - suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) - - self._tokenizer.post_processor = processors.TemplateProcessing( - single=prefix_tokens_str + ["$A"] + suffix_tokens_str, - pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, - special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), - ) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=[lang_code_token, "$A", self.eos_token], + pair=[lang_code_token, "$A", "$B", self.eos_token], + special_tokens=[(self.eos_token, self.eos_token_id), (lang_code_token, self.cur_lang_code)], + ) __all__ = ["NllbTokenizer"] diff --git a/src/transformers/models/openai/tokenization_openai.py b/src/transformers/models/openai/tokenization_openai.py index 8596b879dccd..143c7aedabbd 100644 --- a/src/transformers/models/openai/tokenization_openai.py +++ b/src/transformers/models/openai/tokenization_openai.py @@ -80,13 +80,7 @@ def __init__( # Set normalizer and pre-tokenizer to mimic OpenAI GPT behavior # OpenAI GPT uses BERT BasicTokenizer with lower_case=True - self._tokenizer.normalizer = normalizers.Sequence( - [ - normalizers.NFD(), - normalizers.Lowercase(), - normalizers.StripAccents(), - ] - ) + self._tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True) self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() self._tokenizer.decoder = decoders.BPEDecoder(suffix="") diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index c4a337dbb4f1..c0ffcf258c60 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -85,6 +85,7 @@ def __init__( unk_token="", mask_token="", mask_token_sent="", + _spm_precompiled_charsmap=None, additional_special_tokens=None, offset=103, **kwargs, @@ -100,9 +101,14 @@ def __init__( self._vocab = vocab self._tokenizer = Tokenizer(Unigram(vocab=vocab, unk_id=self._vocab.index((str(unk_token), 0.0), 1))) - self._tokenizer.normalizer = normalizers.Sequence( - [normalizers.Replace(Regex(r"\n"), " "), normalizers.Replace(Regex(r" {2,}"), " ")] - ) + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Sequence( + [normalizers.Precompiled(_spm_precompiled_charsmap), normalizers.Replace(Regex(r" {2,}"), " ")] + ) + else: + self._tokenizer.normalizer = normalizers.Sequence( + [normalizers.Replace(Regex(r"\n"), " "), normalizers.Replace(Regex(r" {2,}"), " ")] + ) self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True) self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py index 06b912d4c995..3c7d79d0e6bc 100644 --- a/src/transformers/models/reformer/tokenization_reformer.py +++ b/src/transformers/models/reformer/tokenization_reformer.py @@ -73,6 +73,7 @@ def __init__( merges: str | list[str] | None = None, eos_token: str = "", unk_token: str = "", + _spm_precompiled_charsmap: str | None = None, additional_special_tokens: list | None = None, **kwargs, ): @@ -90,13 +91,13 @@ def __init__( ) ) - self._tokenizer.normalizer = normalizers.Sequence( - [ - normalizers.Replace(Regex(r"\s{2,}|[\n\r\t]"), " "), - normalizers.NFC(), - normalizers.Strip(left=False, right=True), - ] - ) + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Precompiled(_spm_precompiled_charsmap), + normalizers.Replace(pattern=Regex(" {2,}"), content=" "), + ] + ) self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always") self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always") diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py index a12f76e67910..07c25658f76a 100644 --- a/src/transformers/models/siglip2/modular_siglip2.py +++ b/src/transformers/models/siglip2/modular_siglip2.py @@ -82,9 +82,6 @@ def __init__( if backend is not None and backend.normalizer is not None: backend.normalizer = normalizers.Sequence([normalizers.Lowercase(), backend.normalizer]) - def _unk_id(self) -> int: - raise AttributeError("_unk_id is not needed for SigLIP2.") - class Siglip2TextConfig(SiglipTextConfig): pass diff --git a/src/transformers/models/siglip2/tokenization_siglip2.py b/src/transformers/models/siglip2/tokenization_siglip2.py index dc3ad98206a0..514838cbff2f 100644 --- a/src/transformers/models/siglip2/tokenization_siglip2.py +++ b/src/transformers/models/siglip2/tokenization_siglip2.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tokenizers import Tokenizer, decoders, normalizers +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers from tokenizers.models import BPE from ...tokenization_utils_tokenizers import TokenizersBackend @@ -69,6 +69,9 @@ def __init__( byte_fallback=True, ) ) + self._tokenizer.pre_tokenizer = pre_tokenizers.Split( + pattern=" ", behavior="merged_with_previous", invert=False + ) self._tokenizer.decoder = decoders.Sequence( [decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()] diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index 1f14b20b61e0..741fb8369b6b 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -15,7 +15,7 @@ import re -from tokenizers import Tokenizer, decoders, pre_tokenizers, processors +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import Unigram from ...tokenization_utils_tokenizers import TokenizersBackend @@ -74,6 +74,7 @@ def __init__( eos_token="", unk_token="", pad_token="", + _spm_precompiled_charsmap=None, extra_ids=100, additional_special_tokens=None, **kwargs, @@ -116,7 +117,8 @@ def __init__( ) ) - self._tokenizer.normalizer = None + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap) self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ @@ -124,7 +126,6 @@ def __init__( pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True), ] ) - self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) super().__init__( diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py index db39d0364439..b0339e408c2a 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py @@ -62,6 +62,7 @@ def __init__( unk_token: str = "", pad_token: str = "", mask_token: str = "", + _spm_precompiled_charsmap: str | None = None, **kwargs, ): self.add_prefix_space = add_prefix_space @@ -79,12 +80,8 @@ def __init__( self._tokenizer = Tokenizer(Unigram(vocab=self._vocab, unk_id=3, byte_fallback=False)) - self._tokenizer.normalizer = normalizers.Sequence( - [ - normalizers.Strip(left=False, right=True), - normalizers.Replace(" {2,}", "▁"), - ] - ) + if _spm_precompiled_charsmap is not None: + self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap) prepend_scheme = "always" if add_prefix_space else "never" self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( @@ -108,7 +105,7 @@ def __init__( self._tokenizer.post_processor = processors.TemplateProcessing( single=[str(bos_token), "$A", str(eos_token)], - pair=[str(bos_token), "$A", str(eos_token), "$B", str(eos_token)], + pair=[str(bos_token), "$A", str(eos_token), str(eos_token), "$B", str(eos_token)], special_tokens=[ (str(bos_token), self.bos_token_id), (str(eos_token), self.eos_token_id), diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 03c51f54c26a..dcbe26ead886 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1669,9 +1669,14 @@ def from_pretrained( if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)): # mistral tokenizer names are different, but we can still convert them if # mistral common is not there - other_pattern = r"tekken\.json|tokenizer\.model\.*" + other_pattern = r"tekken\.json|tokenizer\.model\.*|tiktoken\.model" + "|".join( + getattr(cls, "VOCAB_FILES_NAMES", {}).keys() + ) if match := re.search(other_pattern, "\n".join(remote_files)): - vocab_files["vocab_file"] = match.group() + if "spm_file" in vocab_files: + vocab_files["spm_file"] = match.group() + else: + vocab_files["vocab_file"] = match.group() resolved_vocab_files = {} for file_id, file_path in vocab_files.items(): diff --git a/src/transformers/tokenization_utils_tokenizers.py b/src/transformers/tokenization_utils_tokenizers.py index 21c3aa494b4b..932083e29581 100644 --- a/src/transformers/tokenization_utils_tokenizers.py +++ b/src/transformers/tokenization_utils_tokenizers.py @@ -35,6 +35,7 @@ from transformers.utils.hub import cached_file +from .convert_slow_tokenizer import SpmConverter from .integrations.ggml import convert_gguf_tokenizer from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .tokenization_utils_base import ( @@ -115,10 +116,40 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs): local_kwargs["tokenizer_object"] = TokenizerFast.from_file(fast_tokenizer_file) return local_kwargs elif fast_tokenizer_file is not None and os.path.isfile(fast_tokenizer_file): - # we extract vocab / merges from the tokenizer file to pass them to __init__ - processor = TokenizerFast.from_file(fast_tokenizer_file).post_processor + # we extract vocab/merges and pass decoder/pre_tokenizer/post_processor + # from the file so the reconstructed tokenizer matches the tokenizer.json + tok_from_file = TokenizerFast.from_file(fast_tokenizer_file) + local_kwargs["post_processor"] = tok_from_file.post_processor + local_kwargs["tokenizer_padding"] = tok_from_file.padding + local_kwargs["tokenizer_truncation"] = tok_from_file.truncation + # Preserve truncation and padding baked into tokenizer.json so that classes + # with a custom __init__ that rebuild the backend tokenizer from scratch + # can still access these settings. + if tok_from_file.truncation is not None: + local_kwargs["_json_truncation"] = tok_from_file.truncation + if tok_from_file.padding is not None: + local_kwargs["_json_padding"] = tok_from_file.padding + with open(fast_tokenizer_file, encoding="utf-8") as tokenizer_handle: tokenizer_json = json.load(tokenizer_handle) + + # Extract precompiled SentencePiece charsmap from tokenizer.json normalizer + # when present (e.g. T5 tokenizers converted with SentencePiece >= 2.x). + normalizer_config = tokenizer_json.get("normalizer") + if normalizer_config: + if normalizer_config.get("type", None) == "Sequence": + normalizer_config = normalizer_config["normalizers"] + elif not isinstance(normalizer_config, list): + normalizer_config = [normalizer_config] + for normalizer in normalizer_config: + if normalizer.get("type") == "Precompiled" and "precompiled_charsmap" in normalizer: + import base64 + + local_kwargs["_spm_precompiled_charsmap"] = base64.b64decode( + normalizer["precompiled_charsmap"] + ) + break + vocab = tokenizer_json.get("model", {}).get("vocab", None) if cls.model is None: if isinstance(vocab, list): @@ -139,8 +170,6 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs): merges = [tuple(merge.split(" ")) if isinstance(merge, str) else tuple(merge) for merge in merges] local_kwargs["merges"] = merges - if processor is not None: - local_kwargs["post_processor"] = processor return local_kwargs vocab_file = local_kwargs.get("vocab_file") @@ -162,7 +191,11 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs): try: from .convert_slow_tokenizer import SentencePieceExtractor - local_kwargs = SentencePieceExtractor(vocab_file).extract(cls.model, **local_kwargs) + # 1. Extract vocab, merges, and spm_precompiled from the .model proto + extractor = SentencePieceExtractor(vocab_file) + local_kwargs = extractor.extract(cls.model, **local_kwargs) + + # 2. If a model-specific converter exists, use it. try: from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS @@ -173,9 +206,35 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs): logger.warning( f"Could not reorder vocab using converter for {cls.__name__} due to {e}. Falling back to raw SentencePiece extraction." ) - # what used to be in `convert_slow` if hasattr(cls, "convert_from_spm_model"): local_kwargs = cls.convert_from_spm_model(**local_kwargs) + + # 3. For non-model specific tokenizers (e.g. TokenizersBackend used + # for MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS), build a _tokenizer + # from the proto so normalizer/decoder are configured correctly. + if "tokenizer_object" not in local_kwargs and ( + cls is TokenizersBackend or "__init__" not in cls.__dict__ + ): + vocab = local_kwargs.pop("vocab", None) + merges = local_kwargs.pop("merges", None) + tokenizer_object = SpmConverter.build_tokenizer_from_spm_proto( + proto=extractor.proto, + vocab=vocab, + merges=merges, + ) + if tokenizer_object is not None: + local_kwargs["tokenizer_object"] = tokenizer_object + # Set bos/eos tokens from proto spec if available. This is needed when + # building a tokenizer_object directly from a .model file because the + # tokenizer_object does not have bos/eos set. + proto_spec = extractor.proto.trainer_spec + if proto_spec.bos_id >= 0: + local_kwargs.setdefault("bos_token", proto_spec.bos_piece or "") + if proto_spec.eos_id >= 0: + local_kwargs.setdefault("eos_token", proto_spec.eos_piece or "") + if proto_spec.unk_id >= 0: + local_kwargs.setdefault("unk_token", proto_spec.unk_piece or "") + except Exception as e: # TODO only catch deserialization error here! logger.warning( f"Could not extract SentencePiece model from {vocab_file} using sentencepiece library due to {e}. " @@ -183,10 +242,10 @@ def convert_to_native_format(cls, trust_remote_code=False, **kwargs): ) from .convert_slow_tokenizer import TikTokenConverter - local_kwargs["vocab"], local_kwargs["merges"] = TikTokenConverter( + converter = TikTokenConverter( vocab_file=vocab_file, extra_special_tokens=local_kwargs.get("extra_special_tokens") - ).extract_vocab_merges_from_model(vocab_file) - + ) + local_kwargs["tokenizer_object"] = converter.converted() return local_kwargs # Fallback to standard vocab/merges files if they existed! @@ -232,6 +291,14 @@ def _iter_special_tokens(values: Iterable[Any]) -> list[str]: return local_kwargs def __init__(self, *args, **kwargs): + # Truncation/padding dicts extracted from tokenizer.json by convert_to_native_format + # when a class with a custom __init__ rebuilds the backend tokenizer from scratch. + _json_truncation = kwargs.pop("_json_truncation", None) + _json_padding = kwargs.pop("_json_padding", None) + # Precompiled SentencePiece charsmap is already used by model-specific tokenizers + # (before calling super().__init__) and should not be stored in `init_kwargs` to keep the tokenizer serializable. + kwargs.pop("_spm_precompiled_charsmap", None) + tokenizer_object = kwargs.pop("tokenizer_object", None) gguf_file = kwargs.pop("gguf_file", None) fast_tokenizer_file = kwargs.pop("tokenizer_file", None) @@ -289,8 +356,7 @@ def __init__(self, *args, **kwargs): if self._tokenizer is None: raise ValueError("The backend tokenizer is not correctly initialized.") - _truncation = self._tokenizer.truncation - + _truncation = kwargs.pop("tokenizer_truncation", None) or self._tokenizer.truncation or _json_truncation if _truncation is not None: self._tokenizer.enable_truncation(**_truncation) kwargs.setdefault("max_length", _truncation["max_length"]) @@ -300,7 +366,7 @@ def __init__(self, *args, **kwargs): else: self._tokenizer.no_truncation() - _padding = self._tokenizer.padding + _padding = kwargs.pop("tokenizer_padding", None) or self._tokenizer.padding or _json_padding if _padding is not None: self._tokenizer.enable_padding(**_padding) kwargs.setdefault("pad_token", _padding["pad_token"]) diff --git a/tests/models/dpr/test_tokenization_dpr.py b/tests/models/dpr/test_tokenization_dpr.py deleted file mode 100644 index 97c2d95443d2..000000000000 --- a/tests/models/dpr/test_tokenization_dpr.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huggingface -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import ( - DPRContextEncoderTokenizer, - DPRContextEncoderTokenizerFast, - DPRQuestionEncoderTokenizer, - DPRQuestionEncoderTokenizerFast, - DPRReaderOutput, - DPRReaderTokenizer, - DPRReaderTokenizerFast, -) -from transformers.testing_utils import require_tokenizers, slow -from transformers.tokenization_utils_base import BatchEncoding - -from ..bert import test_tokenization_bert - - -@require_tokenizers -class DPRContextEncoderTokenizationTest(test_tokenization_bert.BertTokenizationTest): - tokenizer_class = DPRContextEncoderTokenizer - rust_tokenizer_class = DPRContextEncoderTokenizerFast - test_rust_tokenizer = True - from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base" - - -@require_tokenizers -class DPRQuestionEncoderTokenizationTest(test_tokenization_bert.BertTokenizationTest): - tokenizer_class = DPRQuestionEncoderTokenizer - rust_tokenizer_class = DPRQuestionEncoderTokenizerFast - test_rust_tokenizer = True - from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base" - - -@require_tokenizers -class DPRReaderTokenizationTest(test_tokenization_bert.BertTokenizationTest): - tokenizer_class = DPRReaderTokenizer - rust_tokenizer_class = DPRReaderTokenizerFast - test_rust_tokenizer = True - test_seq2seq = False - from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base" - - @slow - def test_decode_best_spans(self): - tokenizer = self.tokenizer_class.from_pretrained("google-bert/bert-base-uncased") - - text_1 = tokenizer.encode("question sequence", add_special_tokens=False) - text_2 = tokenizer.encode("title sequence", add_special_tokens=False) - text_3 = tokenizer.encode("text sequence " * 4, add_special_tokens=False) - input_ids = [[101] + text_1 + [102] + text_2 + [102] + text_3] - reader_input = BatchEncoding({"input_ids": input_ids}) - - start_logits = [[0] * len(input_ids[0])] - end_logits = [[0] * len(input_ids[0])] - relevance_logits = [0] - reader_output = DPRReaderOutput(start_logits, end_logits, relevance_logits) - - start_index, end_index = 8, 9 - start_logits[0][start_index] = 10 - end_logits[0][end_index] = 10 - predicted_spans = tokenizer.decode_best_spans(reader_input, reader_output) - self.assertEqual(predicted_spans[0].start_index, start_index) - self.assertEqual(predicted_spans[0].end_index, end_index) - self.assertEqual(predicted_spans[0].doc_id, 0) - - @slow - def test_call(self): - tokenizer = self.tokenizer_class.from_pretrained("google-bert/bert-base-uncased") - - text_1 = tokenizer.encode("question sequence", add_special_tokens=False) - text_2 = tokenizer.encode("title sequence", add_special_tokens=False) - text_3 = tokenizer.encode("text sequence", add_special_tokens=False) - expected_input_ids = [101] + text_1 + [102] + text_2 + [102] + text_3 - encoded_input = tokenizer(questions=["question sequence"], titles=["title sequence"], texts=["text sequence"]) - self.assertIn("input_ids", encoded_input) - self.assertIn("attention_mask", encoded_input) - self.assertListEqual(encoded_input["input_ids"][0], expected_input_ids) diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 360901205627..757431ae6baf 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -1,6 +1,7 @@ import unittest from tests.test_tokenization_common import TokenizerTesterMixin +from transformers import AutoTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.testing_utils import ( require_tokenizers, @@ -35,3 +36,18 @@ def setUpClass(cls): def get_tokenizers(self, **kwargs): kwargs.setdefault("pad_token", "") return super().get_tokenizers(**kwargs) + + def test_load_tiktoken_tokenizer(self): + """Test loading a Llama tokenizer from tiktoken.model file""" + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama3-tokenizer-tiktoken") + + text = "This is a test" + tokens = tokenizer.encode(text, add_special_tokens=False) + decoded = tokenizer.decode(tokens, skip_special_tokens=True) + self.assertEqual(decoded, text) + + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama3-tokenizer-tiktoken") + text = "This is a test" + tokens = tokenizer.encode(text, add_special_tokens=False) + decoded = tokenizer.decode(tokens, skip_special_tokens=True) + self.assertEqual(decoded, text) diff --git a/tests/models/mbart50/test_tokenization_mbart50.py b/tests/models/mbart50/test_tokenization_mbart50.py index 52838b37e533..2069f030973f 100644 --- a/tests/models/mbart50/test_tokenization_mbart50.py +++ b/tests/models/mbart50/test_tokenization_mbart50.py @@ -42,10 +42,10 @@ class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase): from_pretrained_id = "facebook/mbart-large-50-one-to-many-mmt" tokenizer_class = MBart50Tokenizer - integration_expected_tokens = ['▁This', '▁is', '▁a', '▁test', '▁', '😊', '▁I', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fals', 'é', '.', '▁', '生活的', '真', '谛', '是', '▁Hi', '▁Hello', '▁Hi', '▁Hello', '▁Hello', '', '▁hi', '', '▁there', '▁The', '▁following', '▁string', '▁should', '▁be', '▁properly', '▁en', 'code', 'd', ':', '▁Hello', '.', '▁But', '▁ir', 'd', '▁and', '▁ปี', '▁ir', 'd', '▁ด', '▁Hey', '▁how', '▁are', '▁you', '▁doing'] # fmt: skip - integration_expected_token_ids = [3293, 83, 10, 3034, 6, 82803, 87, 509, 103122, 23, 483, 13821, 4, 136, 903, 83, 84047, 446, 5, 6, 62668, 5364, 245875, 354, 2673, 35378, 2673, 35378, 35378, 0, 1274, 0, 2685, 581, 25632, 79315, 5608, 186, 155965, 22, 40899, 71, 12, 35378, 5, 4966, 193, 71, 136, 10249, 193, 71, 48229, 28240, 3642, 621, 398, 20594] # fmt: skip - expected_tokens_from_ids = ['▁This', '▁is', '▁a', '▁test', '▁', '😊', '▁I', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fals', 'é', '.', '▁', '生活的', '真', '谛', '是', '▁Hi', '▁Hello', '▁Hi', '▁Hello', '▁Hello', '', '▁hi', '', '▁there', '▁The', '▁following', '▁string', '▁should', '▁be', '▁properly', '▁en', 'code', 'd', ':', '▁Hello', '.', '▁But', '▁ir', 'd', '▁and', '▁ปี', '▁ir', 'd', '▁ด', '▁Hey', '▁how', '▁are', '▁you', '▁doing'] # fmt: skip - integration_expected_decoded_text = "This is a test 😊 I was born in 92000, and this is falsé. 生活的真谛是 Hi Hello Hi Hello Hello hi there The following string should be properly encoded: Hello. But ird and ปี ird ด Hey how are you doing" + integration_expected_tokens = ['▁This', '▁is', '▁a', '▁test', '▁', '😊', '▁I', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fals', 'é', '.', '▁', '生活的', '真', '谛', '是', '▁Hi', '▁Hello', '▁Hi', '▁Hello', '▁Hello','▁', '', '▁hi', '', '▁there', '▁The', '▁following', '▁string', '▁should', '▁be', '▁properly', '▁en', 'code', 'd', ':', '▁Hello', '.', '▁But', '▁ir', 'd', '▁and', '▁ปี', '▁ir', 'd', '▁ด', '▁Hey', '▁how', '▁are', '▁you', '▁doing'] # fmt: skip + integration_expected_token_ids = [3293, 83, 10, 3034, 6, 82803, 87, 509, 103122, 23, 483, 13821, 4, 136, 903, 83, 84047, 446, 5, 6, 62668, 5364, 245875, 354, 2673, 35378, 2673, 35378, 35378,6, 0, 1274, 0, 2685, 581, 25632, 79315, 5608, 186, 155965, 22, 40899, 71, 12, 35378, 5, 4966, 193, 71, 136, 10249, 193, 71, 48229, 28240, 3642, 621, 398, 20594] # fmt: skip + expected_tokens_from_ids = ['▁This', '▁is', '▁a', '▁test', '▁', '😊', '▁I', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fals', 'é', '.', '▁', '生活的', '真', '谛', '是', '▁Hi', '▁Hello', '▁Hi', '▁Hello', '▁Hello','▁', '', '▁hi', '', '▁there', '▁The', '▁following', '▁string', '▁should', '▁be', '▁properly', '▁en', 'code', 'd', ':', '▁Hello', '.', '▁But', '▁ir', 'd', '▁and', '▁ปี', '▁ir', 'd', '▁ด', '▁Hey', '▁how', '▁are', '▁you', '▁doing'] # fmt: skip + integration_expected_decoded_text = "This is a test 😊 I was born in 92000, and this is falsé. 生活的真谛是 Hi Hello Hi Hello Hello hi there The following string should be properly encoded: Hello. But ird and ปี ird ด Hey how are you doing" @require_torch diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f48980e797e6..833134c2913f 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -445,6 +445,21 @@ def get_extracted_tokenizer(self, reference_tokenizer=None): if _type.__name__ == "BPE" or _type.__name__ == "WordPiece": vocab = vocab_ids + # Extract precompiled SentencePiece charsmap from tokenizer.json normalizer + extra_kwargs = {} + normalizer_config = extractor.tokenizer_data.get("normalizer") + if normalizer_config: + if normalizer_config.get("type", None) == "Sequence": + normalizer_list = normalizer_config["normalizers"] + elif not isinstance(normalizer_config, list): + normalizer_list = [normalizer_config] + for normalizer in normalizer_list: + if normalizer.get("type") == "Precompiled" and "precompiled_charsmap" in normalizer: + import base64 + + extra_kwargs["_spm_precompiled_charsmap"] = base64.b64decode(normalizer["precompiled_charsmap"]) + break + # Convert added_tokens list to added_tokens_decoder dict format # This matches the format used by from_pretrained() from tokenizer_config.jso tokenizer_from_extractor = self.tokenizer_class( @@ -453,6 +468,7 @@ def get_extracted_tokenizer(self, reference_tokenizer=None): do_lower_case=False, keep_accents=True, added_tokens_decoder=added_tokens_decoder, + **extra_kwargs, **(self.from_pretrained_kwargs if self.from_pretrained_kwargs is not None else {}), ) @@ -637,6 +653,7 @@ def test_tokenizer_store_full_signature(self): "vocab", "merges", "legacy", + "_spm_precompiled_charsmap", "additional_special_tokens", # V5: deprecated, converted to extra_special_tokens ]: self.assertIn(parameter_name, tokenizer.init_kwargs) @@ -866,6 +883,10 @@ def test_integration_from_extractor(self): tokenizer_from_extractor = self.get_extracted_tokenizer(reference_tokenizer=tokenizer_original) if tokenizer_from_extractor is None: self.fail("No tokenizer from TokenizersExtractor provided") + + # Debug: print tokenizer class used by tokenizer_from_extractor + print("tokenizer_from_extractor class:", type(tokenizer_from_extractor)) + self._run_integration_checks(tokenizer_from_extractor, "from_extractor") def test_internal_consistency(self): diff --git a/tests/test_tokenizers_backend_mixin.py b/tests/test_tokenizers_backend_mixin.py index e8d3b2b96bd1..abc9306ed641 100644 --- a/tests/test_tokenizers_backend_mixin.py +++ b/tests/test_tokenizers_backend_mixin.py @@ -3,11 +3,13 @@ import inspect import shutil import tempfile +import unittest from typing import TYPE_CHECKING from parameterized import parameterized -from transformers import TokenizersBackend +from transformers import AutoTokenizer, TokenizersBackend +from transformers.testing_utils import require_tokenizers, slow from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -483,3 +485,84 @@ def test_local_files_only(self): except Exception as e: # if the pretrained model is not loadable how could it pass locally :) print(f"Could not load tokenizer model: {e}") + + +@slow +@require_tokenizers +class TokenizersBackendV5RoundtripIntegrationTest(unittest.TestCase): + """Integration tests: one decode(encode(text)) check per TokenizersBackend v5 model (PR #44255).""" + + ROUNDTRIP_TEST_TEXT = """This is a test 😊 +I was born in 92000, and this is falsé. +生活的真谛是 +Hi Hello +Hi Hello + + + + Hello + +hithere +The following string should be properly encoded: Hello. +But ird and ปี ird ด +Hey how are you doing""" # noqa: W293 + + ADDITIONAL_ROUNDTRIP_CASES = [ + "وقال، ماما، لقد عدت للمنزل.", + "لم ينطق ببنت شفة.", + "Он ничего не сказал.", + "Αυτό είναι ένα δοκιμαστικό κείμενο.", + "यह सिर्फ एक परीक्षण वाक्य है।", + "Tôi đã sống ở Việt Nam từ năm 1990.", + "def foo(x):\n return x + 1\n", + ] + + EXPECTED_XLANGAI_OPENCUA_7B = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_INTERNLM_INTERNLM2_CHAT_7B = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_STEPFUN_AI_STEP_35_FLASH = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_AI21LABS_JAMBA_TINY_DEV = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_ADEPT_FUYU_8B = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_MICROSOFT_PHI_3_MINI_4K_INSTRUCT = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n \nhi there\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + EXPECTED_MUCAI_VIP_LLAVA_7B = "This is a test 😊\nI was born in 92000, and this is falsé.\n生活的真谛是\nHi Hello\nHi Hello\n\n \n \n Hello\n\nhithere\nThe following string should be properly encoded: Hello.\nBut ird and ปี ird ด\nHey how are you doing" + + TOKENIZERS_BACKEND_V5_MODELS_WITH_EXPECTED = [ + ("xlangai/OpenCUA-7B", EXPECTED_XLANGAI_OPENCUA_7B), + ("internlm/internlm2-chat-7b", EXPECTED_INTERNLM_INTERNLM2_CHAT_7B), + ("stepfun-ai/Step-3.5-Flash", EXPECTED_STEPFUN_AI_STEP_35_FLASH), + ("ai21labs/Jamba-tiny-dev", EXPECTED_AI21LABS_JAMBA_TINY_DEV), + ("adept/fuyu-8b", EXPECTED_ADEPT_FUYU_8B), + ("microsoft/Phi-3-mini-4k-instruct", EXPECTED_MICROSOFT_PHI_3_MINI_4K_INSTRUCT), + ("mucai/vip-llava-7b", EXPECTED_MUCAI_VIP_LLAVA_7B), + ] + + @parameterized.expand(TOKENIZERS_BACKEND_V5_MODELS_WITH_EXPECTED) + def test_decode_encode_roundtrip(self, model_id: str, expected_decoded_text: str) -> None: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + trust_remote_code=True, + use_fast=True, + ) + ids = tokenizer.encode(self.ROUNDTRIP_TEST_TEXT, add_special_tokens=False) + decoded = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual( + decoded, + expected_decoded_text, + f"Roundtrip failed for {model_id}: got {decoded!r}", + ) + + @parameterized.expand(TOKENIZERS_BACKEND_V5_MODELS_WITH_EXPECTED) + def test_additional_roundtrip_cases(self, model_id: str, _expected_decoded_text: str) -> None: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + trust_remote_code=True, + use_fast=True, + ) + for text in self.ADDITIONAL_ROUNDTRIP_CASES: + with self.subTest(text=text): + ids = tokenizer.encode(text, add_special_tokens=False) + decoded = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual( + decoded, + text, + f"Roundtrip failed for {model_id} on sample {text!r}", + )