Skip to content

[Bugfix] Map reasoning_effort="none" to enable_thinking=False for Qwen3 chat templates#38100

Open
Lidang-Jiang wants to merge 2 commits intovllm-project:mainfrom
Lidang-Jiang:fix/reasoning-effort-enable-thinking
Open

[Bugfix] Map reasoning_effort="none" to enable_thinking=False for Qwen3 chat templates#38100
Lidang-Jiang wants to merge 2 commits intovllm-project:mainfrom
Lidang-Jiang:fix/reasoning-effort-enable-thinking

Conversation

@Lidang-Jiang
Copy link
Copy Markdown

Summary

Fixes #37909

  • reasoning_effort="none" now correctly disables thinking for Qwen3/Qwen3.5 models by injecting enable_thinking=False into chat_template_kwargs
  • Injection happens at request-parse time (Pydantic model validator), so both the Jinja chat template and the reasoning parser see the flag consistently
  • For templates that don't declare enable_thinking, the kwarg is automatically filtered out by resolve_chat_template_kwargs — no side effects on other models
  • Does not override explicit chat_template_kwargs={"enable_thinking": True} if set by the user
  • Also applies the same mapping in ResponsesRequest.build_chat_params for the Responses API

Why a model validator instead of build_chat_params?

The reasoning parser is instantiated before build_chat_params runs (see serving.py:222-229), using request.chat_template_kwargs directly. If we inject enable_thinking=False only in build_chat_params, the parser never sees it — it defaults thinking_enabled=True and misclassifies all generated content as "truncated reasoning", returning content: null.

Test plan

  • 10 unit tests covering all edge cases (none/low/medium/high, explicit override, existing kwargs preserved)
  • Pre-commit hooks pass (ruff, mypy, typos, etc.)
  • E2E test with Qwen3-8B + --reasoning-parser qwen3

E2E results

Before: reasoning_effort="none" — model still thinks, content is null
{
    "choices": [
        {
            "message": {
                "content": null,
                "reasoning": null
            },
            "finish_reason": "length"
        }
    ],
    "usage": {
        "completion_tokens": 200
    }
}

The model generated 200 thinking tokens (hitting max_tokens), include_reasoning=False stripped them, and no content was returned.

After: reasoning_effort="none" — thinking disabled, content returned
{
    "choices": [
        {
            "message": {
                "content": "2 + 2 equals 4.",
                "reasoning": null
            },
            "finish_reason": "stop"
        }
    ],
    "usage": {
        "completion_tokens": 9
    }
}

Only 9 tokens generated, direct answer returned.

Normal request (no reasoning_effort) — thinking still works
{
    "choices": [
        {
            "message": {
                "content": null,
                "reasoning": "\nOkay, the user is asking \"What is 2+2?\" That's a basic arithmetic question..."
            },
            "finish_reason": "length"
        }
    ],
    "usage": {
        "completion_tokens": 200
    }
}

Normal thinking behavior is preserved.

Test commands

pytest tests/entrypoints/openai/chat_completion/test_reasoning_effort.py -v -s

Notes

  • No duplicate PR exists for this issue
  • AI assistance was used (Claude). All changes have been reviewed and tested by the human submitter.
  • This PR does not address the broader question of whether set_include_reasoning_for_none_effort should be reverted (discussed in the issue thread) — that is a separate concern

🤖 Generated with Claude Code

@mergify mergify bot added frontend qwen Related to Qwen models bug Something isn't working labels Mar 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new feature to map reasoning_effort="none" to enable_thinking=False within the chat_template_kwargs for OpenAI chat completion requests. This ensures that models like Qwen3/Qwen3.5, which utilize an enable_thinking chat-template kwarg, consistently receive this flag for both Jinja templating and reasoning parsing. The changes include a new model_validator in vllm/entrypoints/openai/chat_completion/protocol.py to inject enable_thinking=False if reasoning_effort is "none" and enable_thinking is not explicitly set. Additionally, vllm/entrypoints/openai/responses/protocol.py is updated to propagate this setting when building ChatParams. A comprehensive new test file, tests/entrypoints/openai/chat_completion/test_reasoning_effort.py, has been added to validate this behavior, covering various scenarios including preservation of existing kwargs and explicit user settings. No feedback to provide.

# For templates that don't use enable_thinking, it is
# automatically filtered out by resolve_chat_template_kwargs.
if reasoning_effort == "none":
extra_kwargs["enable_thinking"] = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think this might be too model specific, wdyt @chaunceyjiang ? There is also the concern about unnecessary warnings/errors being generated if the chat template doesn't support this argument.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents (expressed in #37909 (comment)): this is model-specific but we could temporarily introduce this to avoid breaking compatibility and encourage model providers to use reasoning_effort in their chat templates / libs now the none value has been made available to them (#36238).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good concern — let me address both points:

1. No warnings/errors for unsupported templates: resolve_chat_template_kwargs (in chat_utils.py) introspects the Jinja template's declared variables and automatically filters out any kwargs not accepted by the template. So if a chat template doesn't declare enable_thinking, the kwarg is silently dropped before rendering — no warnings, no errors.

2. On model-specificity: enable_thinking is widely adopted across reasoning model families (Qwen3, Qwen3.5, QwQ, etc.). As @scwgoire noted in the issue discussion, this serves as a reasonable short-term compatibility bridge — ideally model providers will adopt reasoning_effort directly in their chat templates over time.

I've also included the revert of set_include_reasoning_for_none_effort in this PR (per @scwgoire's suggestion), since that validator was silently dropping already-generated tokens and could discard the entire response when </think> was present in the template.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @DarkLight1337’s point. The enable_thinking flag is only supported by a small number of models, so this approach isn’t very general.

Recently, #20859 introduced a thinking_token_budget parameter to control the number of tokens used for thinking.

According to the design at the time, different values of reasoning_effort correspond to different thinking_token_budget values:

  • None → no limit
  • "none" → 0
  • "low" → 1024
  • "medium" → 2048
  • "high" → 8192

see #38204

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right enable_thinking is supported by few models, hence the proposal to add this feature as a short-term solution while model vendors transition to reasoning_effort="none" (which is only available in vLLM since 0.18.0, which may explain model vendors' enable_thinking workaround).

Our use case is our customers don't need vLLM knowledge, they simply use the service as an OpenAI compliant endpoint. The vast majority of them use off-the-shelf application or the OpenAI SDKs (and their documentation) to build their application. It is close to impossible to educate tens of thousands of them to use those "non-standard" features such as enable_thinking and thinking_token_budget, this is a showstopper.

This is why I'm trying to find solutions to implement "do not think" through the "standard" reasoning_effort="none" (can't speak for Lidang Jiang 😁). Going through the "standards" should always be the main approach whenever practical. It helps keeping the learning curve as flat as possible for end users (I am aware OpenAI API is everything but simple 😅) and also provides better adoption by off-the-shelf applications.

Hope this point of view helps 🤗

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing to #38204 — the thinking_token_budget approach is indeed more general and not model-specific. Happy to close this PR once #38204 lands if the team prefers that direction. In the meantime I've rebased to resolve the merge conflicts so CI can run cleanly.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lidang-Jiang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

…hat templates

Models like Qwen3/Qwen3.5 use an `enable_thinking` chat-template kwarg
to toggle thinking on/off, but vLLM's `reasoning_effort="none"` was not
mapped to it. This caused the model to still generate thinking tokens
(wasting compute), and the subsequent `include_reasoning=False` filter
dropped all content, returning `content: null` to the user.

Fix: add a Pydantic model validator on ChatCompletionRequest that injects
`enable_thinking=False` into `chat_template_kwargs` when
`reasoning_effort="none"`. Injecting at request-parse time (before the
reasoning-parser is created) ensures both the Jinja template and the
parser see the flag consistently. For templates that don't declare the
variable, it is filtered out by `resolve_chat_template_kwargs`.

Also apply the same mapping in `ResponsesRequest.build_chat_params` for
the Responses API.

Signed-off-by: Lidang Jiang <lidangjiang@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
Remove the `set_include_reasoning_for_none_effort` validator that
forcibly set `include_reasoning=False` when `reasoning_effort="none"`.

With the `enable_thinking=False` injection already in place, this
validator is redundant for models that support the flag (Qwen3 etc.)
and harmful for models that don't — it silently drops already-generated
reasoning tokens, and can even discard the entire response when
`</think>` is already present in the chat template.

`include_reasoning` now stays at its default (`True`), letting users
control it independently if needed.

Closes vllm-project#37909

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
@Lidang-Jiang Lidang-Jiang force-pushed the fix/reasoning-effort-enable-thinking branch from 77d5a1c to 6eb6e19 Compare April 2, 2026 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working frontend qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: "none" reasoning effort doesn't do what it says it does (and may break output)

4 participants