Skip to content

[Bugfix] Multiple fixes for gpt-oss Chat Completion prompting#28729

Merged
chaunceyjiang merged 5 commits intovllm-project:mainfrom
bbrowning:chat-harmony-tests
Dec 12, 2025
Merged

[Bugfix] Multiple fixes for gpt-oss Chat Completion prompting#28729
chaunceyjiang merged 5 commits intovllm-project:mainfrom
bbrowning:chat-harmony-tests

Conversation

@bbrowning
Copy link
Contributor

@bbrowning bbrowning commented Nov 14, 2025

Purpose

This fixes multiple issues with how we were prompting gpt-oss models for Chat Completion requests. A summary of the fixes:

  • We were not setting the recipient to "assistant" on tool responses
  • We were not mapping the tool_call_id to the proper function name on tool responses, resulting in the model getting confused about which tool response matched which tool call.
  • When the model generates tool calls during a turn, we return any model output to the commentary channel as the content of that tool call response. We were not properly restoring this content as a commentary message when those tool calls were passed back in via Chat Completion messages.
  • When the model generates reasoning messages during a tool call turn, we return that as a reasoning properly on the Chat Completion response. We were not properly restoring that reasoning content as a message to the analysis channel when subsequent turns passed that back in.
  • When non-tool reasoning responses from the model were passed back in via the reasoning extra_body in subsequent turns, we were not properly turning those back into analysis messages. This is the same problem we had with reasoning in tool calls, but a slightly different fix for the non-tool call path.
  • We were not properly marking the channel as final for non-tool, non-reasoning model output when generating the Harmony messages for subsequent turns.
  • We were not dropping analysis messages that precede a model output to the final channel in all cases, resulting in us often showing the model analysis messages from previous turns where it has already generated a final output leading to increased token usage, context pollution, and not following the best practices for handling raw chain of thought outlined by OpenAI.

I added a series of tests that pass in Chat Completion requests and thoroughly validate the exact Harmony messages generated from those requests to test for all of the cases above.

Test Plan

Unit Tests

pytest tests/entrypoints/openai/test_serving_chat.py::TestServingChatWithHarmony

pytest tests/entrypoints/openai/parser/test_harmony_utils.py

Berkeley Function Calling Leaderboard v4 Results

Run vLLM built with this fix:

vllm serve openai/gpt-oss-120b \
  --tool-call-parser openai \
  --reasoning-parser openai_gptoss \
  --enable-auto-tool-choice \
  --tensor-parallel-size 4

Run bfcl generate against that vLLM:

# Assumes you're inside a local clone of bfcl and adjusted it to use openai/gpt-oss-120b model
cat <<EOF >> bfcl_eval/constants/model_config.py
MODEL_CONFIG_MAPPING = {
    "openai/gpt-oss-120b": ModelConfig(
        model_name="openai/gpt-oss-120b",
        display_name="openai/gpt-oss-120b (FC) (vLLM)",
        url="https://huggingface.co/openai/gpt-oss-120b",
        org="OpenAI",
        license="apache-2.0",
        model_handler=OpenAICompletionsHandler,
        input_price=None,
        output_price=None,
        is_fc_model=True,
        underscore_to_dot=True,
    ),
}
EOF

OPENAI_BASE_URL="http://localhost:8000/v1" \
OPENAI_API_KEY="fake" \
bfcl generate \
  --model openai/gpt-oss-120b \
  --test-category multi_turn \
  --num-threads 8

Test Result

Unit Tests

pytest -q --disable-warnings \
  tests/entrypoints/openai/test_serving_chat.py::TestServingChatWithHarmony
...........                                                       [100%]
11 passed, 5 warnings in 8.65s

pytest -q --disable-warnings \
  tests/entrypoints/openai/parser/test_harmony_utils.py
........................................................................        [100%]
72 passed, 2 warnings in 14.47s

Berkeley Function Calling Leaderboard v4 Results

Baseline, from vLLM main as of Nov 12, 2025

commit 3044195

45.0% overall accuracy

🔍 Running test: multi_turn_base
✅ Test completed: multi_turn_base. 🎯 Accuracy: 50.00%
🔍 Running test: multi_turn_long_context
✅ Test completed: multi_turn_long_context. 🎯 Accuracy: 43.00%
🔍 Running test: multi_turn_miss_func
✅ Test completed: multi_turn_miss_func. 🎯 Accuracy: 45.50%
🔍 Running test: multi_turn_miss_param
✅ Test completed: multi_turn_miss_param. 🎯 Accuracy: 41.50%

$ cat score-baseline/data_multi_turn.csv 
Rank,Model,Multi Turn Overall Acc,Base,Miss Func,Miss Param,Long Context
1,openai/gpt-oss-120b (FC) (vLLM),45.00%,50.00%,45.50%,41.50%,43.00%

With this fix on top of same commit from main

  • 53.5% overall accuracy, an improvement of 8.5%
  • 12% improvement in multi_turn_base
  • 8.5% improvement in multi_turn_long_context
  • 4.5% improvement in multi_turn_miss_func
  • 8.0% improvement in multi_turn_miss_param
✅ Test completed: multi_turn_base. 🎯 Accuracy: 62.00%                                                                                                                                                                                                                                         
🔍 Running test: multi_turn_long_context                                                                                                                                                                                                                                                        
✅ Test completed: multi_turn_long_context. 🎯 Accuracy: 51.50%                                                                                                                                                                                                                                 
🔍 Running test: multi_turn_miss_func                                                                                                                                                                                                                                                           
✅ Test completed: multi_turn_miss_func. 🎯 Accuracy: 50.00%                                                                                                                                                                                                                                    
🔍 Running test: multi_turn_miss_param                                                                                                                                                                                                                                                          
✅ Test completed: multi_turn_miss_param. 🎯 Accuracy: 49.50%

$ cat score-harmony_fixes/data_multi_turn.csv 
Rank,Model,Multi Turn Overall Acc,Base,Miss Func,Miss Param,Long Context
1,openai/gpt-oss-120b (FC) (vLLM),53.25%,62.00%,50.00%,49.50%,51.50%

Updated results after rebase on latest main as of Dec 10th

✅ Test completed: multi_turn_base. 🎯 Accuracy: 62.00%
🔍 Running test: multi_turn_long_context
✅ Test completed: multi_turn_long_context. 🎯 Accuracy: 59.00%
🔍 Running test: multi_turn_miss_func
✅ Test completed: multi_turn_miss_func. 🎯 Accuracy: 52.50%
🔍 Running test: multi_turn_miss_param
✅ Test completed: multi_turn_miss_param. 🎯 Accuracy: 46.50%
$ cat score/data_multi_turn.csv 
Rank,Model,Multi Turn Overall Acc,Base,Miss Func,Miss Param,Long Context
1,openai/gpt-oss-120b (FC) (vLLM),55.00%,62.00%,52.50%,46.50%,59.00%

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive set of fixes for prompting gpt-oss models for Chat Completion, addressing several issues related to tool calls, reasoning, and multi-turn conversations. The changes significantly improve the model's performance on function calling benchmarks. The addition of a thorough test suite to validate these fixes is a great contribution. The code is well-structured, and the fixes in vllm/entrypoints/harmony_utils.py are robust. I have one suggestion to improve the code by removing a side effect, which will enhance maintainability.

@bbrowning
Copy link
Contributor Author

Before this fix, when running the bfcl multi_turn and single_turn suits (4,441 tests) I received 8 HarmonyErrors of unexpected tokens remaining in message header and 1 HarmonyError of Unknown role: final. After this fix, I received zero HarmonyErrors on the same test suites.

So, I can also confidently say that by fixing these issues in how we were prompting the gpt-oss-120b model, we are able to reduce the frequency that it does not follow its Harmony format in its responses.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@ehfd
Copy link
Contributor

ehfd commented Nov 15, 2025

Are we able to get this merged before v0.11.1? Looks very critical.

@ZJY0516
Copy link
Member

ZJY0516 commented Nov 15, 2025

CC @chaunceyjiang

@groundsada
Copy link

Can this be merged please 🙏?

@bbrowning
Copy link
Contributor Author

I'm going to make some minor changes to address the review feedback from the two code review bots and to isolate this change entirely from the Responses API path so that we can merge this quicker without worry of it impacting Responses handling in any way. I'll get that done, re-tested, and push the updates today.

@bbrowning
Copy link
Contributor Author

Are we able to get this merged before v0.11.1? Looks very critical.

Note that even without this change, the gpt-oss support on latest merged main branch is substantially better than what v0.11.0 or v0.10.2 shipped with. I don't know the timing or cutoff for fixes in v0.11.1, but even if this does not make that release there will still be measurably better gpt-oss support in v0.11.1. With that said, I support getting this in if the window for additional fixes is still open for v0.11.1.

@bbrowning
Copy link
Contributor Author

Ok, from my perspective I've addressed the review bot comments, entirely isolated these changes to only impact the Chat Completions path, expanded the existing test_harmony_utils.py test suite so that it tests both Chat Completions and Responses paths with a common set of tests while moving logic specific to each into its own test class, and re-run my BFCL evaluation on gpt-oss-120b with the latest changes to ensure the numbers are approximately equal as before (approximately because there is a bit of run-to-run variance in this test suite as even though the inputs are identical in every test run the outputs are not entirely deterministic).

I'll keep an eye out for any other review feedback or comments required to get this in.

@ehfd
Copy link
Contributor

ehfd commented Nov 16, 2025

While I agree that tool calling and various other quirks in GPT-OSS have improved from basically unusable to a usable state before this PR, this should be designated as a milestone for v0.11.1.

@heheda12345
Copy link
Collaborator

@chaunceyjiang
Copy link
Collaborator

/cc @yeqcharlotte PTAL.

Copy link
Contributor

@alecsolder alecsolder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I think we're finally landing on an implementation that takes the learnings from all of us over the past few months :)

for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
if content:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this case for? If tool calls is populated, I can understand reasoning being populated, but not this (which would imply it would be more like a "final" message)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right that this content closer matches the final content, but now you made me double-check how we generate the outputs and we have some issues in mapping of reasoning and tool call outputs when generating Chat Completion responses. In my head, the content here matched commentary channel responses from tool calls while the reasoning key matches analysis channel responses. However, that's not necessarily the case and it looks like there are some edge cases in our reasoning and tool call parsers for gpt-oss models that can sometimes end up with this, sometimes end up with analysis or commentary messages missed entirely, or most often end up with analyis messages as reasoning and final messages as the tool call content.

In practice, this still seems to work quite well as-is. But, we can likely pick up a bit more accuracy by doing a once-over of Chat Completions reasoning/tool parsers and chat_completion_*_generator paths for harmony models to ensure streaming/non-streaming cases are all aligned with exactly how we parse input. I probably need a new test suite here that starts with Harmony messages, goes through our response parsing to Chat Completion responses, sends those back in as Chat Completion requests, and ensures the Harmony messages generated from that match what we originally sent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of what worries me here is on the response generation side, where it makes assumptions about which messages to map as reasoning vs content solely based on order of the output messages without checking of any channels:

if len(output_msgs) == 0:
# The generation has stopped during reasoning.
reasoning = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
reasoning = output_msgs[0].content[0].text
final_content = parser.current_content
else:
reasoning_msg = output_msgs[:-1]
final_msg = output_msgs[-1]
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
final_content = final_msg.content[0].text

Then, we overwrite he content from that if using a tool call parser at

tool_call_info = tool_parser.extract_tool_calls(
"",
request=request,
token_ids=token_ids, # type: ignore
)
content = tool_call_info.content
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_call_info.tool_calls,
)

The tool call parser at https://github.com/vllm-project/vllm/blob/63fed5550609b96b578d2512aefced09efe76e1e/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py only looks at harmony messages with function recipients or to the final channel.

So, we may end up mixing analysis and commentary channel messages in reasoning or even dropping one of those entirely when making the Chat Completion response. To get this request mapping to Harmony messages exactly right, I'll need to clean up and pair this with some adjustments to the Harmony message to response mapping path so I can do this deterministically.

That will take a bit of time to match up, add some round-trip tests to ensure it doesn't deviate in the future, and re-run the eval. I can tackle this next week, either as part of this or a folllow-up based on timing preferences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there were some bugs on the Chat Completion response generation side around what I outlined above. I vastly expanded the new tests in test_serving_chat.py to verify various round-trips of Chat Completion requests to Harmony messages and Harmony outputs to Chat Completion responses including multi-turn variants that then send those Chat Completion response messages back in to future Chat Completion requests.

analysis_msg = Message.from_role_and_content(
Role.ASSISTANT, reasoning_content
)
analysis_msg = analysis_msg.with_channel("analysis")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs counter to how I feel like it should work, but I think is correct.

Trying to understand OpenAI's implementation, when reasoning is output to the commentary channel, that information is meant to signal that this reasoning should be shown to the user as-is (vs needing a summarization) i.e they want to provide the actual reasoning for the tool call so the user can approve it or not vs only a summarization. But then when it is provided as input again, it doesn't matter so they just always change it back to analysis so that it can be dropped.

msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could use a comment as it differs from the OpenAI implementation, however I think there has been enough proof at this point that it increases benchmark scores so should be kept

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to this effect, but also indicating it may need further evaluation.

if not msg.channel:
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think system and developer cases here need more complex logic.

I think from a high level:

  • We should never have more than 1 system message
  • We should never have more than 1 developer message

(Can this validation be added to a test?)

Since we create system message and developer message every time in _make_request_with_harmony I'm worried about what happens with system messages here.

Since there is no "instructions" field like in responses, we need to merge the developer message created with the included tools with the developer message here, but not duplicating the tools for example. Complicated and definitely worth the tests I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not addressed this yet. This behavior should be unchanged from how it was before my PR, and I'm tempted to defer addressing this specific system/developer concern as this one is already large.

@msrodlab
Copy link

msrodlab commented Nov 19, 2025

Some >30% of the time, the model stops generating right after "reasoning_content"/"reasoning" is finished outputting with stop_reason 200012 (200012 corresponds to the <|call|> token). No final answer in the "content" field is generated. This behavior is incompatible for most clients, i.e. Cline. Out of curiosity, will this PR fix said issue?

This is prevalent in vLLM 0.11.1 (from my observation, 0.11.1 may have made it occur more frequently now) and 0.11.0

@ehfd
Copy link
Contributor

ehfd commented Nov 19, 2025

Some >30% of the time, the model stops generating right after "reasoning_content"/"reasoning" is finished outputting with stop_reason 200012 (200012 corresponds to the <|call|> token). No final answer in the "content" field is generated. This behavior is incompatible for most clients, i.e. Cline. Out of curiosity, will this PR fix said issue?

This is prevalent in vLLM 0.11.1 (from my observation, 0.11.1 may have made it occur more frequently now) and 0.11.0

@bbRLdev Yes, I think that it is plausible that the discussions and fixes here would help. It seems that very few self-hosted LLM runtimes have actually mastered OpenAI's Harmony format correctly, yet.

This isolates the fixes made for turning Chat Completion Requests into
Harmony Messages from the code path that Responses uses for
previous_input_messages so that Chat Completions support can be iterated
and improved without worrying about regressing Responses previous input
handling.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This addresses a few minor points of review feedback from my larger set
of fixes for Chat Completions and gpt-oss models:

- In `auto_drop_analysis_messages`, drop any analysis messages before
  the latest final message instead of just ones from assistant.
- Cleaned up `parse_chat_output` so that is is checking message channel
  before assigning messages as reasoning or final content instead of
  relying solely on the order of messages and assumptions about the
  number of messages present in each response that were incorrect.
- Some misc. typos, extracting copy/paste code into methods, etc.

This also adds many tests, as I found and fixed issues raised during
review feedback and manual testing. There are now round-trip tests for
converting Chat Completion requests to Harmony messages and from
converting Harmony output to Chat Completion responses, all with mocked
generation so they're part of the unit test suite. These test in both
streaming and non-streaming mode to ensure parity there.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
@bbrowning
Copy link
Contributor Author

Rebased, pushed, and re-ran the BFCL suite on the same machine as previous runs. Still seeing much better results than the baseline overall. Latest test scores with this rebased set of changes on top of latest main:

✅ Test completed: multi_turn_base. 🎯 Accuracy: 62.00%
🔍 Running test: multi_turn_long_context
✅ Test completed: multi_turn_long_context. 🎯 Accuracy: 59.00%
🔍 Running test: multi_turn_miss_func
✅ Test completed: multi_turn_miss_func. 🎯 Accuracy: 52.50%
🔍 Running test: multi_turn_miss_param
✅ Test completed: multi_turn_miss_param. 🎯 Accuracy: 46.50%
$ cat score/data_multi_turn.csv 
Rank,Model,Multi Turn Overall Acc,Base,Miss Func,Miss Param,Long Context
1,openai/gpt-oss-120b (FC) (vLLM),55.00%,62.00%,52.50%,46.50%,59.00%

Everything is above the baseline multi_turn results from before this PR, with the largest gains in the multi_turn_long_context and multi_turn_base. This is expected, as these fixes are primarily targeted at multi-turn correctness and preventing errors from compounding in the context as turns increase.

@chaunceyjiang chaunceyjiang self-assigned this Dec 11, 2025
@chaunceyjiang
Copy link
Collaborator

Copy link
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~ @bbrowning

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Dec 11, 2025
@bbrowning
Copy link
Contributor Author

https://buildkite.com/vllm/ci/builds/42879#019b0b5c-4b8b-4f9e-8d64-50e00f58d1ad @bbrowning PTAL.

I see that failure, but that's in a dependency resolution issue building a container image:

[2025-12-11T03:04:48Z] WARNING: The requested image's platform (linux/arm64/v8) does not match the detected host platform (linux/amd64/v4) and no specific platform was requested
...
[2025-12-11T03:05:16Z] ERROR: Could not find a version that satisfies the requirement decord (from mantis-vl) (from versions: none)
[2025-12-11T03:05:17Z] ERROR: No matching distribution found for decord

Search around a bit, I found #30062 (comment) where it was identified that this was a bug and the fix in the ci-infra repo for it should have merged yesterday in vllm-project/ci-infra#243. So, based on that this is unrelated to my changes, and will either have to be force merged or trigger that failed test to rerun with the newer changes merged in the ci-infra repo.

@chaunceyjiang chaunceyjiang merged commit 8f8fda2 into vllm-project:main Dec 12, 2025
48 checks passed
@bbrowning bbrowning deleted the chat-harmony-tests branch December 12, 2025 13:28
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Dec 15, 2025
…roject#28729)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…roject#28729)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
@rhajou
Copy link

rhajou commented Dec 31, 2025

This is part of which vllm version? 0.13.0 ? is there anything not released yet in 0.13.0?

@Baescott
Copy link

Baescott commented Jan 8, 2026

This is part of which vllm version? 0.13.0 ? is there anything not released yet in 0.13.0?

Yes, you could see in v0.13.0-tagged main branch. See serving_chat.py with blame on

@bbrowning
Copy link
Contributor Author

These fixes are in v0.13.0, yes. If you're still seeing any issues with gpt-oss models in vLLM 0.13.0, please raise an issue so we can dig into things.

dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…roject#28729)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed tool-calling

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.