Skip to content

[Feature] limit thinking tokens (hard limit)#20859

Open
llsj14 wants to merge 95 commits intovllm-project:mainfrom
llsj14:feat/thinking-budget
Open

[Feature] limit thinking tokens (hard limit)#20859
llsj14 wants to merge 95 commits intovllm-project:mainfrom
llsj14:feat/thinking-budget

Conversation

@llsj14
Copy link
Contributor

@llsj14 llsj14 commented Jul 12, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Related Issue:

Purpose

  • Support limiting thinking tokens and forcing thinking end tokens via logit processors.
  • This feature is designed to prevent uncontrolled long reasoning loops and to enforce an explicit thinking limit (budget), which is a form of hard limiting. This is different from reasoning_level and the soft limiting technique used in gpt-oss.
  • This PR aims to address service pain points:
    • The soft limiting technique (e.g., reasoning_level in gpt-oss) is only available in certain models. Without such models, controlling thinking tokens in the current vLLM implementation requires making two separate API calls. (Example with the Qwen model.) This is a major pain point because, when making API calls twice, there is no guarantee they will be routed to the same server node unless the system has prefix aware routing.
    • Even with soft limiting techniques based on prompting, models often generate repetitive reasoning content or produce tokens related to the prompt instructions themselves, which negatively affects output quality. This issue has been reported by the service team, and it highlights the clear need for a hard limit in addition to soft limiting.
  • discussions: https://vllm-dev.slack.com/archives/C07QQ8DAXMK/p1755614191019169

Implementation

  • If the number of thinking tokens exceeds the thinking_token_budget sampling parameter, the logits processor will forcibly insert the thinking end token IDs to terminate the thinking section.
  • This feature extends the built-in logits processor.
  • By default, if no reasoning config is explicitly set, this logits processor has no effect on the workflow. In other words, the feature is disabled by default and remains a non-intrusive addition to the existing implementation.
  • This feature also works correctly with multiple tokens for thinking start and end token IDs.

Test

Integration Test

pytest tests/v1/logits_processors/test_correctness.py

Online Serving Test

  • server script
python3 -m vllm.entrypoints.openai.api_server \
     --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
     --port 8085 \
     --trust-remote-code \
     --gpu-memory-utilization 0.8 \
     --reasoning-parser "deepseek_r1" \
     --reasoning-config '{"think_start_str": "<think>", "think_end_str": " Now formulate the final answer.</think>"}'
  • request
curl http://localhost:8085/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
  "messages": [
    { "role": "system", "content": "9.11 and 9.8, which is greater?" }
  ],
  "thinking_token_budget": 10,
  "max_tokens": 50
}'
  • response
{"id":"chatcmpl-2f99acbe8d0d47fa9ffbaa35aecb9274","object":"chat.completion","created":1755611380,"model":"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"choices":[{"index":0,"message":{"role":"assistant","content":"\n\nTo determine which number is greater between **9.11** and **9.8**, let's compare them step by step.\n\n1. **Compare the Whole",
"refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],
"reasoning_content":"First, I'll compare the two numbers, Now formulate the final answer."},
"logprobs":null,"finish_reason":"length","stop_reason":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":18,"total_tokens":68,"completion_tokens":50,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_transfer_params":null}

Checking overhead

Model Size: 1.5B (DeepSeek-R1-Distill-Qwen-1.5B)
Max tokens: 512, min_tokens: 512
thinking_token_budgets: 20

The results show that its overhead is almost zero in the median.
thinking_token_budgets would be larger than 20, as it will serve as a hard limit on top of soft limiting (prompting), which makes the overhead even smaller.

image

(Optional) Documentation Update


Note

Introduces hard “thinking” token limiting and wires it through config, APIs, and sampling.

  • Adds ReasoningConfig with think_start_str/think_end_str and derived think_start_token_ids/think_end_token_ids; initializes IDs in VllmConfig.__post_init__ and exposes via --reasoning-config CLI
  • Adds thinking_token_budget to SamplingParams and OpenAI chat request; plumbs through request construction
  • New builtin ThinkingTokenBudgetLogitsProcessor that tracks per-request state and, once budget is reached after a think_start, forces think_end token IDs (argmax-variant) while masking others
  • Registers processor in logits pipeline and test harness; extends tests to validate budget counting, end-token forcing, and isolates it from other processors

Written by Cursor Bugbot for commit f1aefbb022ad5c04af4e88163d979e1517da178c. This will update automatically on new commits. Configure here.


Note

Introduces hard limiting of "thinking" tokens and forces end-of-thinking tokens when the budget is reached.

  • New builtin ThinkingTokenBudgetLogitsProcessor tracks per-request state; after think_start, counts tokens and, once budget is met, forces think_end token IDs (registered in logits pipeline)
  • Adds ReasoningConfig (with think_start_str/think_end_str → token IDs); initialized in VllmConfig.__post_init__, exported in vllm.config, and exposed via --reasoning-config CLI; plumbed through EngineArgs to VllmConfig
  • Extends request surface: SamplingParams gains thinking_token_budget; OpenAI chat protocol accepts thinking_token_budget and passes it through request construction
  • Updates tests to cover the new processor (mock reasoning config, budget counting, end-token forcing); adjusts test case generation to isolate it from other processors

Written by Cursor Bugbot for commit b031c57. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit b600cd0. Configure here.


Note

Introduces hard “thinking” token limiting and wires it through config, APIs, and sampling.

  • New builtin ThinkingTokenBudgetLogitsProcessor counts tokens after think_start and, once budget is reached, forces think_end token IDs; registered in logits pipeline and state updates
  • Adds ReasoningConfig with think_start_str/think_end_str → token IDs; initialized in VllmConfig.__post_init__ and exposed via --reasoning-config CLI
  • Extends request surface: SamplingParams and OpenAI chat request accept thinking_token_budget and plumb it through request construction
  • Updates tests to cover budget counting and end-token forcing; isolates this processor from others

Written by Cursor Bugbot for commit fbaaf12. This will update automatically on new commits. Configure here.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added deepseek Related to DeepSeek models frontend v1 labels Jul 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @llsj14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial feature to manage and limit the length of 'thinking' or 'reasoning' phases in large language models that employ explicit reasoning tokens. By allowing users to set a max_think_tokens budget, the system can prevent uncontrolled long reasoning loops, ensuring more predictable and efficient model behavior. The core of this feature is a new logits processor that monitors token generation within designated thinking sections and intervenes to terminate them if the specified limit is exceeded.

Highlights

  • New max_think_tokens parameter: Introduced a max_think_tokens parameter in SamplingParams and exposed it via the OpenAI protocol's ChatCompletionRequest. This allows users to specify a maximum token limit for the 'thinking' phase of models that utilize explicit reasoning tokens.
  • ReasoningConfig and Dynamic Token ID Management: Added a new ReasoningConfig class to vllm/config.py to encapsulate think_start_token_id and think_end_token_id. These IDs are now dynamically populated in GpuModelRunner based on the configured reasoning backend (e.g., DeepSeek R1), ensuring the system correctly identifies and manages reasoning sections.
  • MaxThinkTokensLogitsProcessor Implementation: Implemented a new MaxThinkTokensLogitsProcessor in vllm/v1/sample/logits_processor.py. This processor actively monitors the number of tokens generated within a thinking section. If the max_think_tokens limit is reached, it modifies the logits to forcibly generate the think_end_token_id, effectively terminating the reasoning loop.
  • Enhanced State Tracking for Logits Processors: Modified the AddedRequest tuple in vllm/v1/sample/logits_processor.py and vllm/v1/worker/gpu_input_batch.py to include prompt_tok_ids. This provides logits processors, especially the new MaxThinkTokensLogitsProcessor, with more complete context for tracking token counts from the beginning of a request's generation.
  • Integration Across the Stack: The new max_think_tokens parameter and the ReasoningConfig are integrated throughout the system, from the API request parsing to the SamplingParams, GpuInputBatch, and finally into the LogitsProcessorManager to ensure the thinking token limit is enforced during the token generation process.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@llsj14 llsj14 force-pushed the feat/thinking-budget branch from c13ccf9 to 3a072f0 Compare July 12, 2025 09:25
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a feature to limit the number of "thinking" tokens generated by a model, which is a great way to prevent uncontrolled reasoning loops and manage computational budgets. The implementation adds a max_think_tokens parameter and a corresponding MaxThinkTokensLogitsProcessor to enforce this limit. I've identified a couple of issues related to correctness, particularly in edge cases and state management, which I've detailed below. Addressing these will make the feature more robust.

@llsj14 llsj14 force-pushed the feat/thinking-budget branch 3 times, most recently from 35cad4f to 4d64881 Compare July 14, 2025 04:56
@mergify
Copy link

mergify bot commented Jul 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 14, 2025
@llsj14 llsj14 force-pushed the feat/thinking-budget branch from 4d64881 to 3c4fc40 Compare July 14, 2025 05:09
@mergify mergify bot removed the needs-rebase label Jul 14, 2025
@llsj14 llsj14 force-pushed the feat/thinking-budget branch 5 times, most recently from d5b9de1 to 4c4251d Compare July 14, 2025 06:12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more appropriate to split this into separate files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s good to separate the files, but I’m just concerned about the divergence of different kinds of logits processors at the moment, since some are declared in the ops directory (e.g., bad words, penalties, top-k, top-p), while the built-in logits processors are declared in this logits_processor.py file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably create a logit_processors dir, then put diff logic processor there.

The default ones can just live under logit_processors/__init__.py, and others can have its own file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I’ll update it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has already been addressed in this PR, so I will update it as soon as the PR is merged.

Copy link
Contributor Author

@llsj14 llsj14 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the PR is merged, I moved my implementation of ThinkingBudgetLogitsProcessors into v1/sample/logits_processor/builtin.py.

vllm/config.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not introduce another class for this here. I think we can coupled this with the reasoning parser.

Copy link
Contributor Author

@llsj14 llsj14 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was quite hard to pass the reasoning parser information to the logits processors. If I don’t use ReasoningConfig, I might still need to pass the reasoning parser object to the logits processor anyways, to make logits processor get the information of think start/end token ids.

Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick drive by comments on configuration.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we introduce some heuristic with reasoning_effort. I'm thinking:

  • low -> 1024
  • medium -> 2048
  • high -> 8192

Then we can also have this as additional extra_body for users to override if they have custom context length set to vllm server here.

Copy link
Contributor Author

@llsj14 llsj14 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. So the user should only provide "reasoning_effort": [low, medium, high] as the sampling parameter? What I’m a bit concerned about is that it’s hard to control at the token level, and it’s only configurable when the server loads.

Copy link
Collaborator

@aarnphm aarnphm Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasoning_effort are mostly for openai compatible endpoint. If users want more control, we then respect thinking_token_budget or some naming in the body instead of reasoning_effort.

Two scenarios:

  • Users who already uses reasoning_effort from openai frontend: nothing changes for them
  • If they want to increase the thinking budget, knowing that the model context length supports it:
    client.chat.completions.create(..., 
                                   reasoning_effort="medium", # we ignore reasoning_effort here for thinking_tokens_budget
                                   extra_body={"thinking_tokens_budget": 16384}
                                  )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this should be included in the max_tokens calculation as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your feedbacks. I add this parameter with thinking_tokens_budget.

Copy link
Contributor Author

@llsj14 llsj14 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had applied reasoning_effort, but it became a sampling parameter for soft limit of thinking tokens which is used by chat_template.jinja file.
So I broke the connection between reasoning_effort and thinking_budget_tokens.

vllm/config.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things:

  • You've put this between additional_config and the comment above explainint what it is
  • There's no need to make this config Optional you can default construct the actual config as follows:
Suggested change
reasoning_config: Optional[ReasoningConfig] = None
reasoning_config: ReasoningConfig = field(default_factory=ReasoningConfig)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's avoid changing this, I don't think this is related to this PR.

Copy link
Contributor Author

@llsj14 llsj14 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this part, because I needed the start/end token ids from reasoning parser for logits processor, which needs the starting point and the end point of thinking mode.
I referenced this part as reasoning_parser.think_start_token_id for both qwen and deepseek models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's avoid changing this, I don't think this is related to this PR.

+1.

Also, as shown in hunyuan_a13b_reasoning_parser.py, think_start_ids consists of three token IDs. Using reasoning_parser.think_start_token_id directly doesn’t seem like a good approach—I suggest using a @property instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaunceyjiang Yes, I’ll update this for extensibility.
For now, I just wanted this PR to support only Qwen and DeepSeek models, which use a single token id to start and finish the thinking mode. I think we’ll need a different workflow for reasoning models that require multiple token ids, for example, they may need partial prefill after forcing multiple tokens at the end. In that case, I’m not sure if using only logits processors is the right approach. Maybe we’ll need partial prefill workflows or some help from guided decoding. What do you think about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think structured outputs is relevant here.

I think frontend-related features should be using logit processor, to avoid performance issue. But the new logit processor should be performant enough.

Copy link
Contributor Author

@llsj14 llsj14 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is quite hard to handle multiple think end tokens using a logits processor. That’s why I’m also considering implementing this feature in the serving_chat part, the scheduler, or with guided decoding.

There are several ways to implement this, each with its own drawbacks:

  1. Logits processor: I would have to enforce multiple think end tokens across multiple decode steps. It means performance degradation. (maybe it sounds still reasonable)
  2. serving_chat: I could make the reasoning_parser count think tokens and enforce think end tokens. This could be quite easy to implement, but with the current implementation, it seems hard to make the reasoning_parser check the sampling parameters of every request. It’s challenging to implement this in non Stream API.
  3. Scheduler: Similar to the verification stage of speculative decoding, we could enforce multiple tokens and make the forward step perform a partial prefill. However, it seems quite difficult and complex to make only part of the requests in a batch build a KV cache for multiple tokens. @rishitdholakia13’s implementation appears to follow this approach. but if we need to handle multiple tokens, it would get more complex.
  4. Guided decoding: Guided decoding or structured outputs have similar needs. for example, forcing certain tokens. But I think it’s also complex to manage given the prior implementations and the use of external libraries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to apply multiple think end tokens using logits processors. The methods I described above (options 2–4) are difficult to implement at the moment. So, the logits processors will produce multiple think end tokens across multiple forward steps.

Copy link
Contributor Author

@llsj14 llsj14 Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new commit, I made this feature work with start/end tokens defined as token sequences (multiple tokens).
Since the reasoning parsers do not have the same property, I needed a new config argument to get the think start/end strings (e.g., think_end_str="\n\nConsidering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>\n\n").

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the insights, I wanted to ask, what would be the advantage of using a logitsprocessor ? By that I mean, what advantage do we get by forcing the change in the logits value, as compared to just inserting the required token in the new_token_ids list present in gpu_model_runner.py

Copy link
Contributor Author

@llsj14 llsj14 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, directly inserting tokens into new_token_ids is a valid approach. I used a LogitsProcessor to preserve the sampling flow while still guiding the model to pick the desired token. For multiple end tokens, directly forcing them requires careful KV cache handling, like partial prefill and num_lookahead_slots, which can be tricky. So I designed the LogitsProcessor to force think end tokens one by one. That said, it's also possible to optimize by forcing all at once with partial prefill.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably create a logit_processors dir, then put diff logic processor there.

The default ones can just live under logit_processors/__init__.py, and others can have its own file.

@aarnphm
Copy link
Collaborator

aarnphm commented Jul 14, 2025

fyi #19912

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to thinking_budget, would help provide consistency in naming since the max thinking here would refer to the thinking budget provided by the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that’s possible. But I’m also thinking that maybe in the future min_think_tokens option will be added, which forces the model to generate at least min_think_tokens number of 'think' tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to your recommendation, I renamed it as thinking_token_budget.

@mergify
Copy link

mergify bot commented Jul 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@llsj14
Copy link
Contributor Author

llsj14 commented Jul 18, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new feature to limit thinking tokens by sampling parameters, which aims to prevent uncontrolled long reasoning loops and support explicit thinking limits. The code changes include adding a ReasoningConfig class, modifying the SamplingParams class, and implementing a ThinkingTokenBudgetLogitsProcessor class. The code review identified issues related to error handling and redundant conditions, which should be addressed to ensure the code's correctness and maintainability.

for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
state1 = self._state.get(i1, {})
state2 = self._state.get(i2, {})
Copy link

@rishit13 rishit13 Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i found a bug here, where when we run multiple requests with some requests having a no thinking budget, and where there is a swap that happens, the requests that do not have a thinking budget get added to the _state dictionary if it was a part of the swap slot, causing a KeyError. I have made a simple change of using -1 as thinking budget (meaning unlimited thinking in my Spec+ thinking budget PR). This avoids the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rishit13 Thank you for pointing this out!
Instead of using -1, I thought it would be better to default to None when popping states. This is consistent with how other logits processors are implemented, and avoids unnecessary overhead from tracking states for requests that don't require a thinking budget. Adding a state entry for every such request would introduce extra overhead in state management.

llsj14 added 6 commits March 8, 2026 13:11
…method

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
"--max-model-len",
"2048",
"--enforce-eager",
"--no-async-scheduling",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an e2e test. (To run, python -m pytest tests/v1/entrypoints/openai/test_thinking_token_budget.py)

Limiting the thinking token budget works with async scheduling, but achieving exact budget enforcement is difficult, because with async scheduling, output token IDs are not updated in sync with each token generation step. I think this issue could also be addressed by the @rishitdholakia13 's following PR (#34668), which aims to enable this feature with speculative decoding. It is also a case where more than one token can be generated per step.

Copy link
Contributor

@rishitdholakia13 rishitdholakia13 Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i have addressed the issue in the spec + thinking budget PR and added e2e tests as well that, ensure the exact thinking budget limit is enforced with spec and non spec mode while hsing both sync and async

@mergify
Copy link

mergify bot commented Mar 9, 2026

Hi @llsj14, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

llsj14 added 2 commits March 9, 2026 08:31
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
@llsj14
Copy link
Contributor Author

llsj14 commented Mar 10, 2026

@njhill @chaunceyjiang
Thank you for taking the time to review this PR several times.
I addressed all the comments except for the automation part. I think it would be better to handle the automation in a separate PR along with some refactoring. Since the spec decode PR is already prepared, it might be better to merge this PR first and then move forward from there. I would appreciate it if you could review it again when you have time.

@mergify
Copy link

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry again fro the delay @llsj14 .. please see remaining comments.

# ThinkingTokenBudgetLogitsProcessor also needs output token ids to
# correctly track think start/end token sequences in async scheduling.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs)
or (self.vllm_config.reasoning_config is not None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
or (self.vllm_config.reasoning_config is not None),
or self.vllm_config.reasoning_config is not None,

Comment on lines +51 to +53
self.think_start_token_ids = tokenizer.convert_tokens_to_ids(
tokenizer.tokenize(self.think_start_str)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reason for replacing convert_tokens_to_ids with the encode method?

performance .. we are making one call into rust without materializing intermediate python strings

Can we use encode with add_special_tokens=False?

Comment on lines +13 to +16
think_start_str: str | None = None
"""String that indicates the start of reasoning."""
think_end_str: str | None = None
"""String that indicates the end of reasoning."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, after several improvements to the ReasoningParser, some similar interfaces have gradually been introduced internally, although they are not publicly exposed yet.

This is all internal usage right? I don't understand the relevance of exposing publicly. Still confused why we could not have done this.

`initialize_token_ids`. Not intended to be configured directly."""

def initialize_token_ids(self, model_config: ModelConfig) -> None:
"""Initialize reasoning token IDs from strings using the tokenizer."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a check here that think_start_token_ids and think_end_token_ids are None.

And we should perhaps rename them to start with an underscore and have @property accessors, as I think we do with other "derived" values in the config classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is applied by this commit, 8252175

Comment on lines +13 to +16
think_start_str: str | None = None
"""String that indicates the start of reasoning."""
think_end_str: str | None = None
"""String that indicates the end of reasoning."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main issue here is that we're exposing a new arg / config parameter externally that isn't really required, just because we don't want to go to the hassle of wiring up to the reasoning parsers.

Let's at least add a comment explaining that setting the parameter shouldn't be required and is a temporary state, that the parameter will likely be removed in a subsequent version.

_bad_words_token_ids: list[list[int]] | None = None

skip_reading_prefix_cache: bool | None = None
thinking_token_budget: int | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to add a check in the appropriate place to fail the request if thinking_token_budget is set but reasoning config is None (no logit processor initialized).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a ValueError for that situation.
45bed67

llsj14 added 5 commits March 13, 2026 13:00
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
…ig is not configured

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
@llsj14
Copy link
Contributor Author

llsj14 commented Mar 13, 2026

@njhill
Thank you for thorough reviews, and I think I resolved most of issues.

My main issue here is that we're exposing a new arg / config parameter externally that isn't really required, just because we don't want to go to the hassle of wiring up to the reasoning parsers.
Let's at least add a comment explaining that setting the parameter shouldn't be required and is a temporary state, that the parameter will likely be removed in a subsequent version.

I added a comment, and I will resolve this issue with separate PR soon. I will ask opinions of you and @chaunceyjiang again there. I think automation can be considered with not only ReasoningParsers but also model configs somehow.

@llsj14
Copy link
Contributor Author

llsj14 commented Mar 13, 2026

@userbz
I saw your comment from alarms, but couldn't find it in this PR. Did you resolve issues while experimenting with mistral model or move the comment into another issue?

I also changed tokenizer part. I replaced convert_ids_to_tokens by encode. See this commit.
Please let me know if you have further issues, and maybe we can collaborate to resolve them.

@abhinand5
Copy link

I tried a simpler approach to wire into the existing reasoning parser instead of a separate ReasoningConfig, see the draft PR #37112. Happy to consolidate efforts @njhill

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding structured-output v1

Projects

Status: Todo
Status: In review
Status: No status
Status: In progress

Development

Successfully merging this pull request may close these issues.