diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index fe482112d386..c609cfb5c067 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock import pytest +import pytest_asyncio from vllm.config import MultiModalConfig from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -17,6 +20,164 @@ OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer +from ...utils import RemoteOpenAIServer + +if TYPE_CHECKING: + from openai import OpenAI + +GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def gptoss_server(monkeypatch_module: pytest.MonkeyPatch): + with monkeypatch_module.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + args = [ + "--enforce-eager", + "--max-model-len", + "8192", + "--tool-call-parser", + "openai", + "--reasoning-parser", + "openai_gptoss", + "--enable-auto-tool-choice", + ] + with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def gptoss_client(gptoss_server): + async with gptoss_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI): + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "user", + "content": "What is the weather in Dallas, TX?" + }, + ] + + stream = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True) + + name = None + args_buf = "" + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.tool_calls: + tc = delta.tool_calls[0] + if tc.function and tc.function.name: + name = tc.function.name + if tc.function and tc.function.arguments: + args_buf += tc.function.arguments + + assert name is not None + assert len(args_buf) > 0 + + +@pytest.mark.asyncio +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI): + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "system", + "content": "you are a helpful assistant" + }, + { + "role": "user", + "content": "What is the weather in Dallas, TX?" + }, + ] + + first = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + first_msg = first.choices[0].message + assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0 + tc = first_msg.tool_calls[0] + assert tc.function is not None and tc.function.name == "get_current_weather" + args1 = tc.function.arguments + assert args1 is not None and len(args1) > 0 + + messages.append({"role": "assistant", "content": args1}) + messages.append({ + "role": "user", + "content": "Now convert to celsius and return JSON only" + }) + + second = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + second_msg = second.choices[0].message + assert (second_msg.content is not None and len(second_msg.content) > 0) or \ + (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501 + + MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py new file mode 100644 index 000000000000..0192c7d2765c --- /dev/null +++ b/tests/tool_use/test_openai_tool_parser.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +from openai_harmony import (Conversation, DeveloperContent, + HarmonyEncodingName, Message, Role, SystemContent, + load_harmony_encoding) + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "gpt2" + + +@pytest.fixture(scope="module") +def openai_tokenizer(): + # The parser does not use the tokenizer, but the constructor requires it. + return get_tokenizer(MODEL) + + +@pytest.fixture +def openai_tool_parser(openai_tokenizer): + return OpenAIToolParser(openai_tokenizer) + + +@pytest.fixture(scope="module") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 # Default from protocol.py + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.SYSTEM, + SystemContent.new(), + ), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Talk like a pirate!")), + Message.from_role_and_content(Role.USER, "Arrr, how be you?"), + Message.from_role_and_content(Role.ASSISTANT, + "This is a test").with_channel("final") + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert not extracted_info.tools_called + assert extracted_info.tool_calls == [] + assert extracted_info.content == "This is a test" + + +def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content(Role.USER, + "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None + + +def test_extract_tool_calls_multiple_tools( + openai_tool_parser, + harmony_encoding, +): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_user_location").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )), + ToolCall(function=FunctionCall( + name="get_user_location", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 078d31684425..d1ff06425fcb 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + import datetime import json from collections.abc import Iterable, Sequence @@ -18,7 +21,8 @@ Role, StreamableParser, SystemContent, TextContent, ToolDescription, load_harmony_encoding) -from vllm.entrypoints.openai.protocol import ResponseInputOutputItem +from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, + ResponseInputOutputItem) from vllm.utils import random_uuid REASONING_EFFORT = { @@ -63,13 +67,29 @@ def get_system_message( return sys_msg -def get_developer_message(instructions: Optional[str] = None, - tools: Optional[list[Tool]] = None) -> Message: +def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + +def get_developer_message( + instructions: Optional[str] = None, + tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, +) -> Message: dev_msg_content = DeveloperContent.new() if instructions is not None: dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: - function_tools = [] + function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: if tool.type in ("web_search_preview", "code_interpreter"): # These are built-in tools that are added to the system message. @@ -80,11 +100,7 @@ def get_developer_message(instructions: Optional[str] = None, raise ValueError(f"tool type {tool.type} not supported") if function_tools: function_tool_descriptions = [ - ToolDescription.new( - name=tool.name, - description=tool.description, - parameters=tool.parameters, - ) for tool in function_tools + create_tool_definition(tool) for tool in function_tools ] dev_msg_content = dev_msg_content.with_function_tools( function_tool_descriptions) @@ -148,16 +164,46 @@ def parse_response_input( return msg -def parse_chat_input(chat_msg) -> Message: - role = chat_msg["role"] - content = chat_msg["content"] +def parse_chat_input(chat_msg) -> list[Message]: + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls") + if role == "assistant" and tool_calls: + msgs: list[Message] = [] + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + name = chat_msg.get("name", "") + content = chat_msg.get("content", "") or "" + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), + content).with_channel("commentary") + return [msg] + + # Default: user/assistant/system messages with content + content = chat_msg.get("content", "") if isinstance(content, str): contents = [TextContent(text=content)] else: # TODO: Support refusal. contents = [TextContent(text=c.get("text", "")) for c in content] msg = Message.from_role_and_contents(role, contents) - return msg + return [msg] def render_for_completion(messages: list[Message]) -> list[int]: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index fff6dcd724ad..4cc22787a020 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Callable, Final, Optional, Union +from typing import TYPE_CHECKING, Callable, Final, Optional, Union import jinja2 import partial_json_parser @@ -489,6 +489,8 @@ async def chat_completion_stream_generator( get_streamable_parser_for_assistant() for _ in range(num_choices) ] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -662,13 +664,11 @@ async def chat_completion_stream_generator( if self.use_harmony: harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient for token_id in output.token_ids: harmony_parser.process(token_id) - is_reasoning = \ - harmony_parser.current_channel == "analysis" - if not request.include_reasoning and is_reasoning: - # Skip the reasoning content. - continue + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient delta_text = harmony_parser.last_content_delta or "" else: delta_text = output.text @@ -681,8 +681,7 @@ async def chat_completion_stream_generator( delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if ((tool_choice_auto or self.reasoning_parser) - and not self.use_harmony): + if tool_choice_auto or self.reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -696,11 +695,54 @@ async def chat_completion_stream_generator( current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_reasoning: - delta_message = DeltaMessage( - reasoning_content=delta_text) - else: + if cur_channel == "final": delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage( + reasoning_content=delta_text) + else: + delta_message = None + elif (cur_channel == "commentary" and cur_recipient + and cur_recipient.startswith("functions.")): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if (msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith( + "functions.")): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split( + "functions.", 1)[1] + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ]) + elif delta_text: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall( + arguments=delta_text), + ) + ]) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True + else: + delta_message = None # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] @@ -758,6 +800,7 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(tool_calls=[ delta_tool_call, ]) + tools_streamed[i] = True elif request.tool_choice == "required": assert previous_texts is not None @@ -783,6 +826,7 @@ async def chat_completion_stream_generator( if (delta_message and delta_message.tool_calls and delta_message.tool_calls[0].id is not None): history_tool_call_cnt += 1 + tools_streamed[i] = True # update the previous values for the next iteration previous_texts[i] = current_text @@ -859,6 +903,8 @@ async def chat_completion_stream_generator( current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only tool calls elif tool_choice_auto: assert tool_parser is not None @@ -871,6 +917,8 @@ async def chat_completion_stream_generator( current_token_ids=current_token_ids, delta_token_ids=output.token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only reasoning elif self.reasoning_parser: @@ -907,7 +955,10 @@ async def chat_completion_stream_generator( # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - continue + if output.finish_reason is None: + continue + else: + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -993,12 +1044,18 @@ async def chat_completion_stream_generator( ]) # Send the finish response for each request.n only once + if auto_tools_called or tools_streamed[i] or ( + self.use_harmony + and harmony_tools_streamed[i]): + finish_reason_ = "tool_calls" + else: + finish_reason_ = output.finish_reason \ + if output.finish_reason else "stop" choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason - if not auto_tools_called else "tool_calls", + finish_reason=finish_reason_, stop_reason=output.stop_reason, token_ids=(as_list(output.token_ids) if request.return_token_ids else None)) @@ -1131,31 +1188,32 @@ async def chat_completion_full_generator( logprobs = None if self.use_harmony: - reasoning_content, final_content, is_tool_call = ( - parse_chat_output(token_ids)) - if not request.include_reasoning: - reasoning_content = None - - if is_tool_call: - # TODO(woosuk): Implement tool call for gpt-oss. - # For now, only Responses API supports tool call for - # gpt-oss. - raise NotImplementedError( - "Tool call in Chat Completion API is not supported " - "for gpt-oss yet. Please use Responses API instead.") - else: - # Normal message - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=final_content, - ) + if TYPE_CHECKING: + assert self.tool_parser is not None + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + reasoning_content, content = None, tool_call_info.content + if request.include_reasoning: + reasoning_content, content, _ = parse_chat_output( + token_ids) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if is_tool_call else + finish_reason="tool_calls" + if tool_call_info.tools_called else output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason, ) @@ -1504,12 +1562,12 @@ def _make_request_with_harmony( messages.append(sys_msg) # Add developer message. - dev_msg = get_developer_message() + dev_msg = get_developer_message(tools=request.tools) messages.append(dev_msg) # Add user message. for chat_msg in request.messages: - messages.append(parse_chat_input(chat_msg)) + messages.extend(parse_chat_input(chat_msg)) # Render prompt token ids. prompt_token_ids = render_for_completion(messages) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 44aa1208a54c..35096b046136 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -16,6 +16,7 @@ from .llama_tool_parser import Llama3JsonToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser +from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser @@ -46,4 +47,5 @@ "Qwen3CoderToolParser", "SeedOssToolParser", "Step3ToolParser", + "OpenAIToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py new file mode 100644 index 000000000000..c5d59514b944 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from vllm.entrypoints.harmony_utils import parse_output_into_messages +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + + +@ToolParserManager.register_module("openai") +class OpenAIToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + token_ids: Sequence[int] | None = None, + ) -> ExtractedToolCallInformation: + if token_ids is None: + raise NotImplementedError( + "OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501 + ) + + parser = parse_output_into_messages(token_ids) + tool_calls = [] + final_content = None + + if len(parser.messages) > 0: + for msg in parser.messages: + if msg.recipient and msg.recipient.startswith("functions."): + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=msg.recipient.split("functions.")[1], + arguments=msg.content[0].text, + ), + )) + elif msg.channel == "final": + final_content = msg.content[0].text + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=final_content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + raise NotImplementedError( + "Not being used, manual parsing in serving_chat.py" # noqa: E501 + ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 0245e89f7da7..8b76a54332f8 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -256,7 +256,7 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): def verify_and_update_config(vllm_config: "VllmConfig") -> None: decoding_config = vllm_config.decoding_config if decoding_config.reasoning_backend == "": - decoding_config.reasoning_backend = "GptOss" + decoding_config.reasoning_backend = "openai_gptoss" # Increase the max capture size from 512 to 1024 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 05a72ac23bf2..3bd4d872ce22 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) from vllm.logger import init_logger @@ -14,7 +15,7 @@ logger = init_logger(__name__) -@ReasoningParserManager.register_module("GptOss") +@ReasoningParserManager.register_module("openai_gptoss") class GptOssReasoningParser(ReasoningParser): """ Reasoning parser for GptOss model. @@ -39,9 +40,10 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool: return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + _, content, _ = parse_chat_output(input_ids) + if content is None: + return [] + return self.model_tokenizer.encode(content) def extract_reasoning_content_streaming( self, @@ -52,13 +54,34 @@ def extract_reasoning_content_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + prev_reasoning, prev_content, _ = parse_chat_output( + list(previous_token_ids)) + cur_reasoning, cur_content, _ = parse_chat_output( + list(current_token_ids)) + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + if cur_reasoning.startswith(prev_r): + reasoning_delta = cur_reasoning[len(prev_r):] or None + else: + reasoning_delta = cur_reasoning + if cur_content is not None: + prev_c = prev_content or "" + if cur_content.startswith(prev_c): + content_delta = cur_content[len(prev_c):] or None + else: + content_delta = cur_content + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning_content=reasoning_delta, + content=content_delta) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: ChatCompletionRequest, ) -> tuple[Optional[str], Optional[str]]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + raise NotImplementedError( + "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 + )