Skip to content

[Bugfix]: Fix structured output in multi-turn gpt-oss#34454

Merged
vllm-bot merged 6 commits intovllm-project:mainfrom
bbrowning:gptoss-multiturn-reasoning-structured
Feb 13, 2026
Merged

[Bugfix]: Fix structured output in multi-turn gpt-oss#34454
vllm-bot merged 6 commits intovllm-project:mainfrom
bbrowning:gptoss-multiturn-reasoning-structured

Conversation

@bbrowning
Copy link
Copy Markdown
Contributor

@bbrowning bbrowning commented Feb 12, 2026

Purpose

The logic in the gptoss_reasoning_parser to detect when the model has finished outputting reasoning content and starting to output content to the final channel was inadvertently matching on final channel messages from previous messages for multi-turn scenarios. In practice this meant that vLLM started applying the grammar bitmasks to the entirety of the model's output in these multi-turn conversations prematurely, causing the model to deviate from its trained Harmony format and lead to empty or invalid outputs.

This PR fixes things by never looking for the final channel marker in any message prior to the current one the model is generating so that we don't falsely believe the model is starting generation of the final channel unless it's actually doing so during this turn of the conversation.

Prior to vLLM v0.13.0 this bug existed but we didn't actually trip over it because the way we handle multi-turn conversation state with gpt-oss models was missing important tokens that coincidentally caused those prior conversations to not actually match these token id checks. But, once we fixed multi-turn conversation state, that caused structured output usage with things like json_object response formats to then hit this bug in the reasoning parser.

Fixes #32791

Test Plan

I added a unit test specifically to cover this case, following test-driven-development by ensuring the test failed initially, applied my fix, and then ensured the test passed.

The existing and new gptoss_reasoning_parser unit tests were run via:

pytest tests/reasoning/test_gptoss_reasoning_parser.py
pytest tests/v1/structured_output/test_gptoss_structural_tags.py
pytest tests/entrypoints/openai/test_gptoss_structural_tags_integration.py

Additionally, I ran the manual reproducer (labeled as case 3) in #32791:

vllm serve openai/gpt-oss-20b \
  --tool-call-parser openai \
  --enable-auto-tool-choice

curl -s http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer dummy" \
  -d '{
    "model": "openai/gpt-oss-20b",
    "messages": [
      {
        "role": "user",
        "content": "Respond with JSON only in the form {\"response\":\"hello\"}."
      },
      {
        "role": "assistant",
        "content": "{\"response\":\"hello\"}"
      },
      {
        "role": "user",
        "content": "Respond with JSON only in the form {\"response\":\"bye\"}."
      }
    ],
    "response_format": { "type": "json_object" },
    "max_tokens": 128,
    "temperature": 0
  }' | jq .

Test Result

All the unit tests passed.

For the manual curl test, prior to this change it gave a response with empty content:

{
  "id": "chatcmpl-81416dae965f4f7d",
  "object": "chat.completion",
  "created": 1770920903,
  "model": "openai/gpt-oss-20b",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": null,
        "refusal": null,
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": [],
        "reasoning": null
      },
      "logprobs": null,
      "finish_reason": "stop",
...

After this change, the model gives the expected response:

{
  "id": "chatcmpl-9c7eb34a997d07e2",
  "object": "chat.completion",
  "created": 1770923019,
  "model": "openai/gpt-oss-20b",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "{\"response\":\"bye\"}",
        "refusal": null,
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": [],
        "reasoning": "The user wants JSON only: {\"response\":\"bye\"}. So output that."
      },
      "logprobs": null,
      "finish_reason": "stop",
...

@mergify mergify bot added gpt-oss Related to GPT-OSS models bug Something isn't working labels Feb 12, 2026
The logic in the gptoss_reasoning_parser to detect when the model has
finished outputting reasoning content is is starting to output content
to the final channel was inadvertently matching on final channel
messages from previous messages for multi-turn scenarios. In practice
this meant that vLLM started applying the grammar bitmasks to the
entirety of the model's output in these multi-turn conversations
prematurely, causing the model to deviate from its trained Harmony
format and lead to empty or invalid outputs.

This PR fixes things by never looking for the final channel marker in any
message prior to the current one the model is generating so that we
don't falsely believe the model is starting generation of the final
channel unless it's actually doing so during this turn of the
conversation.

Prior to vLLM v0.13.0 this bug existed but we didn't actually trip over
it because the way we handle multi-turn conversation state with gpt-oss
models was missing important tokens that coincidentally caused those
prior conversations to not actually match these token id checks. But,
once we fixed multi-turn conversation state, that caused structured
output usage with things like `json_object` response formats to then hit
this bug in the reasoning parser.

Fixes vllm-project#32791

Signed-off-by: Ben Browning <bbrownin@redhat.com>
@bbrowning bbrowning force-pushed the gptoss-multiturn-reasoning-structured branch from 65c163e to c851d60 Compare February 12, 2026 19:27
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the gptoss_reasoning_parser that caused premature termination of reasoning in multi-turn conversations, leading to incorrect structured outputs. The fix, which involves stopping the backward search for the end-of-reasoning marker upon encountering a message boundary from a previous turn, is logical and well-implemented. The inclusion of a specific unit test to cover this multi-turn scenario is a great addition and significantly improves the robustness of the parser. Overall, the changes are excellent and effectively resolve the described issue. I have one suggestion to further improve the robustness of the code.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Instead of .encode followed by taking the first token, it's cleaner to just directly use model_tokenizer.vocab to fetch single token ids.

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Ben Browning <bbrownin@redhat.com>
Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Feb 13, 2026
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 13, 2026 13:13
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 13, 2026
CI discovered some additional tests that use gptoss_reasoning_parser but
with a mocked tokenizer. So, this adds a mocked `vocab` to that mock
tokenizer so that these tests also pass.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
auto-merge was automatically disabled February 13, 2026 14:18

Head branch was pushed to by a user without write access

@bbrowning
Copy link
Copy Markdown
Contributor Author

CI picked up some additional tests that used gptoss_reasoning_parser but with a mocked tokenizer that failed after adjusting to use .vocab instead of .encode. So, I pushed one more commit adding a vocab mock to those mock tokenizers, grepped the tests to ensure no other tests use gptoss_reasoning_parser that need updating, and updated the test plan in the PR description to reflect running the 3 unit tests that touch this code:

pytest tests/reasoning/test_gptoss_reasoning_parser.py
pytest tests/v1/structured_output/test_gptoss_structural_tags.py
pytest tests/entrypoints/openai/test_gptoss_structural_tags_integration.py

The latter two failed and caught by CI, but are passing locally now.

@bbrowning
Copy link
Copy Markdown
Contributor Author

bbrowning commented Feb 13, 2026

The amd-basic-correctness test failure looks unrelated, but I left a comment on the recently merged PR (32993) that added those tests so the authors are aware that test is failing on AMD hardware.

@vllm-bot vllm-bot merged commit fd267bc into vllm-project:main Feb 13, 2026
45 of 47 checks passed
@bbrowning bbrowning deleted the gptoss-multiturn-reasoning-structured branch February 13, 2026 19:13
wzhao18 pushed a commit to wzhao18/vllm that referenced this pull request Feb 18, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
ZJY0516 pushed a commit to ZJY0516/vllm that referenced this pull request Feb 23, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…4454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
## Summary

Cherry-pick upstream bug fixes for RHAIIS 3.3.1 onto `rhai/0.13.0`. All
fixes are from upstream vLLM `main` and address critical bugs affecting
RHAIIS 3.3.0. Other releases (3.2.2, EAx) will be done separately.

**Jira Epic:**
[INFERENG-4743](https://issues.redhat.com/browse/INFERENG-4743)

## Cherry-picked commits (chronological order)

| # | Upstream PR | Jira | Summary |
|---|------------|------|---------|
| 1 | [vllm-project#30550](vllm-project#30550) |
[INFERENG-5106](https://issues.redhat.com/browse/INFERENG-5106) |
Support using chat template as custom score template for reranking
models |
| 2 | [vllm-project#31406](vllm-project#31406) |
[INFERENG-4800](https://issues.redhat.com/browse/INFERENG-4800) | Add
encoder-only/cross attention support to Triton Attention backend |
| 3 | [vllm-project#34243](vllm-project#34243) |
[INFERENG-4746](https://issues.redhat.com/browse/INFERENG-4746) | Fix
Llama-4 attn quantization by correctly permuting scales for rope (int8,
fp8) |
| 4 | [vllm-project#34454](vllm-project#34454) |
[INFERENG-5032](https://issues.redhat.com/browse/INFERENG-5032) | Fix
structured output in multi-turn GPT-OSS (content:null with json_object)
|
| 5 | [vllm-project#34507](vllm-project#34507) |
[INFERENG-5038](https://issues.redhat.com/browse/INFERENG-5038) | Fix
fused MoE int32 overflow in stride*offset for large models |
| 6 | [vllm-project#35085](vllm-project#35085) |
[INFERENG-5028](https://issues.redhat.com/browse/INFERENG-5028) |
Gracefully disable AllReduceFusionPass on GPUs without multicast support
|
| 7 | [vllm-project#35456](vllm-project#35456) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) |
Replace assert with ValueError for response_format validation
(completions) |
| 8 | [vllm-project#35510](vllm-project#35510) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) | Add
response_format validation to chat completions endpoint |


## Conflict resolutions

<details>
<summary><b>#1 — llama-nemotron-embed / score-template support
(vllm-project#30550)</b>: Clean cherry-pick, no conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#2 — Triton Attention (vllm-project#31406)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#3 — Llama-4 attn quant (vllm-project#34243)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly. 4 intermediate upstream commits touch `llama4.py` but
the fix targets a self-contained block.
</details>

<details>
<summary><b>vllm-project#4 — GPT-OSS multi-turn (vllm-project#34454)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly despite 3 intermediate upstream commits that refactored
imports in `gptoss_reasoning_parser.py`. The fix logic (adding
`eom_token_id` early-exit check in `is_reasoning_end`) was independent
of the import changes.
</details>

<details>
<summary><b>vllm-project#5 — Fused MoE int32 overflow (vllm-project#34507)</b>: Conflicts in 2
files</summary>

**`vllm/model_executor/layers/fused_moe/fused_moe.py`**: ~30
intermediate upstream commits refactored `fused_moe_kernel` with
conditional `naive_block_assignment` logic that doesn't exist in
`rhai/0.13.0`. Resolved by keeping our simpler code and applying only
the int64 cast fix:
- `fused_moe_kernel_gptq_awq`: added `.to(tl.int64)` to `tl.load()`
result
- `fused_moe_kernel`: added `offs_token = offs_token.to(tl.int64)`
before `token_mask`

**`tests/kernels/moe/test_moe.py`**: Upstream test changes depend on
`make_dummy_moe_config()` from intermediate refactors. Resolved by
keeping our existing test code (no test changes).
</details>

<details>
<summary><b>vllm-project#6 — AllReduceFusionPass multicast (vllm-project#35085)</b>: Conflict
due to file rename + API change</summary>

Upstream moved `collective_fusion.py` →
`compilation/passes/fusion/allreduce_rms_fusion.py` and changed the API
from `trtllm_create_ipc_workspace_for_all_reduce_fusion()` to
`create_allreduce_fusion_workspace()`. Resolved by applying the
try/except wrapper around our existing
`trtllm_create_ipc_workspace_for_all_reduce_fusion()` call in
`collective_fusion.py`. The error handling logic (catching RuntimeError
with "multicast" in message, logging warning, returning early) is
identical to upstream.
</details>

<details>
<summary><b>vllm-project#7 — response_format validation for completions
(vllm-project#35456)</b>: Conflict due to file restructuring</summary>

Upstream split `protocol.py` into `completion/protocol.py` and
`chat_completion/protocol.py`. Our branch still has the monolithic
`protocol.py`. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`CompletionRequest` in our `protocol.py`
- Using `ValueError` instead of upstream's `VLLMValidationError` (which
doesn't exist in our branch; `ValueError` is already handled as 400 Bad
Request in `serving_engine.py`)
- Test additions from upstream applied cleanly to
`test_completion_error.py`
</details>

<details>
<summary><b>vllm-project#8 — response_format validation for chat completions
(vllm-project#35510)</b>: Conflict due to file restructuring</summary>

Same file restructuring issue as vllm-project#6. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/chat_completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`ChatCompletionRequest` in our `protocol.py`
- Only accepting the `test_json_schema_response_format_missing_schema`
test from the conflict (discarding ~140 lines of intermediate upstream
tests that reference non-existent paths in our branch)
</details>

## Test plan

- [ ] Verify `llama-nemotron-embed-1b-v2` works correctly with the
backported score-template / bidirectional model support
- [ ] Verify Llama-4 quantized model loads correctly with int8/fp8
attention quantization
- [ ] Verify GPT-OSS multi-turn chat with `json_object` response_format
returns valid content
- [ ] Verify large MoE models (e.g. Qwen3.5-397B) don't crash with int32
overflow
- [ ] Verify MoE model loading on H200 GPUs (without multicast)
gracefully falls back
- [ ] Verify `response_format: {type: "json_schema"}` without
`json_schema` field returns 400 (not 500) for both `/v1/completions` and
`/v1/chat/completions`
- [ ] Verify encoder models (e.g. Whisper) work with Triton attention
backend on ROCm


[INFERENG-4743]:
https://redhat.atlassian.net/browse/INFERENG-4743?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4800]:
https://redhat.atlassian.net/browse/INFERENG-4800?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4746]:
https://redhat.atlassian.net/browse/INFERENG-4746?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5032]:
https://redhat.atlassian.net/browse/INFERENG-5032?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5038]:
https://redhat.atlassian.net/browse/INFERENG-5038?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

[INFERENG-5106]:
https://redhat.atlassian.net/browse/INFERENG-5106?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed structured-output v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: chat.completions returns content: null for GPT-OSS multi-turn with json_object

3 participants