Skip to content

[vllm + v5 fix] handle TokenizersBackend fallback properly for v5#44255

Merged
ArthurZucker merged 48 commits intomainfrom
bad_models_update
Mar 4, 2026
Merged

[vllm + v5 fix] handle TokenizersBackend fallback properly for v5#44255
ArthurZucker merged 48 commits intomainfrom
bad_models_update

Conversation

@itazap
Copy link
Collaborator

@itazap itazap commented Feb 24, 2026

What this PR does

Given he different issues that were noticed by @hmellor on vLLM, we wanted to make sure we did not end up with crazy breaks. We ran a full test suite (code can be found in #44298) and the results showed 22 model converters that had potential issues.

We test what would be the AutoTokenizer.from_pretrained(model_id)._tokenizer in v4 vs in v5, reporting diff in the json serialization.

The full report is available here:
report.html

Here is an explanation for each converter:

Important fixes

  • ReformerConverter/PegasusConverter/T5Converter/MBart50Converter/XLMRobertaConverter/NllbConverter: this is the biggest "change" we did miss the precompiled_charsmap and it is required for very specific whitespace cases. For all of these, we only changed the ones that had XNLI missmatches. The motivation is to have an explicit noramalizer vs something arbitrary. This is to follow the default in: https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/convert_slow_tokenizer.py#L629-L638
  • MBart50Converter: was not usable / could not convert before so that's a big fix. We checked the expected results with v4, that's why the integration tests are changed.
  • GemmaConverter: we missed a Split + in v5 named tokens (mask_token) are always special. This is a known change for v5 and would require relaxing the constraint that named tokens are special. Will consider this in the future.
  • BertConverter: do_lower_case=True by default which we missed going to v5.

All of these are unigram models that need the precompiled for super super specific shit.

Not that important, could affect small stuff

  • BlenderbotConverter: it had a RobertaProcessing but it should not
  • BigBirdConverter: Should not have the strip normalizer?

Mild concerns fixes/ absolutely safe to ignore

  • Qwen2Converter: For both the pre_tokenizer and the decoder, the trim_offsets and use_regex flag are useless. Serialization of Bytelevel for decoder and pre_tokenizer is missleading tokenizers#1960 will address this but TLDR: behavior is the same.
  • TikTokenConverter: Same as Qwen2Converter, the decoder flags are useless
  • LlamaConverter: this IS expected as most of the old llama models had the legacy flag with the metaspace issue (link it) prepend normalizer is shit
  • CLIPConverter: the split pattern has "<\|startoftext\|>|<\|endoftext\|>" which we added. It does not make a difference as it's part of the special tokens, but we are removing this.
  • OpenAIGPTConverter: behavior was identical but for fairness it needs BertNormalizer
  • MarkupLMConverter: seems safe to ignore in v4 it did not have RobertaProcessing either
  • SeamlessM4TConverter: weird
  • RobertaConverter: default template is diferent + add_prefix_space=False
  • GPT2Converter: equivalent diff / safe to ignore. pre_tokenizer=ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True) is equivalent to the default Regex: Sequence(pretokenizers=[Split(pattern=Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]|\s[\r\n]..."), behavior=Removed, invert=True), ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=False)]) link to tokenizers byte_level and default
  • DebertaConverter: equivalent diff / safe to ignore: TemplateProcessing is serialized in a slightly different order.

UPDATE TO: https://github.com/huggingface/transformers/pull/44179/changes (deepseek v2, internlm2)

Models with incorrect tokenizer_class in tokenization_config.json that should use TokenziersBackend

Sentencepiece /Tiktoken models:

  • we don't properly read all the tokenizer parameters from a sentencpeiece tokenizer.model file, so we need to extract that info and create a generic fallback that will create a _tokenizer object from the

Script to compare roundtrip tokenizer outputs:
#!/usr/bin/env python
import argparse
import json
import re

from datasets import load_dataset
from tqdm import tqdm

import transformers
from transformers import AutoTokenizer


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("model_name", type=str)
    parser.add_argument("version_tag", type=str)
    parser.add_argument("--num-samples", type=int, default=1000)
    parser.add_argument("--compare-to", type=str, default=None)
    args = parser.parse_args()

    safe_model = re.sub(r"[^0-9A-Za-z_.\-]+", "_", args.model_name.replace("/", "-").replace(" ", "_"))
    out_path = f"xlni_{safe_model}_{args.version_tag}"

    print(f"Transformers: {transformers.__version__}")
    print(f"Loading tokenizer: {args.model_name} (trust_remote_code=True, use_fast=True)")
    tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=True)
    print(f"  Loaded {type(tok).__name__}")

    if hasattr(tok, "_tokenizer") and tok._tokenizer is not None:
        with open(f"{out_path}_tokenizer.txt", "w", encoding="utf-8") as f:
            f.write(str(tok._tokenizer))

    results = []

    for lang in ["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"]:
        ds = load_dataset("facebook/xnli", lang, split="validation")
        limit = min(args.num_samples, len(ds))
        for i, ex in enumerate(tqdm(ds, total=limit, desc=f"XNLI [{lang}]")):
            if i >= limit:
                break
            for field in ("premise", "hypothesis"):
                text = ex.get(field, "")
                if not text:
                    continue
                ids = tok.encode(text, add_special_tokens=False)
                decoded = tok.decode(ids, skip_special_tokens=True)
                results.append(
                    {
                        "language": lang,
                        "index": i,
                        "field": field,
                        "text": text,
                        "ids": ids,
                        "decoded": decoded,
                        "roundtrip_ok": decoded == text,
                    }
                )

    try:
        code_ds = load_dataset("code_search_net", "python", split="train")
        limit = min(args.num_samples, len(code_ds))
        for i, ex in enumerate(tqdm(code_ds, total=limit, desc="Code [code_search_net/python]")):
            if i >= limit:
                break
            for field in ("func_code_string", "func_documentation_string", "code", "docstring"):
                text = ex.get(field)
                if not text:
                    continue
                ids = tok.encode(text, add_special_tokens=False)
                decoded = tok.decode(ids, skip_special_tokens=True)
                results.append(
                    {
                        "language": "code",
                        "source": "code_search_net/python",
                        "index": i,
                        "field": field,
                        "text": text,
                        "ids": ids,
                        "decoded": decoded,
                        "roundtrip_ok": decoded == text,
                    }
                )
    except Exception as e:
        print(f"Could not load code_search_net/python: {e}")

    if args.compare_to:
        with open(args.compare_to, "r", encoding="utf-8") as f:
            prev_payload = json.load(f)
        prev_results = prev_payload.get("results", [])
        n = min(len(prev_results), len(results))
        if n:
            prev_ok = sum(1 for r in prev_results[:n] if r.get("roundtrip_ok"))
            curr_ok = sum(1 for r in results[:n] if r.get("roundtrip_ok"))
            print("\n--- Roundtrip comparison ---")
            print(f"Samples compared: {n}")
            print(f"Prev roundtrip OK: {prev_ok}/{n} ({100 * prev_ok / n:.1f}%)")
            print(f"Curr roundtrip OK: {curr_ok}/{n} ({100 * curr_ok / n:.1f}%)")
            diffs = []
            for i in range(n):
                a = prev_results[i]
                b = results[i]
                if a.get("index") != b.get("index") or a.get("field") != b.get("field"):
                    break
                if a.get("roundtrip_ok") != b.get("roundtrip_ok") or a.get("decoded") != b.get("decoded"):
                    diffs.append((a, b))
            print(f"Changed samples:   {len(diffs)}")
            if diffs:
                print("\nFirst few changed samples:")
                for a, b in diffs[:10]:
                    print(f"  #{a['index']} prev_ok={a['roundtrip_ok']} curr_ok={b['roundtrip_ok']}")
                    prev_dec = a.get("decoded") or ""
                    curr_dec = b.get("decoded") or ""
                    max_len = max(len(prev_dec), len(curr_dec))
                    pos = 0
                    while (
                        pos < max_len
                        and pos < len(prev_dec)
                        and pos < len(curr_dec)
                        and prev_dec[pos] == curr_dec[pos]
                    ):
                        pos += 1
                    radius = 40
                    start = max(0, pos - radius)
                    end = min(max_len, pos + radius)
                    prev_window = prev_dec[start:end]
                    curr_window = curr_dec[start:end]
                    marker = " " * max(0, pos - start) + "^"
                    print(f"    prev_decoded: {prev_window!r}")
                    print(f"    curr_decoded: {curr_window!r}")
                    print(f"    diff_window : {marker}")

    payload = {
        "model_name": args.model_name,
        "model_name_sanitized": safe_model,
        "transformers_version": transformers.__version__,
        "num_samples": args.num_samples,
        "num_results": len(results),
        "results": results,
    }

    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

    print(f"\nSaved {len(results)} roundtrip entries to '{out_path}'.")


if __name__ == "__main__":
    raise SystemExit(main())
(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py xlangai/OpenCUA-7B v5 --compare-to xlni_xlangai-OpenCUA-7B_v4_use_fast
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: xlangai/OpenCUA-7B (trust_remote_code=True, use_fast=True)
Could not extract SentencePiece model from /Users/itazaporozhets/.cache/huggingface/hub/models--xlangai--OpenCUA-7B/snapshots/a2efb7d2b104d477a4a2666a357e79550a28aafc/tiktoken.model using sentencepiece library due to Error parsing message with type 'sentencepiece.ModelProto'. Falling back to TikToken extractor.
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 8747.38it/s]
Code [code_search_net/python]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2027.26it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_xlangai-OpenCUA-7B_v5'.
(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py internlm/internlm2-chat-7b v5 --compare-to xlni_internlm-internlm2-chat-7b_v4_use_fast
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: internlm/internlm2-chat-7b (trust_remote_code=True, use_fast=True)
Unrecognized keys in `rope_parameters` for 'rope_type'='dynamic': {'rope_theta'}
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 7579.54it/s]
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1971.22it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_internlm-internlm2-chat-7b_v5'.

(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py stepfun-ai/Step-3.5-Flash v5 --compare-to xlni_stepfun-ai-Step-3.5-Flash_v4_use_fast
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: stepfun-ai/Step-3.5-Flash (trust_remote_code=True, use_fast=True)
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 8374.95it/s]
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1728.44it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_stepfun-ai-Step-3.5-Flash_v5'.
(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py ai21labs/Jamba-tiny-dev v5 --compare-to xlni_ai21labs-Jamba-tiny-dev_v4_use_fast
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: ai21labs/Jamba-tiny-dev (trust_remote_code=True, use_fast=True)
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 10206.83it/s]
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1798.66it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_ai21labs-Jamba-tiny-dev_v5'.

(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py adept/fuyu-8b  v5 --compare-to  xlni_adept-fuyu-8b_v4_use_fast
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: adept/fuyu-8b (trust_remote_code=True, use_fast=True)
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 5958.00it/s]
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1367.24it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_adept-fuyu-8b_v5'.
(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py microsoft/Phi-3-mini-4k-instruct v5 --compare-to xlni_microsoft-Phi-3-mini-4k-instruct_v4
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: microsoft/Phi-3-mini-4k-instruct (trust_remote_code=True, use_fast=True)
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 967/967 [00:00<00:00, 443kB/s]
configuration_phi3.py: 11.2kB [00:00, 5.54MB/s]
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 9145.49it/s]
Code [code_search_net/python]:   0%|                                                                                                                                            | 0/1000 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4599 > 4096). Running this sequence through the model will result in indexing errors
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1624.12it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_microsoft-Phi-3-mini-4k-instruct_v5'.
(uv_env) (base) itazaporozhets@Huggingfaces-MacBook-Pro-2 transformers % python scripts/compare_tokenizers.py mucai/vip-llava-7b v5 --compare-to xlni_mucai-vip-llava-7b_v4
/Users/itazaporozhets/Documents/Repos/transformers/uv_env/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Transformers: 5.3.0.dev0
Loading tokenizer: mucai/vip-llava-7b (trust_remote_code=True, use_fast=True)
config.json: 1.09kB [00:00, 626kB/s]
  Loaded TokenizersBackend
XNLI languages: [ar] [bg] [de] [el] [en] [es] [fr] [hi] [ru] [sw] [th] [tr] [ur] [vi] [zh]
XNLI: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15000/15000 [00:01<00:00, 9549.95it/s]
Code [code_search_net/python]:   0%|                                                                                                                                            | 0/1000 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4609 > 2048). Running this sequence through the model will result in indexing errors
Code [code_search_net/python]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1680.68it/s]

--- Roundtrip comparison ---
Samples compared: 32000
Prev roundtrip OK: 32000/32000 (100.0%)
Curr roundtrip OK: 32000/32000 (100.0%)
Changed samples:   0

Saved 32000 roundtrip entries to 'xlni_mucai-vip-llava-7b_v5'.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@itazap itazap requested a review from ArthurZucker February 24, 2026 15:11
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from testing:

  • olmo2
  • qwen2
    and probably many other are affected 😢

I'll make a list of arch and test them but at least let's add these 2!

@ArthurZucker
Copy link
Collaborator

Also before we can merge IMO we need to make sure the auto conversion is foolproof otherwise we are ignoring code but the conversion is still not correct

@itazap itazap requested a review from ArthurZucker February 24, 2026 15:28
@itazap itazap changed the title [vllm + v5 fix] update deepseek v2 tokenizer class for v5 [vllm + v5 fix] handle TokenizersBackend fallback properly for v5 Feb 26, 2026
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, let's add tests to validate which model ids we fixed on tthe hub

Comment on lines +641 to +642
Similar to convert_from_spm method, but used when converting directly from proto and vocab/merges.
(convert_from_spm requires some class attrs like byte_fallback, unk_piece, precompiled_charsmap, etc.)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Similar to convert_from_spm method, but used when converting directly from proto and vocab/merges.
(convert_from_spm requires some class attrs like byte_fallback, unk_piece, precompiled_charsmap, etc.)
Similar to convert_from_spm method, but used only when there is no `model_type` class, i.e. there is no matching class in `TOKENIZERS_MAPPING` and we just create a tokenizer instead of extracting stuff from the sentencepiece file

func name can probably be updated as well

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's just have better names for each functions and a bit more tests

@itazap itazap force-pushed the bad_models_update branch from 575d566 to 2e03e15 Compare March 2, 2026 16:33
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty for addressing the comments!

local_kwargs.setdefault("bos_token", proto_spec.bos_piece or "<s>")
if proto_spec.eos_id >= 0:
local_kwargs.setdefault("eos_token", proto_spec.eos_piece or "</s>")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we look for unk id as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move these to be done inside:

extractor.extract(cls.model, **local_kwargs)

if possible (whatever we can do with the proto there?)

@ArthurZucker
Copy link
Collaborator

run-slow: auto, gemma, lasr, llama, siglip2, t5

@github-actions
Copy link
Contributor

github-actions bot commented Mar 3, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/auto", "models/gemma", "models/lasr", "models/llama", "models/siglip2", "models/t5"]
quantizations: []

@itazap itazap requested a review from ArthurZucker March 3, 2026 12:32
@github-actions
Copy link
Contributor

github-actions bot commented Mar 3, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b7d01aac workflow commit (merge commit)
PR ed4c16b6 branch commit (from PR)
main 28d02a31 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty!
I think GPT2 needs potentially to be removed from the list as this: https://huggingface.co/Open4bits/granite-4.0-h-tiny-mlx-fp16/tree/main

-pre_tokenizer:		ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True)
+pre_tokenizer:		Sequence(pretokenizers=[Split(pattern=Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]..."), behavior=Removed, invert=True), ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=False)])

When I check our old converted, indeed it does not have a split pattern. Its a bit weird.
Similarly:

-pre_tokenizer:		ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True)
+pre_tokenizer:		Sequence(pretokenizers=[Digits(individual_digits=True), ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True)])

appears in some of the GPT2Tokenizer (mapped).

Pegasus has a similar issue:

-normalizer:		Sequence(normalizers=[Replace(pattern=Regex("\n"), content=" "), Replace(pattern=Regex(" {2,}"), content=" ")])
-pre_tokenizer:		Metaspace(replacement="▁", prepend_scheme=always, split=True)
+normalizer:		Sequence(normalizers=[Precompiled(precompiled_charsmap="ALQCAACEAAAAAACAAQAAgMz8AgC4BQAAhyIAgMzkAgC4PQAAeyIAgMzsAgC4BQAAiyIAgMw8AADNvAAAmwkAgJ4JAIChCQCAgx0A..."), Replace(pattern=Regex(" {2,}"), content=" ")])
+pre_tokenizer:		Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)])

Pattern8 for LlamaConverter is fine, we did a bug fix!
Pattern9: Clip has a small issue:

        self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
            [
                pre_tokenizers.Split(
                    Regex(
                        r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""
                    ),
                    behavior="removed",
                    invert=True,
                ),
                pre_tokenizers.ByteLevel(add_prefix_space=False),
            ]
        )

should not have r"""<\|startoftext\|>|<\|endoftext\|>

BlenderbotConverter needs updated post processor to roberta processing I think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-normalizer:		Sequence(normalizers=[Strip(strip_left=False, strip_right=True), Replace(pattern=Regex(" {2,}"), content="▁"), Precompiled(precompiled_charsmap="...")])
-pre_tokenizer:		Metaspace(replacement="▁", prepend_scheme=always, split=True)
+normalizer:		Precompiled(precompiled_charsmap="...")
+pre_tokenizer:		Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)])

is what most T5 have as an issue.
I think if we add the regex, we need the normalizer to be strip + replace + precompiled !

return None

# normalizer
_normalizers = [normalizers.Replace(" ", "▁")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's only for some model not all of them (ex gpt2 uses Ġ )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah MB, sentencepiece never used Ġ !
So ignore this comment probably

# decoder
if byte_fallback:
tokenizer.decoder = decoders.Sequence(
[decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +123 to +131
local_kwargs["tokenizer_padding"] = tok_from_file.padding
local_kwargs["tokenizer_truncation"] = tok_from_file.truncation
# Preserve truncation and padding baked into tokenizer.json so that classes
# with a custom __init__ that rebuild the backend tokenizer from scratch
# can still access these settings.
if tok_from_file.truncation is not None:
local_kwargs["_json_truncation"] = tok_from_file.truncation
if tok_from_file.padding is not None:
local_kwargs["_json_padding"] = tok_from_file.padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we needd both tokenizer_padding and _json_truncation which are the same

local_kwargs.setdefault("bos_token", proto_spec.bos_piece or "<s>")
if proto_spec.eos_id >= 0:
local_kwargs.setdefault("eos_token", proto_spec.eos_piece or "</s>")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move these to be done inside:

extractor.extract(cls.model, **local_kwargs)

if possible (whatever we can do with the proto there?)


_truncation = self._tokenizer.truncation

_truncation = kwargs.pop("tokenizer_truncation", None) or self._tokenizer.truncation or _json_truncation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment

normalizers.StripAccents(),
]
)
self._tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

@itazap itazap force-pushed the bad_models_update branch from 02766a3 to df12cc4 Compare March 4, 2026 09:32
@github-actions
Copy link
Contributor

github-actions bot commented Mar 4, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, bert, blenderbot, gemma, gpt_neox, lasr, llama, mbart50, nllb, openai, pegasus, reformer, siglip2, t5, xlm_roberta

@ArthurZucker ArthurZucker merged commit fd6bc38 into main Mar 4, 2026
27 checks passed
@ArthurZucker ArthurZucker deleted the bad_models_update branch March 4, 2026 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants