Skip to content

[Bugfix] Add num_special_tokens_to_add to MistralTokenizer, fixes #22013#22121

Closed
ShUl0w wants to merge 1 commit intovllm-project:mainfrom
ShUl0w:22013-random-dataset-serve-benchmark-throws
Closed

[Bugfix] Add num_special_tokens_to_add to MistralTokenizer, fixes #22013#22121
ShUl0w wants to merge 1 commit intovllm-project:mainfrom
ShUl0w:22013-random-dataset-serve-benchmark-throws

Conversation

@ShUl0w
Copy link

@ShUl0w ShUl0w commented Aug 2, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Test Plan

Manually testing the serve benchmark without the changes and with the changes in place.

uv run vllm/entrypoints/openai/api_server.py --model /models/ministral-8b-instruct-2410 --gpu-memory-utilization 0.8 --max-model-len 256 --tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral --quantization bitsandbytes --kv-cache-dtype fp8 --max-num-seqs 1 --max-num-batched-tokens 512 --served-model-name ministral-8b-instruct-2410 --enable-auto-tool-choice --disable-log-requests
uv run benchmarks/benchmark_serving.py --backend openai-chat --model /models/ministral-8b-instruct-2410 --served-model-name ministral-8b-instruct-2410 --endpoint /v1/chat/completions --dataset-name random --num-prompts 1 --tokenizer_mode mistral --base-url "http://0.0.0.0:8000" --random-input-len 64 --random-output-len 64

Test Result

Manually testing without the changes produces an error:

Traceback (most recent call last):
  File "/home/user/llm/benchmarks/benchmark_serving.py", line 1299, in <module>
    main(args)
  File "/home/user/vllm/.venv/lib/python3.12/site-packages/typing_extensions.py", line 2956, in wrapper
    return arg(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/user/vllm/benchmarks/benchmark_serving.py", line 774, in main
    input_requests = dataset_mapping[args.dataset_name]()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/vllm/benchmarks/benchmark_serving.py", line 763, in <lambda>
    "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/vllm/benchmarks/benchmark_dataset.py", line 314, in sample
    num_special_tokens = tokenizer.num_special_tokens_to_add()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MistralTokenizer' object has no attribute 'num_special_tokens_to_add'. Did you mean: 'all_special_tokens_extended'

Manually testing with the changes produces the desired result, i.e. a full run of the benchmark:

============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  2.63      
Total input tokens:                      61        
Total generated tokens:                  64        
Request throughput (req/s):              0.38      
Output token throughput (tok/s):         24.32     
Total Token throughput (tok/s):          47.51     
---------------Time to First Token----------------
Mean TTFT (ms):                          369.13    
Median TTFT (ms):                        369.13    
P99 TTFT (ms):                           369.13    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.90     
Median TPOT (ms):                        35.90     
P99 TPOT (ms):                           35.90     
---------------Inter-token Latency----------------
Mean ITL (ms):                           35.34     
Median ITL (ms):                         35.86     
P99 ITL (ms):                            36.38     
==================================================

(Optional) Documentation Update

@github-actions
Copy link

github-actions bot commented Aug 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an AttributeError in the MistralTokenizer by adding the num_special_tokens_to_add method. While the fix unblocks the benchmark, the implementation is conceptually mismatched with the standard tokenizer API and has inconsistencies that could lead to future bugs. I've suggested a refactoring to improve clarity, correctness, and safety.

Comment on lines 523 to 531
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation of num_special_tokens_to_add has a few issues that could lead to incorrect behavior and confusion for future developers:

  1. Conceptual Mismatch: The method name num_special_tokens_to_add is part of the Hugging Face Transformers tokenizer API and typically reflects the number of special tokens added by the encode() method. However, this implementation seems to calculate the tokens for apply_chat_template() (i.e., [INST], [/INST], and a BOS token). This is misleading, as self.encode() only adds a BOS token by default. This can lead to incorrect token calculations and truncation if used outside of the specific benchmark context.

  2. Unsupported pair Logic: The logic for pair=True is not supported by the tokenizer's __call__ method, which currently ignores the text_pair argument. This inconsistency means any code calling num_special_tokens_to_add(pair=True) will get a number that doesn't match the tokenizer's actual behavior, leading to bugs.

  3. Redundant bos_token_id Check: The bos_token_id property is typed to return an int and likely never None. The hasattr check is also redundant. This makes the logic unnecessarily complex.

To improve correctness and clarity, I suggest refactoring this method to clearly state its purpose and remove the unsupported pair logic. The implementation should be simplified to reflect the tokens added for a single-turn chat prompt.

Here is a suggested implementation with improved comments:

Suggested change
def num_special_tokens_to_add(self, pair: bool = False) -> int:
# accomodates for [INST] and [/INST], which are always added
num_tokens = 2
# MistralTokenizer does not appear to add an eos token
if hasattr(self, "bos_token_id") and self.bos_token_id is not None:
num_tokens += 1
# SentencePiece adds <0x0A><0x0A> in this case, tekken does not
if pair and self.is_spm:
num_tokens += 2
return num_tokens
def num_special_tokens_to_add(self, pair: bool = False) -> int:
"""
Returns the number of special tokens added by the chat template for a
single-turn conversation.
Note: This method's behavior is specific to chat completion and does not
reflect the special tokens added by the `encode()` method.
"""
if pair:
# The `__call__` method of this tokenizer does not support encoding
# pairs of sequences.
raise NotImplementedError(
"Encoding token pairs is not supported by this tokenizer."
)
# A single-turn chat prompt is templated with <s>, [INST], and [/INST].
# self.tokenizer.bos_id is expected to be always present.
num_tokens = 3 # BOS + [INST] + [/INST]
return num_tokens

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the points outlined by the automatic code review, however, they deviate from my prior inspection into the respective implementations.

1 & 3: Please find my reasoning for the implementation here. I don't wan't to rule out a misunderstanding of the tokenizer on my end.
2: I tried to stick as close to the original implementation of the function in transformers as possible, without adding any further possibly confusing functions (i.e. build_input_with_special_tokens).

As this is my first PR to the project, I'd be happy to receive guidance on how to proceed, and what changes I should add.

@patrickvonplaten
Copy link
Collaborator

Hey @ShUl0w,

Thanks for the PR! I don't think the bool pair: bool is used anywhere - see:

num_special_tokens = tokenizer.num_special_tokens_to_add()
.

I don't think we have to / should stick to the transformers design as this is not a tokenizer in transformers format.
Can we remove the pair: bool & logic and then happy to merge!

Signed-off-by: ShUl0w <37832993+ShUl0w@users.noreply.github.com>
@ShUl0w ShUl0w force-pushed the 22013-random-dataset-serve-benchmark-throws branch from 9039be4 to be39b4c Compare August 4, 2025 05:57
@ShUl0w
Copy link
Author

ShUl0w commented Aug 4, 2025

@patrickvonplaten thanks a lot for your quick feedback! I've removed the pair: bool param & logic.

@ShUl0w
Copy link
Author

ShUl0w commented Aug 25, 2025

Hi @patrickvonplaten, is there anything else I can/should do in this PR? Thanks for a quick info!

@juanjucm
Copy link

Hi!

Is there any update on this? I'm (unsuccessfully) trying to benchmark mistralai/Magistral-Small-2507 and I've just found this PR. It would be awesome to have this bugfix shipped ☺️

Thanks in advance!!

(cc @ShUl0w @patrickvonplaten)

@hmellor
Copy link
Member

hmellor commented Dec 5, 2025

Superseded by #30009

@hmellor hmellor closed this Dec 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Random Dataset Serve Benchmark throws AttributeError when using MistralTokenizer

4 participants