diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py new file mode 100644 index 000000000..769199b8f --- /dev/null +++ b/tests/test_responses_api.py @@ -0,0 +1,654 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the OpenAI-compatible Responses API.""" + +import json +import platform +import sys +from collections import OrderedDict +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +@pytest.fixture() +def client(): + from vllm_mlx.server import app + + return TestClient(app) + + +@pytest.fixture(autouse=True) +def server_state(): + import vllm_mlx.server as srv + + original_engine = srv._engine + original_model_name = srv._model_name + original_store = srv._responses_store + original_store_max_size = srv._RESPONSES_STORE_MAX_SIZE + original_api_key = srv._api_key + + srv._engine = None + srv._model_name = "test-model" + srv._responses_store = OrderedDict() + srv._RESPONSES_STORE_MAX_SIZE = 1000 + srv._api_key = None + + try: + yield + finally: + srv._engine = original_engine + srv._model_name = original_model_name + srv._responses_store = original_store + srv._RESPONSES_STORE_MAX_SIZE = original_store_max_size + srv._api_key = original_api_key + + +def _mock_engine(*outputs): + engine = MagicMock() + engine.model_name = "test-model" + engine.preserve_native_tool_format = False + engine.chat = AsyncMock(side_effect=list(outputs)) + stream_calls = [] + + async def _stream_chat(**kwargs): + stream_calls.append(kwargs) + for output in getattr(engine, "_stream_outputs", []): + yield output + + engine._stream_calls = stream_calls + engine._stream_outputs = [] + engine.stream_chat = _stream_chat + return engine + + +def _output(text: str, prompt_tokens: int = 7, completion_tokens: int = 3): + return SimpleNamespace( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason="stop", + ) + + +def _stream_output( + new_text: str, + prompt_tokens: int = 7, + completion_tokens: int = 1, + finish_reason: str | None = None, +): + return SimpleNamespace( + new_text=new_text, + text=new_text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason=finish_reason, + finished=finish_reason is not None, + ) + + +def _parse_sse_events(body: str) -> list[tuple[str, dict]]: + events = [] + for chunk in body.strip().split("\n\n"): + if not chunk.strip(): + continue + event_type = None + payload = None + for line in chunk.splitlines(): + if line.startswith("event: "): + event_type = line.removeprefix("event: ").strip() + elif line.startswith("data: "): + payload = json.loads(line.removeprefix("data: ").strip()) + if event_type is not None and payload is not None: + events.append((event_type, payload)) + return events + + +class TestResponsesEndpoint: + def test_basic_response(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello there")) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Say hello"}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "response" + assert body["output_text"] == "Hello there" + assert body["output"][0]["type"] == "message" + assert body["output"][0]["content"][0]["type"] == "output_text" + assert body["usage"]["input_tokens"] == 7 + assert body["usage"]["output_tokens"] == 3 + + def test_previous_response_id_reuses_prior_context(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "user" + assert second_messages[0]["content"] == "First prompt" + assert second_messages[1]["role"] == "assistant" + assert second_messages[1]["content"] == "First answer" + assert second_messages[2]["role"] == "user" + assert second_messages[2]["content"] == "Follow-up prompt" + + def test_previous_response_id_chains_across_multiple_follow_ups(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine( + _output("First answer"), + _output("Second answer"), + _output("Third answer"), + ) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Second prompt", + }, + ) + second_id = second.json()["id"] + + third = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": second_id, + "input": "Third prompt", + }, + ) + + assert third.status_code == 200 + third_messages = engine.chat.call_args_list[2].kwargs["messages"] + assert third_messages[0]["role"] == "user" + assert third_messages[0]["content"] == "First prompt" + assert third_messages[1]["role"] == "assistant" + assert third_messages[1]["content"] == "First answer" + assert third_messages[2]["role"] == "user" + assert third_messages[2]["content"] == "Second prompt" + assert third_messages[3]["role"] == "assistant" + assert third_messages[3]["content"] == "Second answer" + assert third_messages[4]["role"] == "user" + assert third_messages[4]["content"] == "Third prompt" + + def test_previous_response_id_does_not_carry_prior_instructions(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "First system instruction", + "input": "First prompt", + }, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "Second system instruction", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "system" + assert second_messages[0]["content"] == "Second system instruction" + assert "First system instruction" not in second_messages[0]["content"] + assert second_messages[1]["role"] == "user" + assert second_messages[1]["content"] == "First prompt" + assert second_messages[2]["role"] == "assistant" + assert second_messages[3]["role"] == "user" + + def test_previous_response_id_preserves_prior_system_message_items(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "system", "content": "Persist me"}, + {"type": "message", "role": "user", "content": "First prompt"}, + ], + }, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "system" + assert second_messages[0]["content"] == "Persist me" + assert second_messages[1]["role"] == "user" + assert second_messages[1]["content"] == "First prompt" + assert second_messages[2]["role"] == "assistant" + assert second_messages[3]["role"] == "user" + + def test_previous_response_id_missing_returns_404(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("unused")) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": "resp_missing", + "input": "Follow-up prompt", + }, + ) + + assert resp.status_code == 404 + assert "resp_missing" in resp.json()["detail"] + + def test_developer_role_is_normalized_to_system(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hi"}, + {"type": "message", "role": "developer", "content": "Be terse"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be terse" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "Hi" + + def test_instructions_and_developer_message_are_merged(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "System instructions", + "input": [ + { + "type": "message", + "role": "developer", + "content": "Developer note", + }, + {"type": "message", "role": "user", "content": "Hi"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert len([m for m in messages if m["role"] == "system"]) == 1 + assert messages[0]["role"] == "system" + assert "System instructions" in messages[0]["content"] + assert "Developer note" in messages[0]["content"] + assert messages[1]["role"] == "user" + + def test_function_call_output_input_is_mapped_cleanly(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Done")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Run it"}, + { + "type": "function_call", + "call_id": "call_1", + "name": "shell", + "arguments": '{"cmd":"pwd"}', + }, + { + "type": "function_call_output", + "call_id": "call_1", + "output": "/tmp/work", + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[1]["role"] == "assistant" + assert "[Calling tool: shell(" in messages[1]["content"] + assert messages[2]["role"] == "user" + assert "[Tool Result (call_1)]" in messages[2]["content"] + assert "/tmp/work" in messages[2]["content"] + + def test_unsupported_tools_and_items_do_not_fail(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Fallback answer")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Answer directly"}, + { + "type": "web_search_call", + "status": "completed", + "action": {"type": "search", "query": "ignored"}, + }, + ], + "tools": [ + {"type": "web_search_preview"}, + {"type": "file_search", "vector_store_ids": ["vs_123"]}, + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert "not available on this backend" in messages[0]["content"] + assert messages[1]["role"] == "user" + assert engine.chat.call_args.kwargs["tools"][0]["type"] == "function" + + def test_function_call_response_item(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine( + _output('{"name":"shell","arguments":{"cmd":"pwd"}}') + ) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Use a tool", + "tools": [ + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + } + ], + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["output"][0]["type"] == "function_call" + assert body["output"][0]["name"] == "shell" + assert body["output_text"] == "" + + def test_store_false_skips_persistence(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Ephemeral answer")) + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Do not store this", + "store": False, + }, + ) + + assert first.status_code == 200 + assert first.json()["store"] is False + assert first.json()["id"] not in srv._responses_store + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first.json()["id"], + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 404 + + def test_responses_store_is_lru_bounded(self, client): + import vllm_mlx.server as srv + + srv._RESPONSES_STORE_MAX_SIZE = 2 + srv._engine = _mock_engine( + _output("First answer"), + _output("Second answer"), + _output("Third answer"), + ) + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + second = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Second prompt"}, + ) + third = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Third prompt"}, + ) + + assert first.status_code == 200 + assert second.status_code == 200 + assert third.status_code == 200 + assert list(srv._responses_store) == [second.json()["id"], third.json()["id"]] + assert first.json()["id"] not in srv._responses_store + + def test_streaming_response_returns_sse_events(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("unused")) + engine.chat = AsyncMock( + side_effect=AssertionError("stream path should not call chat") + ) + engine._stream_outputs = [ + _stream_output("Hello ", completion_tokens=1), + _stream_output("stream", completion_tokens=2, finish_reason="stop"), + ] + srv._engine = engine + + with client.stream( + "POST", + "/v1/responses", + json={"model": "test-model", "input": "Hello", "stream": True}, + ) as resp: + body = "".join(resp.iter_text()) + + assert resp.status_code == 200 + assert "event: response.created" in body + assert "event: response.output_text.delta" in body + assert "Hello stream" in body + assert "event: response.completed" in body + assert len(engine._stream_calls) == 1 + engine.chat.assert_not_awaited() + + def test_streaming_response_sequence_metadata_is_monotonic(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("unused")) + engine.chat = AsyncMock( + side_effect=AssertionError("stream path should not call chat") + ) + engine._stream_outputs = [ + _stream_output("Hello ", completion_tokens=1), + _stream_output("stream", completion_tokens=2, finish_reason="stop"), + ] + srv._engine = engine + + with client.stream( + "POST", + "/v1/responses", + json={"model": "test-model", "input": "Hello", "stream": True}, + ) as resp: + body = "".join(resp.iter_text()) + + assert resp.status_code == 200 + events = _parse_sse_events(body) + assert [event_type for event_type, _ in events[:2]] == [ + "response.created", + "response.in_progress", + ] + sequence_numbers = [payload["sequence_number"] for _, payload in events] + assert sequence_numbers == sorted(sequence_numbers) + created_payload = events[0][1] + completed_payload = next( + payload + for event_type, payload in events + if event_type == "response.completed" + ) + assert created_payload["response"]["id"] == completed_payload["response"]["id"] + assert completed_payload["response"]["output_text"] == "Hello stream" + + def test_json_object_response_format_is_rejected(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello")) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Hello", + "text": {"format": {"type": "json_object"}}, + }, + ) + + assert resp.status_code == 400 + assert "json_object" in resp.json()["detail"] + + def test_reasoning_configuration_is_ignored(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Hello", + "reasoning": {"effort": "xhigh"}, + }, + ) + + assert resp.status_code == 200 + assert engine.chat.await_count == 1 + + def test_reasoning_input_item_is_accepted(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hello"}, + { + "type": "reasoning", + "content": [{"type": "reasoning_text", "text": "x"}], + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "Hello" + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "x" + + def test_length_finish_reason_marks_response_incomplete(self, client): + import vllm_mlx.server as srv + + output = _output("Cut off", completion_tokens=5) + output.finish_reason = "length" + srv._engine = _mock_engine(output) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Hello", "max_output_tokens": 5}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "incomplete" + assert body["incomplete_details"] == {"reason": "max_output_tokens"} diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py index 62dcb6919..c3313d7a1 100644 --- a/vllm_mlx/api/__init__.py +++ b/vllm_mlx/api/__init__.py @@ -53,6 +53,24 @@ EmbeddingUsage, EmbeddingResponse, ) +from .responses_models import ( + ResponseTextFormat, + ResponseTextConfig, + ResponseReasoningConfig, + ResponseTextContentPart, + ResponseReasoningTextPart, + ResponseReasoningSummaryTextPart, + ResponseMessageItem, + ResponseReasoningItem, + ResponseFunctionCallItem, + ResponseFunctionCallOutputItem, + ResponseFunctionTool, + ResponsesUsage, + ResponseError, + ResponseIncompleteDetails, + ResponsesRequest, + ResponseObject, +) from .utils import ( clean_output_text, @@ -113,6 +131,22 @@ "EmbeddingData", "EmbeddingUsage", "EmbeddingResponse", + "ResponseTextFormat", + "ResponseTextConfig", + "ResponseReasoningConfig", + "ResponseTextContentPart", + "ResponseReasoningTextPart", + "ResponseReasoningSummaryTextPart", + "ResponseMessageItem", + "ResponseReasoningItem", + "ResponseFunctionCallItem", + "ResponseFunctionCallOutputItem", + "ResponseFunctionTool", + "ResponsesUsage", + "ResponseError", + "ResponseIncompleteDetails", + "ResponsesRequest", + "ResponseObject", # Utils "clean_output_text", "is_mllm_model", diff --git a/vllm_mlx/api/responses_models.py b/vllm_mlx/api/responses_models.py new file mode 100644 index 000000000..9fdfae8f3 --- /dev/null +++ b/vllm_mlx/api/responses_models.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Pydantic models for the OpenAI-compatible Responses API. + +This intentionally implements the subset needed for local coding-agent +workflows: text messages, function tools, function call outputs, and SSE +streaming events. The object and event shapes follow the conventions used by +OpenAI's gpt-oss reference server and llama.cpp's OpenAI-compatible server. +""" + +import time +import uuid +from typing import Literal + +from pydantic import BaseModel, Field, computed_field + + +class ResponseTextFormat(BaseModel): + """Output text format configuration.""" + + type: Literal["text", "json_object"] = "text" + + +class ResponseTextConfig(BaseModel): + """Text output configuration.""" + + format: ResponseTextFormat = Field(default_factory=ResponseTextFormat) + + +class ResponseReasoningConfig(BaseModel): + """Reasoning configuration.""" + + effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] | None = None + + +class ResponseTextContentPart(BaseModel): + """A text content part for message items.""" + + type: Literal["text", "input_text", "output_text"] = "output_text" + text: str + annotations: list[dict] = Field(default_factory=list) + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextPart(BaseModel): + """A reasoning text content part.""" + + type: Literal["reasoning_text"] = "reasoning_text" + text: str + + +class ResponseReasoningSummaryTextPart(BaseModel): + """A reasoning summary item.""" + + type: Literal["summary_text"] = "summary_text" + text: str + + +class ResponseMessageItem(BaseModel): + """A Responses API message item.""" + + id: str | None = None + type: Literal["message"] = "message" + role: Literal["system", "user", "assistant", "developer"] = "assistant" + content: str | list[ResponseTextContentPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseReasoningItem(BaseModel): + """A reasoning output item.""" + + id: str | None = None + type: Literal["reasoning"] = "reasoning" + summary: list[ResponseReasoningSummaryTextPart] = Field(default_factory=list) + content: list[ResponseReasoningTextPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseFunctionCallItem(BaseModel): + """A function call output item.""" + + id: str | None = None + type: Literal["function_call"] = "function_call" + call_id: str + name: str + arguments: str + status: Literal["in_progress", "completed", "incomplete"] = "completed" + + +class ResponseFunctionCallOutputItem(BaseModel): + """A tool result item passed back into a later request.""" + + type: Literal["function_call_output"] = "function_call_output" + call_id: str + output: str + + +class ResponseFunctionTool(BaseModel): + """A function tool definition.""" + + type: Literal["function"] = "function" + name: str + description: str | None = "" + parameters: dict = Field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + strict: bool = False + + +class ResponsesInputTokenDetails(BaseModel): + """Input token breakdown.""" + + cached_tokens: int = 0 + + +class ResponsesOutputTokenDetails(BaseModel): + """Output token breakdown.""" + + reasoning_tokens: int = 0 + + +class ResponsesUsage(BaseModel): + """Responses API token usage.""" + + input_tokens: int + output_tokens: int + total_tokens: int + input_tokens_details: ResponsesInputTokenDetails = Field( + default_factory=ResponsesInputTokenDetails + ) + output_tokens_details: ResponsesOutputTokenDetails = Field( + default_factory=ResponsesOutputTokenDetails + ) + + +class ResponseError(BaseModel): + """Error payload.""" + + code: str + message: str + + +class ResponseIncompleteDetails(BaseModel): + """Incomplete response details.""" + + reason: str + + +class ResponsesRequest(BaseModel): + """Request payload for /v1/responses.""" + + model: str + input: ( + str + | list[ + ResponseMessageItem + | ResponseReasoningItem + | ResponseFunctionCallItem + | ResponseFunctionCallOutputItem + | dict + ] + ) + instructions: str | None = None + max_output_tokens: int | None = None + stream: bool = False + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + tool_choice: str | dict | None = "auto" + parallel_tool_calls: bool = True + previous_response_id: str | None = None + temperature: float | None = None + top_p: float | None = None + metadata: dict = Field(default_factory=dict) + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + reasoning: ResponseReasoningConfig | None = None + store: bool = True + truncation: str = "disabled" + user: str | None = None + + +class ResponseObject(BaseModel): + """Response object for /v1/responses.""" + + id: str = Field(default_factory=lambda: f"resp_{uuid.uuid4().hex}") + object: Literal["response"] = "response" + created_at: int = Field(default_factory=lambda: int(time.time())) + status: Literal["completed", "failed", "incomplete", "in_progress"] = "completed" + background: bool = False + error: ResponseError | None = None + incomplete_details: ResponseIncompleteDetails | None = None + instructions: str | None = None + max_output_tokens: int | None = None + max_tool_calls: int | None = None + metadata: dict = Field(default_factory=dict) + model: str + output: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = Field(default_factory=list) + parallel_tool_calls: bool = True + previous_response_id: str | None = None + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + tool_choice: str | dict | None = "auto" + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + top_p: float = 1.0 + temperature: float | None = None + truncation: str = "disabled" + usage: ResponsesUsage | None = None + user: str | None = None + store: bool = True + + @computed_field + @property + def output_text(self) -> str: + """Concatenate assistant text content into the convenience field.""" + text_parts: list[str] = [] + for item in self.output: + if not isinstance(item, ResponseMessageItem): + continue + if isinstance(item.content, str): + text_parts.append(item.content) + continue + for part in item.content: + if part.type == "output_text": + text_parts.append(part.text) + return "".join(text_parts) + + +class ResponsesEventBase(BaseModel): + """Base event fields.""" + + sequence_number: int + + +class ResponseCreatedEvent(ResponsesEventBase): + type: Literal["response.created"] = "response.created" + response: ResponseObject + + +class ResponseInProgressEvent(ResponsesEventBase): + type: Literal["response.in_progress"] = "response.in_progress" + response: ResponseObject + + +class ResponseCompletedEvent(ResponsesEventBase): + type: Literal["response.completed"] = "response.completed" + response: ResponseObject + + +class ResponseOutputItemAddedEvent(ResponsesEventBase): + type: Literal["response.output_item.added"] = "response.output_item.added" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseOutputItemDoneEvent(ResponsesEventBase): + type: Literal["response.output_item.done"] = "response.output_item.done" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseContentPartAddedEvent(ResponsesEventBase): + type: Literal["response.content_part.added"] = "response.content_part.added" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseContentPartDoneEvent(ResponsesEventBase): + type: Literal["response.content_part.done"] = "response.content_part.done" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseOutputTextDeltaEvent(ResponsesEventBase): + type: Literal["response.output_text.delta"] = "response.output_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseOutputTextDoneEvent(ResponsesEventBase): + type: Literal["response.output_text.done"] = "response.output_text.done" + item_id: str + output_index: int + content_index: int + text: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextDeltaEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + + +class ResponseReasoningTextDoneEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done" + item_id: str + output_index: int + content_index: int + text: str + + +class ResponseFunctionCallArgumentsDeltaEvent(ResponsesEventBase): + type: Literal["response.function_call_arguments.delta"] = ( + "response.function_call_arguments.delta" + ) + item_id: str + output_index: int + delta: str diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 2d4b9c3e8..0ec43233d 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -39,6 +39,7 @@ import argparse import asyncio +import copy import json import logging import os @@ -48,13 +49,14 @@ import threading import time import uuid -from collections import defaultdict +from collections import OrderedDict, defaultdict from collections.abc import AsyncIterator import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.responses import Response, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel from starlette.routing import Match # Import from new modular API @@ -97,6 +99,31 @@ Usage, # noqa: F401 VideoUrl, # noqa: F401 ) +from .api.responses_models import ( + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallItem, + ResponseFunctionCallOutputItem, + ResponseFunctionTool, + ResponseIncompleteDetails, + ResponseInProgressEvent, + ResponseMessageItem, + ResponseObject, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputTextDeltaEvent, + ResponseOutputTextDoneEvent, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseReasoningTextPart, + ResponseTextContentPart, + ResponsesRequest, + ResponsesUsage, +) from .api.tool_calling import ( StreamingJsonFenceStripper, build_json_logits_processor, @@ -171,6 +198,8 @@ def _resolve_top_p(request_value: float | None) -> float: _enable_auto_tool_choice: bool = False _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +_responses_store: OrderedDict[str, dict] = OrderedDict() +_RESPONSES_STORE_MAX_SIZE: int = 1000 # Pattern to strip leaked tool call markup from content output. # Safety net: the tool parser should consume these, but if it doesn't @@ -553,6 +582,885 @@ def _parse_tool_calls_with_parser( return parse_tool_calls(output_text, request_dict) +def _new_response_item_id(prefix: str) -> str: + """Generate stable OpenAI-style item ids.""" + return f"{prefix}_{uuid.uuid4().hex}" + + +def _response_content_to_text(content) -> str: + """Normalize Responses API content items into plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content + + text_parts = [] + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + text = part.get("text", "") + else: + part_type = getattr(part, "type", None) + text = getattr(part, "text", "") + if part_type in {"text", "input_text", "output_text"}: + text_parts.append(text) + return "\n".join(part for part in text_parts if part) + + +def _responses_tools_to_chat_tools( + tools: list[ResponseFunctionTool | dict], +) -> tuple[list[dict] | None, list[str]]: + """Convert supported Responses tools and report unsupported tool types.""" + if not tools: + return None, [] + + supported: list[dict] = [] + unsupported: list[str] = [] + + for tool in tools: + if isinstance(tool, ResponseFunctionTool): + tool_type = tool.type + tool_name = tool.name + tool_description = tool.description or "" + tool_parameters = tool.parameters + elif isinstance(tool, dict): + tool_type = tool.get("type", "unknown") + tool_name = tool.get("name", "") + tool_description = tool.get("description", "") + tool_parameters = tool.get("parameters", {}) + else: + unsupported.append(type(tool).__name__) + continue + + if tool_type == "function": + supported.append( + { + "type": "function", + "function": { + "name": tool_name, + "description": tool_description, + "parameters": tool_parameters + or {"type": "object", "properties": {}}, + }, + } + ) + else: + unsupported.append(tool_type) + + return supported or None, unsupported + + +def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: + """Convert Responses API input items into chat-completions-style messages.""" + messages: list[dict] = [] + + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + + if request.instructions: + messages.append({"role": "system", "content": request.instructions}) + + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + return messages + + for item in request.input: + if isinstance(item, dict): + item_type = item.get("type", "") + if item_type == "message": + role = item.get("role", "user") + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.get("content")), + } + ) + elif item_type == "function_call": + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.get( + "call_id", _new_response_item_id("call") + ), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ], + } + ) + elif item_type == "function_call_output": + messages.append( + { + "role": "tool", + "tool_call_id": item.get("call_id", ""), + "content": item.get("output", ""), + } + ) + elif item_type == "reasoning": + parts = item.get("content", []) + reasoning_text = "\n".join( + p.get("text", "") for p in parts if isinstance(p, dict) + ) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", item_type + ) + continue + + if isinstance(item, ResponseMessageItem): + role = item.role + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.content), + } + ) + elif isinstance(item, ResponseFunctionCallItem): + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ], + } + ) + elif isinstance(item, ResponseFunctionCallOutputItem): + messages.append( + { + "role": "tool", + "tool_call_id": item.call_id, + "content": item.output, + } + ) + elif isinstance(item, ResponseReasoningItem): + reasoning_text = "\n".join(part.text for part in (item.content or [])) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", + getattr(item, "type", type(item).__name__), + ) + + return messages + + +def _responses_request_to_new_persisted_messages( + request: ResponsesRequest, +) -> list[dict]: + """Persist only the current request's replayable input items.""" + request_without_history = request.model_copy( + update={"previous_response_id": None, "instructions": None}, + deep=True, + ) + return _responses_input_to_chat_messages(request_without_history) + + +def _responses_request_to_persisted_messages(request: ResponsesRequest) -> list[dict]: + """Persist replayable history for chained previous_response_id requests. + + Responses `instructions` are intentionally not replayed across + `previous_response_id`, but replayable message items are. + """ + messages: list[dict] = [] + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + messages.extend(_responses_request_to_new_persisted_messages(request)) + return messages + + +def _responses_request_to_chat_request( + request: ResponsesRequest, +) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from a ResponsesRequest.""" + if request.text.format.type == "json_object": + raise HTTPException( + status_code=400, + detail="Responses text.format.type='json_object' is not supported on this backend", + ) + if request.reasoning is not None: + logger.debug("Ignoring reasoning configuration (not supported on this backend)") + + tools, unsupported_tools = _responses_tools_to_chat_tools(request.tools) + messages = _responses_input_to_chat_messages(request) + if unsupported_tools: + tool_list = ", ".join(sorted(set(unsupported_tools))) + messages.insert( + 0, + { + "role": "system", + "content": ( + "The following requested tool types are not available on this " + f"backend: {tool_list}. Do not call them." + ), + }, + ) + + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + merged_system_content = "\n\n".join( + str(msg.get("content", "")).strip() + for msg in system_messages + if str(msg.get("content", "")).strip() + ) + messages = ( + [{"role": "system", "content": merged_system_content}] + if merged_system_content + else [] + ) + non_system_messages + + return ChatCompletionRequest( + model=request.model, + messages=[Message(**msg) for msg in messages], + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + tools=tools, + tool_choice=request.tool_choice, + ) + + +def _build_responses_output_items( + text: str | None, + reasoning: str | None, + tool_calls: list[ToolCall] | None, +) -> list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem]: + """Convert parsed assistant output into Responses API output items.""" + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + + if reasoning: + output_items.append( + ResponseReasoningItem( + id=_new_response_item_id("rs"), + content=[ResponseReasoningTextPart(text=reasoning)], + ) + ) + + if text: + output_items.append( + ResponseMessageItem( + id=_new_response_item_id("msg"), + role="assistant", + content=[ResponseTextContentPart(type="output_text", text=text)], + ) + ) + + for tool_call in tool_calls or []: + output_items.append( + ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + + return output_items + + +def _response_output_items_to_chat_messages(output_items: list) -> list[dict]: + """Persist assistant output in chat-completions form for previous_response_id.""" + assistant_text_parts: list[str] = [] + assistant_tool_calls: list[dict] = [] + + for item in output_items: + if isinstance(item, ResponseMessageItem): + assistant_text_parts.append(_response_content_to_text(item.content)) + elif isinstance(item, ResponseFunctionCallItem): + assistant_tool_calls.append( + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ) + + if not assistant_text_parts and not assistant_tool_calls: + return [] + + return [ + { + "role": "assistant", + "content": "".join(assistant_text_parts), + "tool_calls": assistant_tool_calls or None, + } + ] + + +def _build_response_object( + request: ResponsesRequest, + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ], + prompt_tokens: int, + completion_tokens: int, + finish_reason: str | None, + response_id: str | None = None, +) -> ResponseObject: + """Build a full Responses API object.""" + response = ResponseObject( + id=response_id or _new_response_item_id("resp"), + model=_model_name or request.model, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + metadata=request.metadata, + output=output_items, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + text=request.text, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=_resolve_top_p(request.top_p), + temperature=_resolve_temperature(request.temperature), + truncation=request.truncation, + user=request.user, + store=request.store, + usage=ResponsesUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + if finish_reason == "length": + response.status = "incomplete" + response.incomplete_details = ResponseIncompleteDetails( + reason="max_output_tokens" + ) + return response + + +def _prepare_responses_request( + request: ResponsesRequest, +) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict]: + """Prepare a Responses request for execution on the chat engine.""" + _validate_model_name(request.model) + engine = get_engine() + chat_request = _responses_request_to_chat_request(request) + + if chat_request.messages: + logger.info( + f"[REQUEST] POST /v1/responses stream={request.stream} " + f"model={request.model!r} items=" + f"{len(request.input) if isinstance(request.input, list) else 1} " + f"tools={len(request.tools)}" + ) + + messages, images, videos = extract_multimodal_content( + chat_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": chat_request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(chat_request.temperature), + "top_p": _resolve_top_p(chat_request.top_p), + } + if request.tools: + chat_kwargs["tools"] = convert_tools_for_template(chat_request.tools) + if images: + chat_kwargs["images"] = images + if videos: + chat_kwargs["videos"] = videos + + return engine, chat_request, messages, chat_kwargs + + +async def _run_responses_request( + request: ResponsesRequest, + raw_request: Request, +) -> tuple[ResponseObject | None, list[dict]]: + """Execute a Responses API request against the backend chat engine.""" + engine, chat_request, messages, chat_kwargs = _prepare_responses_request(request) + + timeout = _default_timeout + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + if output is None: + return None, [] + + cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, chat_request) + reasoning_text = None + if _reasoning_parser and not tool_calls: + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + cleaned_text or output.text + ) + + output_items = _build_responses_output_items( + clean_output_text(cleaned_text) if cleaned_text else None, + reasoning_text, + tool_calls, + ) + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + + persisted_messages = _responses_request_to_persisted_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + if request.store: + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + return response_object, persisted_messages + + +async def _stream_responses_request(request: ResponsesRequest) -> AsyncIterator[str]: + """Execute a Responses API request and stream SSE events incrementally.""" + engine, chat_request, messages, chat_kwargs = _prepare_responses_request(request) + + response_id = _new_response_item_id("resp") + sequence = 1 + base_response = _build_response_object( + request=request, + output_items=[], + prompt_tokens=0, + completion_tokens=0, + finish_reason=None, + response_id=response_id, + ) + base_response.status = "in_progress" + base_response.usage = None + + yield _responses_sse_event( + "response.created", + ResponseCreatedEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + yield _responses_sse_event( + "response.in_progress", + ResponseInProgressEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + + prompt_tokens = 0 + completion_tokens = 0 + finish_reason = None + last_output = None + raw_accumulated_text = "" + accumulated_text = "" + accumulated_reasoning = "" + + text_item_id: str | None = None + text_output_index: int | None = None + reasoning_item_id: str | None = None + reasoning_output_index: int | None = None + next_output_index = 0 + + def _start_text_item() -> list[str]: + nonlocal text_item_id, text_output_index, next_output_index, sequence + events: list[str] = [] + if text_item_id is None: + text_item_id = _new_response_item_id("msg") + text_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=text_output_index, + item=ResponseMessageItem( + id=text_item_id, + role="assistant", + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + part=ResponseTextContentPart(type="output_text", text=""), + ), + ) + ) + sequence += 1 + return events + + def _start_reasoning_item() -> list[str]: + nonlocal reasoning_item_id, reasoning_output_index, next_output_index, sequence + events: list[str] = [] + if reasoning_item_id is None: + reasoning_item_id = _new_response_item_id("rs") + reasoning_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=reasoning_output_index, + item=ResponseReasoningItem( + id=reasoning_item_id, + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + part=ResponseReasoningTextPart(text=""), + ), + ) + ) + sequence += 1 + return events + + if _reasoning_parser: + _reasoning_parser.reset_state() + + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_markup_possible = False + if _enable_auto_tool_choice and _tool_call_parser: + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info( + "Initialized tool call parser for responses streaming: %s", + _tool_call_parser, + ) + except Exception as e: + logger.warning( + "Failed to init tool parser for responses streaming: %s", e + ) + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() + + async for output in engine.stream_chat(messages=messages, **chat_kwargs): + last_output = output + finish_reason = output.finish_reason + if hasattr(output, "prompt_tokens") and output.prompt_tokens: + prompt_tokens = output.prompt_tokens + if hasattr(output, "completion_tokens") and output.completion_tokens: + completion_tokens = output.completion_tokens + + delta_text = output.new_text or "" + if not delta_text: + continue + + previous_text = raw_accumulated_text + raw_accumulated_text += delta_text + + if _reasoning_parser: + delta_msg = _reasoning_parser.extract_reasoning_streaming( + previous_text, raw_accumulated_text, delta_text + ) + if delta_msg is None: + continue + + if delta_msg.reasoning: + for event in _start_reasoning_item(): + yield event + accumulated_reasoning += delta_msg.reasoning + yield _responses_sse_event( + "response.reasoning_text.delta", + ResponseReasoningTextDeltaEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + delta=delta_msg.reasoning, + ), + ) + sequence += 1 + + if delta_msg.content: + for event in _start_text_item(): + yield event + accumulated_text += delta_msg.content + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + delta=delta_msg.content, + ), + ) + sequence += 1 + continue + + content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + if tool_parser and delta_text: + if not tool_markup_possible and "<" not in delta_text: + tool_accumulated_text += delta_text + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_result = tool_parser.extract_tool_calls_streaming( + tool_accumulated_text, + tool_accumulated_text + delta_text, + delta_text, + ) + tool_accumulated_text += delta_text + if tool_result is None: + continue + if "tool_calls" in tool_result: + continue + content = tool_result.get("content", "") + + if not content: + continue + + for event in _start_text_item(): + yield event + accumulated_text += content + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + delta=content, + ), + ) + sequence += 1 + + cleaned_text, tool_calls = _parse_tool_calls_with_parser( + raw_accumulated_text, chat_request + ) + final_text = accumulated_text + if cleaned_text is not None and not final_text and not tool_calls: + final_text = clean_output_text(cleaned_text) + + reasoning_item = None + if reasoning_item_id is not None: + reasoning_item = ResponseReasoningItem( + id=reasoning_item_id, + status="completed", + content=[ResponseReasoningTextPart(text=accumulated_reasoning)], + ) + yield _responses_sse_event( + "response.reasoning_text.done", + ResponseReasoningTextDoneEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + text=accumulated_reasoning, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + part=reasoning_item.content[0], + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=reasoning_output_index, + item=reasoning_item, + ), + ) + sequence += 1 + + text_item = None + if text_item_id is not None or final_text: + if text_item_id is None: + for event in _start_text_item(): + yield event + text_item = ResponseMessageItem( + id=text_item_id, + role="assistant", + status="completed", + content=[ResponseTextContentPart(type="output_text", text=final_text)], + ) + yield _responses_sse_event( + "response.output_text.done", + ResponseOutputTextDoneEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + text=final_text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + part=text_item.content[0], + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=text_output_index, + item=text_item, + ), + ) + sequence += 1 + + function_call_items: list[ResponseFunctionCallItem] = [] + for tool_call in tool_calls or []: + output_index = next_output_index + next_output_index += 1 + item = ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + function_call_items.append(item) + yield _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=output_index, + item=item.model_copy(update={"status": "in_progress"}), + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.function_call_arguments.delta", + ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=sequence, + item_id=item.id, + output_index=output_index, + delta=item.arguments, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=output_index, + item=item, + ), + ) + sequence += 1 + + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + if reasoning_item is not None: + output_items.append(reasoning_item) + if text_item is not None: + output_items.append(text_item) + output_items.extend(function_call_items) + + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason=finish_reason, + response_id=response_id, + ) + + if request.store and last_output is not None: + persisted_messages = _responses_request_to_persisted_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + yield _responses_sse_event( + "response.completed", + ResponseCompletedEvent(sequence_number=sequence, response=response_object), + ) + + +def _responses_sse_event(event_type: str, payload: BaseModel | dict) -> str: + """Encode a Responses API SSE event.""" + data = ( + payload.model_dump_json() + if isinstance(payload, BaseModel) + else json.dumps(payload) + ) + return f"event: {event_type}\ndata: {data}\n\n" + + def _detect_native_tool_support() -> bool: """ Detect if the active tool parser supports native tool format. @@ -1915,6 +2823,27 @@ def _get_engine_tokenizer(engine) -> object | None: return None +@app.post( + "/v1/responses", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) +async def create_response(request: ResponsesRequest, raw_request: Request): + """Create a Responses API response.""" + if request.stream: + return StreamingResponse( + _disconnect_guard(_stream_responses_request(request), raw_request), + media_type="text/event-stream", + ) + + response_object, _persisted_messages = await _run_responses_request( + request, raw_request + ) + if response_object is None: + return Response(status_code=499) + + return response_object + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages.