diff --git a/tests/test_chat_template_kwargs.py b/tests/test_chat_template_kwargs.py new file mode 100644 index 00000000..534dc6e0 --- /dev/null +++ b/tests/test_chat_template_kwargs.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for chat template kwargs forwarding.""" + +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + +import vllm_mlx.server as srv +from vllm_mlx.engine.base import GenerationOutput + + +def test_chat_completion_request_preserves_chat_template_kwargs(): + request = srv.ChatCompletionRequest( + model="test-model", + messages=[srv.Message(role="user", content="Hello")], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert request.chat_template_kwargs == {"enable_thinking": False} + + +def test_batched_engine_applies_chat_template_kwargs(): + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + from vllm_mlx.engine.batched import BatchedEngine + + engine = BatchedEngine("test-model") + engine._tokenizer = MagicMock() + engine._tokenizer.apply_chat_template.return_value = "prompt" + + prompt = engine._apply_chat_template( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert prompt == "prompt" + engine._tokenizer.apply_chat_template.assert_called_once() + assert ( + engine._tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"] + is False + ) + + +def test_chat_completion_endpoint_forwards_chat_template_kwargs(): + captured = {} + + class FakeEngine: + model_name = "test-model" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ORBIT", + prompt_tokens=4, + completion_tokens=1, + finish_reason="stop", + ) + + client = TestClient(srv.app) + original_engine = srv._engine + original_model_name = srv._model_name + srv._engine = FakeEngine() + srv._model_name = "test-model" + try: + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Reply with ORBIT."}], + "max_tokens": 8, + "chat_template_kwargs": {"enable_thinking": False}, + }, + ) + finally: + srv._engine = original_engine + srv._model_name = original_model_name + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False} + assert response.json()["choices"][0]["message"]["content"] == "ORBIT" diff --git a/tests/test_paged_cache.py b/tests/test_paged_cache.py index 8e3082c3..5d5eac40 100644 --- a/tests/test_paged_cache.py +++ b/tests/test_paged_cache.py @@ -725,3 +725,93 @@ def test_clear(self): stats = cache.get_stats() # After clear, null block is still allocated (vLLM style) assert stats["allocated_blocks"] == 1 # only null block + + def test_reconstructs_hybrid_cache_from_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + tokens = list(range(8)) + kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3) + linear_state = [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ] + extracted = [ + { + "state": (kv_keys, kv_values), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": linear_state, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", tokens, extracted) + first_block = paged_manager.allocated_blocks[block_table.block_ids[0]] + last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]] + + assert first_block.cache_data[0] is not None + assert first_block.cache_data[1] is None + assert last_block.cache_data[1] is not None + + reconstructed = cache.reconstruct_cache(block_table) + + assert reconstructed is not None + assert isinstance(reconstructed[0], KVCache) + assert isinstance(reconstructed[1], ArraysCache) + assert reconstructed[0].state[0].tolist() == kv_keys.tolist() + assert reconstructed[0].state[1].tolist() == kv_values.tolist() + assert reconstructed[1].state[0].tolist() == linear_state[0].tolist() + assert reconstructed[1].state[1].tolist() == linear_state[1].tolist() + + def test_rejects_hybrid_prefix_without_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import BlockTable, PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + extracted = [ + { + "state": ( + mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3), + mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3), + ), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ], + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", list(range(8)), extracted) + prefix_table = BlockTable( + request_id="req-prefix", + block_ids=[block_table.block_ids[0]], + num_tokens=4, + ) + + assert cache.reconstruct_cache(prefix_table) is None diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py new file mode 100644 index 00000000..5cb136fc --- /dev/null +++ b/tests/test_responses_api.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the OpenAI-compatible Responses API.""" + +import platform +import sys +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +@pytest.fixture() +def client(): + from vllm_mlx.server import app + + return TestClient(app) + + +@pytest.fixture(autouse=True) +def server_state(): + import vllm_mlx.server as srv + + original_engine = srv._engine + original_model_name = srv._model_name + original_store = srv._responses_store + original_api_key = srv._api_key + + srv._engine = None + srv._model_name = "test-model" + srv._responses_store = {} + srv._api_key = None + + try: + yield + finally: + srv._engine = original_engine + srv._model_name = original_model_name + srv._responses_store = original_store + srv._api_key = original_api_key + + +def _mock_engine(*outputs): + engine = MagicMock() + engine.model_name = "test-model" + engine.preserve_native_tool_format = False + engine.chat = AsyncMock(side_effect=list(outputs)) + return engine + + +def _output(text: str, prompt_tokens: int = 7, completion_tokens: int = 3): + return SimpleNamespace( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason="stop", + ) + + +class TestResponsesEndpoint: + def test_basic_response(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello there")) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Say hello"}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "response" + assert body["output_text"] == "Hello there" + assert body["output"][0]["type"] == "message" + assert body["output"][0]["content"][0]["type"] == "output_text" + assert body["usage"]["input_tokens"] == 7 + assert body["usage"]["output_tokens"] == 3 + + def test_previous_response_id_reuses_prior_context(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "user" + assert second_messages[0]["content"] == "First prompt" + assert second_messages[1]["role"] == "assistant" + assert second_messages[1]["content"] == "First answer" + assert second_messages[2]["role"] == "user" + assert second_messages[2]["content"] == "Follow-up prompt" + + def test_developer_role_is_normalized_to_system(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hi"}, + {"type": "message", "role": "developer", "content": "Be terse"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be terse" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "Hi" + + def test_instructions_and_developer_message_are_merged(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "System instructions", + "input": [ + {"type": "message", "role": "developer", "content": "Developer note"}, + {"type": "message", "role": "user", "content": "Hi"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert len([m for m in messages if m["role"] == "system"]) == 1 + assert messages[0]["role"] == "system" + assert "System instructions" in messages[0]["content"] + assert "Developer note" in messages[0]["content"] + assert messages[1]["role"] == "user" + + def test_function_call_output_input_is_mapped_cleanly(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Done")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Run it"}, + { + "type": "function_call", + "call_id": "call_1", + "name": "shell", + "arguments": "{\"cmd\":\"pwd\"}", + }, + { + "type": "function_call_output", + "call_id": "call_1", + "output": "/tmp/work", + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[1]["role"] == "assistant" + assert "[Calling tool: shell(" in messages[1]["content"] + assert messages[2]["role"] == "user" + assert "[Tool Result (call_1)]" in messages[2]["content"] + assert "/tmp/work" in messages[2]["content"] + + def test_unsupported_tools_and_items_do_not_fail(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Fallback answer")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Answer directly"}, + { + "type": "web_search_call", + "status": "completed", + "action": {"type": "search", "query": "ignored"}, + }, + ], + "tools": [ + {"type": "web_search_preview"}, + {"type": "file_search", "vector_store_ids": ["vs_123"]}, + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert "not available on this backend" in messages[0]["content"] + assert messages[1]["role"] == "user" + assert engine.chat.call_args.kwargs["tools"][0]["type"] == "function" + + def test_function_call_response_item(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine( + _output('{"name":"shell","arguments":{"cmd":"pwd"}}') + ) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Use a tool", + "tools": [ + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + } + ], + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["output"][0]["type"] == "function_call" + assert body["output"][0]["name"] == "shell" + assert body["output_text"] == "" + + def test_streaming_response_events(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello stream")) + + with client.stream( + "POST", + "/v1/responses", + json={"model": "test-model", "input": "Hello", "stream": True}, + ) as resp: + stream_text = "".join(resp.iter_text()) + + assert resp.status_code == 200 + assert "event: response.created" in stream_text + assert "event: response.in_progress" in stream_text + assert "event: response.output_text.delta" in stream_text + assert "event: response.completed" in stream_text + assert "Hello stream" in stream_text diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e..81d74b95 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -167,6 +167,41 @@ def test_basic_completion_request(self): assert request.max_tokens is None # uses _default_max_tokens when None +class TestServeCli: + """Test serve CLI argument parsing.""" + + def test_tool_call_parser_accepts_harmony_aliases(self): + """GPT-OSS/Harmony parsers should be selectable from the serve CLI.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "harmony", + ] + ) + + assert args.command == "serve" + assert args.tool_call_parser == "harmony" + assert args.enable_auto_tool_choice is True + + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "gpt-oss", + ] + ) + + assert args.tool_call_parser == "gpt-oss" + + # ============================================================================= # Helper Function Tests # ============================================================================= diff --git a/tests/test_tokenizer_utils.py b/tests/test_tokenizer_utils.py new file mode 100644 index 00000000..d95fecc7 --- /dev/null +++ b/tests/test_tokenizer_utils.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tokenizer utility helpers.""" + +import platform +import sys +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +class TestLoadModelWithFallback: + def test_returns_successful_load_result(self): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + with patch("mlx_lm.load", return_value=(fake_model, fake_tokenizer)) as load: + model, tokenizer = load_model_with_fallback("mlx-community/Qwen3.5-4B") + + load.assert_called_once() + assert model is fake_model + assert tokenizer is fake_tokenizer + + def test_uses_tokenizer_fallback_for_tokenizer_errors(self): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + with patch( + "mlx_lm.load", + side_effect=ValueError("Tokenizer class Foo does not exist"), + ), patch( + "vllm_mlx.utils.tokenizer._load_with_tokenizer_fallback", + return_value=(fake_model, fake_tokenizer), + ) as fallback: + model, tokenizer = load_model_with_fallback("example/model") + + fallback.assert_called_once_with("example/model") + assert model is fake_model + assert tokenizer is fake_tokenizer diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py index cfb62f45..552f253f 100644 --- a/vllm_mlx/api/__init__.py +++ b/vllm_mlx/api/__init__.py @@ -53,6 +53,24 @@ EmbeddingUsage, EmbeddingResponse, ) +from .responses_models import ( + ResponseTextFormat, + ResponseTextConfig, + ResponseReasoningConfig, + ResponseTextContentPart, + ResponseReasoningTextPart, + ResponseReasoningSummaryTextPart, + ResponseMessageItem, + ResponseReasoningItem, + ResponseFunctionCallItem, + ResponseFunctionCallOutputItem, + ResponseFunctionTool, + ResponsesUsage, + ResponseError, + ResponseIncompleteDetails, + ResponsesRequest, + ResponseObject, +) from .utils import ( clean_output_text, @@ -111,6 +129,22 @@ "EmbeddingData", "EmbeddingUsage", "EmbeddingResponse", + "ResponseTextFormat", + "ResponseTextConfig", + "ResponseReasoningConfig", + "ResponseTextContentPart", + "ResponseReasoningTextPart", + "ResponseReasoningSummaryTextPart", + "ResponseMessageItem", + "ResponseReasoningItem", + "ResponseFunctionCallItem", + "ResponseFunctionCallOutputItem", + "ResponseFunctionTool", + "ResponsesUsage", + "ResponseError", + "ResponseIncompleteDetails", + "ResponsesRequest", + "ResponseObject", # Utils "clean_output_text", "is_mllm_model", diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 32b26e03..f7bcaaaa 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -11,6 +11,7 @@ import time import uuid +from typing import Any from pydantic import BaseModel, Field, computed_field @@ -169,6 +170,8 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict | None = None # "auto", "none", or specific tool # Structured output response_format: ResponseFormat | dict | None = None + # Extra kwargs forwarded to tokenizer.apply_chat_template + chat_template_kwargs: dict[str, Any] | None = None # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None diff --git a/vllm_mlx/api/responses_models.py b/vllm_mlx/api/responses_models.py new file mode 100644 index 00000000..ed84d28e --- /dev/null +++ b/vllm_mlx/api/responses_models.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Pydantic models for the OpenAI-compatible Responses API. + +This intentionally implements the subset needed for local coding-agent +workflows: text messages, function tools, function call outputs, and SSE +streaming events. The object and event shapes follow the conventions used by +OpenAI's gpt-oss reference server and llama.cpp's OpenAI-compatible server. +""" + +import time +import uuid +from typing import Literal + +from pydantic import BaseModel, Field, computed_field + + +class ResponseTextFormat(BaseModel): + """Output text format configuration.""" + + type: Literal["text", "json_object"] = "text" + + +class ResponseTextConfig(BaseModel): + """Text output configuration.""" + + format: ResponseTextFormat = Field(default_factory=ResponseTextFormat) + + +class ResponseReasoningConfig(BaseModel): + """Reasoning configuration.""" + + effort: Literal["low", "medium", "high"] | None = None + + +class ResponseTextContentPart(BaseModel): + """A text content part for message items.""" + + type: Literal["text", "input_text", "output_text"] = "output_text" + text: str + annotations: list[dict] = Field(default_factory=list) + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextPart(BaseModel): + """A reasoning text content part.""" + + type: Literal["reasoning_text"] = "reasoning_text" + text: str + + +class ResponseReasoningSummaryTextPart(BaseModel): + """A reasoning summary item.""" + + type: Literal["summary_text"] = "summary_text" + text: str + + +class ResponseMessageItem(BaseModel): + """A Responses API message item.""" + + id: str | None = None + type: Literal["message"] = "message" + role: Literal["system", "user", "assistant", "developer"] = "assistant" + content: str | list[ResponseTextContentPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseReasoningItem(BaseModel): + """A reasoning output item.""" + + id: str | None = None + type: Literal["reasoning"] = "reasoning" + summary: list[ResponseReasoningSummaryTextPart] = Field(default_factory=list) + content: list[ResponseReasoningTextPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseFunctionCallItem(BaseModel): + """A function call output item.""" + + id: str | None = None + type: Literal["function_call"] = "function_call" + call_id: str + name: str + arguments: str + status: Literal["in_progress", "completed", "incomplete"] = "completed" + + +class ResponseFunctionCallOutputItem(BaseModel): + """A tool result item passed back into a later request.""" + + type: Literal["function_call_output"] = "function_call_output" + call_id: str + output: str + + +class ResponseFunctionTool(BaseModel): + """A function tool definition.""" + + type: Literal["function"] = "function" + name: str + description: str | None = "" + parameters: dict = Field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + strict: bool = False + + +class ResponsesInputTokenDetails(BaseModel): + """Input token breakdown.""" + + cached_tokens: int = 0 + + +class ResponsesOutputTokenDetails(BaseModel): + """Output token breakdown.""" + + reasoning_tokens: int = 0 + + +class ResponsesUsage(BaseModel): + """Responses API token usage.""" + + input_tokens: int + output_tokens: int + total_tokens: int + input_tokens_details: ResponsesInputTokenDetails = Field( + default_factory=ResponsesInputTokenDetails + ) + output_tokens_details: ResponsesOutputTokenDetails = Field( + default_factory=ResponsesOutputTokenDetails + ) + + +class ResponseError(BaseModel): + """Error payload.""" + + code: str + message: str + + +class ResponseIncompleteDetails(BaseModel): + """Incomplete response details.""" + + reason: str + + +class ResponsesRequest(BaseModel): + """Request payload for /v1/responses.""" + + model: str + input: ( + str + | list[ + ResponseMessageItem + | ResponseReasoningItem + | ResponseFunctionCallItem + | ResponseFunctionCallOutputItem + | dict + ] + ) + instructions: str | None = None + max_output_tokens: int | None = None + stream: bool = False + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + tool_choice: str | dict | None = "auto" + parallel_tool_calls: bool = True + previous_response_id: str | None = None + temperature: float | None = None + top_p: float | None = None + metadata: dict = Field(default_factory=dict) + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + reasoning: ResponseReasoningConfig | None = None + store: bool = True + truncation: str = "disabled" + user: str | None = None + + +class ResponseObject(BaseModel): + """Response object for /v1/responses.""" + + id: str = Field(default_factory=lambda: f"resp_{uuid.uuid4().hex}") + object: Literal["response"] = "response" + created_at: int = Field(default_factory=lambda: int(time.time())) + status: Literal["completed", "failed", "incomplete", "in_progress"] = "completed" + background: bool = False + error: ResponseError | None = None + incomplete_details: ResponseIncompleteDetails | None = None + instructions: str | None = None + max_output_tokens: int | None = None + max_tool_calls: int | None = None + metadata: dict = Field(default_factory=dict) + model: str + output: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = Field(default_factory=list) + parallel_tool_calls: bool = True + previous_response_id: str | None = None + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + tool_choice: str | dict | None = "auto" + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + top_p: float = 1.0 + temperature: float | None = None + truncation: str = "disabled" + usage: ResponsesUsage | None = None + user: str | None = None + store: bool = True + + @computed_field + @property + def output_text(self) -> str: + """Concatenate assistant text content into the convenience field.""" + text_parts: list[str] = [] + for item in self.output: + if not isinstance(item, ResponseMessageItem): + continue + if isinstance(item.content, str): + text_parts.append(item.content) + continue + for part in item.content: + if part.type == "output_text": + text_parts.append(part.text) + return "".join(text_parts) + + +class ResponsesEventBase(BaseModel): + """Base event fields.""" + + sequence_number: int + + +class ResponseCreatedEvent(ResponsesEventBase): + type: Literal["response.created"] = "response.created" + response: ResponseObject + + +class ResponseInProgressEvent(ResponsesEventBase): + type: Literal["response.in_progress"] = "response.in_progress" + response: ResponseObject + + +class ResponseCompletedEvent(ResponsesEventBase): + type: Literal["response.completed"] = "response.completed" + response: ResponseObject + + +class ResponseOutputItemAddedEvent(ResponsesEventBase): + type: Literal["response.output_item.added"] = "response.output_item.added" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseOutputItemDoneEvent(ResponsesEventBase): + type: Literal["response.output_item.done"] = "response.output_item.done" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseContentPartAddedEvent(ResponsesEventBase): + type: Literal["response.content_part.added"] = "response.content_part.added" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseContentPartDoneEvent(ResponsesEventBase): + type: Literal["response.content_part.done"] = "response.content_part.done" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseOutputTextDeltaEvent(ResponsesEventBase): + type: Literal["response.output_text.delta"] = "response.output_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseOutputTextDoneEvent(ResponsesEventBase): + type: Literal["response.output_text.done"] = "response.output_text.done" + item_id: str + output_index: int + content_index: int + text: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextDeltaEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + + +class ResponseReasoningTextDoneEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done" + item_id: str + output_index: int + content_index: int + text: str + + +class ResponseFunctionCallArgumentsDeltaEvent(ResponsesEventBase): + type: Literal["response.function_call_arguments.delta"] = ( + "response.function_call_arguments.delta" + ) + item_id: str + output_index: int + delta: str diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 8a90bc9b..7f9ea088 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -593,7 +593,8 @@ def bench_kv_cache_command(args): ) -def main(): +def create_parser() -> argparse.ArgumentParser: + """Build the top-level CLI parser.""" parser = argparse.ArgumentParser( description="vllm-mlx: Apple Silicon MLX backend for vLLM", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -832,6 +833,8 @@ def main(): "qwen3_coder", "llama", "hermes", + "harmony", + "gpt-oss", "deepseek", "kimi", "granite", @@ -843,7 +846,8 @@ def main(): help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " - "deepseek, kimi, granite, nemotron, xlam, functionary, glm47. " + "harmony, gpt-oss, deepseek, kimi, granite, nemotron, xlam, " + "functionary, glm47. " "Required for --enable-auto-tool-choice." ), ) @@ -1023,6 +1027,12 @@ def main(): help="Quantization group size (default: 64)", ) + return parser + + +def main(): + parser = create_parser() + args = parser.parse_args() if args.command == "serve": diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..c30421e1 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -335,6 +335,7 @@ def _apply_chat_template( messages: list[dict[str, Any]], tools: list[dict] | None = None, num_images: int = 0, + chat_template_kwargs: dict[str, Any] | None = None, ) -> str: """Apply chat template to messages. @@ -367,7 +368,9 @@ def _apply_chat_template( "tokenize": False, "add_generation_prompt": True, } - if tools: + if chat_template_kwargs: + template_kwargs.update(chat_template_kwargs) + if tools and "tools" not in template_kwargs: template_kwargs["tools"] = tools try: @@ -375,11 +378,10 @@ def _apply_chat_template( messages, **template_kwargs ) except TypeError as e: - # Some templates don't accept 'tools'; retry without them. + # Some templates don't accept extra kwargs; retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools"]: - if key in template_kwargs: - del template_kwargs[key] + for key in ["tools", *(chat_template_kwargs or {}).keys()]: + template_kwargs.pop(key, None) return template_applicator.apply_chat_template( messages, **template_kwargs ) @@ -620,12 +622,14 @@ async def chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, ) return await self.generate( @@ -639,7 +643,10 @@ async def chat( ) def _compute_prefix_boundary( - self, messages: list[dict[str, Any]], tools: list[dict] | None = None + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, ) -> int: """Compute token count for the shared prefix across message variations. @@ -661,7 +668,11 @@ def _compute_prefix_boundary( template_tools = convert_tools_for_template(tools) if tools else None # Tokenize the real prompt - real_prompt = self._apply_chat_template(messages, template_tools) + real_prompt = self._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) # Build a dummy variant with different last user content dummy_messages = list(messages) @@ -669,7 +680,11 @@ def _compute_prefix_boundary( **messages[last_user_idx], "content": "XXXXXXXXXX", } - dummy_prompt = self._apply_chat_template(dummy_messages, template_tools) + dummy_prompt = self._apply_chat_template( + dummy_messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) tokenizer = self.tokenizer if hasattr(tokenizer, "tokenizer"): @@ -731,16 +746,22 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, ) # Compute prefix boundary for cache - prefix_boundary = self._compute_prefix_boundary(messages, tools) + prefix_boundary = self._compute_prefix_boundary( + messages, + tools, + chat_template_kwargs=chat_template_kwargs, + ) if prefix_boundary > 0: kwargs["prefix_boundary"] = prefix_boundary diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a32..0bfe329f 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -586,7 +586,7 @@ def store_cache( # Extract and store actual tensor slices for this block if is_tensor_data and HAS_MLX: block_kv_data = self._extract_block_tensor_slice( - cache_data, global_start, global_end + cache_data, global_start, global_end, len(tokens) ) if block_kv_data: block.cache_data = block_kv_data @@ -629,56 +629,120 @@ def _extract_block_tensor_slice( cache_data: List[Dict[str, Any]], start_idx: int, end_idx: int, - ) -> Optional[List[Tuple[Any, Any]]]: + total_tokens: int, + ) -> Optional[List[Optional[Dict[str, Any]]]]: """ - Extract tensor slices for a single block from cache data. + Extract per-layer cache data for a single block. Args: - cache_data: List of layer states, each containing 'state': (keys, values) + cache_data: List of extracted layer states start_idx: Start token index in the sequence end_idx: End token index in the sequence + total_tokens: Total number of tokens covered by cache_data Returns: - List of (keys_slice, values_slice) for each layer, or None on failure + Per-layer block cache state, or None on failure """ if not HAS_MLX or not cache_data: return None try: - block_slices = [] + block_slices: List[Optional[Dict[str, Any]]] = [] for layer_state in cache_data: if "state" not in layer_state: + block_slices.append(None) continue - keys, values = layer_state["state"] + state = layer_state["state"] + meta_state = layer_state.get("meta_state") + class_ref = layer_state.get("class_ref") + class_name = layer_state.get("class_name") - # KV cache shape: (batch, n_kv_heads, seq_len, head_dim) - # Slice along seq_len dimension (axis 2) - seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 + if self._can_concatenate_cache_state(state): + state_slice = self._slice_concat_cache_state( + state, start_idx, end_idx + ) + block_slices.append( + { + "state": state_slice, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "concat", + "seq_axis": 2, + } + ) + continue - if end_idx > seq_len: - # Requested range extends beyond available data - logger.debug( - f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + if end_idx == total_tokens: + block_slices.append( + { + "state": state, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "latest", + } ) - # Use whatever is available - actual_end = min(end_idx, seq_len) - if start_idx >= actual_end: - continue - keys_slice = keys[:, :, start_idx:actual_end, :] - values_slice = values[:, :, start_idx:actual_end, :] else: - keys_slice = keys[:, :, start_idx:end_idx, :] - values_slice = values[:, :, start_idx:end_idx, :] + block_slices.append(None) - block_slices.append((keys_slice, values_slice)) - - return block_slices if block_slices else None + return block_slices if any(entry is not None for entry in block_slices) else None except Exception as e: logger.warning(f"Failed to extract block tensor slice: {e}") return None + def _can_concatenate_cache_state(self, state: Any) -> bool: + """Return True when cache state can be concatenated block-by-block.""" + if not isinstance(state, (list, tuple)) or not state: + return False + return all( + tensor is not None + and hasattr(tensor, "shape") + and len(tensor.shape) == 4 + for tensor in state + ) + + def _slice_concat_cache_state( + self, + state: Tuple[Any, ...] | List[Any], + start_idx: int, + end_idx: int, + ) -> Tuple[Any, ...] | List[Any]: + """Slice a sequence-backed cache state across the token axis.""" + seq_len = state[0].shape[2] + actual_end = min(end_idx, seq_len) + if start_idx >= actual_end: + raise ValueError( + f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + ) + + def _slice_tensor(tensor: Any) -> Any: + slices = [slice(None)] * len(tensor.shape) + slices[2] = slice(start_idx, actual_end) + return tensor[tuple(slices)] + + sliced = [_slice_tensor(tensor) for tensor in state] + return tuple(sliced) if isinstance(state, tuple) else sliced + + def _concat_cache_states( + self, + states: List[Tuple[Any, ...] | List[Any]], + seq_axis: int, + ) -> Optional[Tuple[Any, ...] | List[Any]]: + """Concatenate state fragments for a sequence-backed cache layer.""" + if not states: + return None + arity = len(states[0]) + concatenated = [] + for idx in range(arity): + parts = [state[idx] for state in states] + if any(part is None for part in parts): + return None + concatenated.append(mx.concatenate(parts, axis=seq_axis)) + return tuple(concatenated) if isinstance(states[0], tuple) else concatenated + def get_cache_for_generation( self, request_id: str, @@ -763,10 +827,11 @@ def reconstruct_cache( block_table: BlockTable, ) -> Optional[List[Any]]: """ - Reconstruct KVCache objects from stored block tensor data. + Reconstruct cache objects from stored block tensor data. - This method concatenates tensor slices from all blocks and - creates new KVCache objects that can be used for inference. + Sequence-backed caches are concatenated block-by-block. Recurrent + caches such as ArraysCache are restored from the latest sequence + boundary snapshot that was actually stored. Args: block_table: BlockTable containing block IDs to reconstruct from @@ -800,67 +865,62 @@ def reconstruct_cache( if not all_block_data: return None - # Get number of layers from first block - num_layers = len(all_block_data[0]) + # Get number of layers from the richest block + num_layers = max(len(block_data) for block_data in all_block_data) if num_layers == 0: return None - # Concatenate tensors for each layer reconstructed_caches = [] - for layer_idx in range(num_layers): - layer_keys = [] - layer_values = [] + layer_entries = [ + block_data[layer_idx] + for block_data in all_block_data + if layer_idx < len(block_data) + ] + layer_entries = [entry for entry in layer_entries if entry is not None] + if not layer_entries: + return None - for block_data in all_block_data: - if layer_idx < len(block_data): - keys_slice, values_slice = block_data[layer_idx] - layer_keys.append(keys_slice) - layer_values.append(values_slice) + layer_meta = layer_entries[-1] + state = layer_meta["state"] + if layer_meta["storage"] == "concat": + state = self._concat_cache_states( + [entry["state"] for entry in layer_entries], + layer_meta["seq_axis"], + ) + elif layer_meta["storage"] == "latest": + state = layer_entries[-1]["state"] - if not layer_keys: - continue + if state is None: + return None - # Concatenate along sequence dimension (axis 2) - # Shape: (batch, n_kv_heads, seq_len, head_dim) - concat_keys = mx.concatenate(layer_keys, axis=2) - concat_values = mx.concatenate(layer_values, axis=2) + cache_cls = layer_meta.get("class_ref") + meta_state = layer_meta.get("meta_state") - # Create KVCache object - # Try to use mlx_lm's KVCache.from_state if available - try: + if cache_cls is not None and hasattr(cache_cls, "from_state"): + from mlx_lm.models.cache import ( + BatchKVCache as _BatchKVCache, + KVCache as _KVCache, + ) + + if cache_cls is _BatchKVCache: + keys, values = state[0], state[1] + cache = _KVCache() + cache.keys = keys + cache.values = values + cache.offset = keys.shape[2] + else: + cache = cache_cls.from_state(state, meta_state) + else: from mlx_lm.models.cache import KVCache - # Create new cache and set its state + if len(state) != 2: + return None cache = KVCache() - seq_len = concat_keys.shape[2] - - # Set internal state directly - # KVCache stores keys/values and offset - cache.keys = concat_keys - cache.values = concat_values - cache.offset = seq_len - - reconstructed_caches.append(cache) - - except ImportError: - # Fallback: create a simple cache-like object - class SimpleKVCache: - def __init__(self, keys, values): - self.keys = keys - self.values = values - self.offset = keys.shape[2] - - @property - def state(self): - return (self.keys, self.values) - - @property - def meta_state(self): - return (str(self.offset),) - - cache = SimpleKVCache(concat_keys, concat_values) - reconstructed_caches.append(cache) + cache.keys, cache.values = state + cache.offset = cache.keys.shape[2] + + reconstructed_caches.append(cache) if not reconstructed_caches: return None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f..4ec280d0 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -39,6 +39,7 @@ import argparse import asyncio +import copy import json import logging import os @@ -54,6 +55,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.responses import Response, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel # Import from new modular API # Re-export for backwards compatibility with tests @@ -90,6 +92,30 @@ Usage, # noqa: F401 VideoUrl, # noqa: F401 ) +from .api.responses_models import ( + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallItem, + ResponseFunctionCallOutputItem, + ResponseFunctionTool, + ResponseInProgressEvent, + ResponseMessageItem, + ResponseObject, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputTextDeltaEvent, + ResponseOutputTextDoneEvent, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseReasoningTextPart, + ResponseTextContentPart, + ResponsesRequest, + ResponsesUsage, +) from .api.tool_calling import ( build_json_system_prompt, convert_tools_for_template, @@ -160,6 +186,7 @@ def _resolve_top_p(request_value: float | None) -> float: _enable_auto_tool_choice: bool = False _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +_responses_store: dict[str, dict] = {} def _load_prefix_cache_from_disk() -> None: @@ -419,6 +446,612 @@ def _parse_tool_calls_with_parser( return parse_tool_calls(output_text, request_dict) +def _new_response_item_id(prefix: str) -> str: + """Generate stable OpenAI-style item ids.""" + return f"{prefix}_{uuid.uuid4().hex}" + + +def _response_content_to_text(content) -> str: + """Normalize Responses API content items into plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content + + text_parts = [] + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + text = part.get("text", "") + else: + part_type = getattr(part, "type", None) + text = getattr(part, "text", "") + if part_type in {"text", "input_text", "output_text"}: + text_parts.append(text) + return "\n".join(part for part in text_parts if part) + + +def _responses_tools_to_chat_tools( + tools: list[ResponseFunctionTool | dict], +) -> tuple[list[dict] | None, list[str]]: + """Convert supported Responses tools and report unsupported tool types.""" + if not tools: + return None, [] + + supported: list[dict] = [] + unsupported: list[str] = [] + + for tool in tools: + if isinstance(tool, ResponseFunctionTool): + tool_type = tool.type + tool_name = tool.name + tool_description = tool.description or "" + tool_parameters = tool.parameters + elif isinstance(tool, dict): + tool_type = tool.get("type", "unknown") + tool_name = tool.get("name", "") + tool_description = tool.get("description", "") + tool_parameters = tool.get("parameters", {}) + else: + unsupported.append(type(tool).__name__) + continue + + if tool_type == "function": + supported.append( + { + "type": "function", + "function": { + "name": tool_name, + "description": tool_description, + "parameters": tool_parameters + or {"type": "object", "properties": {}}, + }, + } + ) + else: + unsupported.append(tool_type) + + return supported or None, unsupported + + +def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: + """Convert Responses API input items into chat-completions-style messages.""" + messages: list[dict] = [] + + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + + if request.instructions: + messages.append({"role": "system", "content": request.instructions}) + + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + return messages + + for item in request.input: + if isinstance(item, dict): + item_type = item.get("type", "") + if item_type == "message": + role = item.get("role", "user") + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.get("content")), + } + ) + elif item_type == "function_call": + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.get("call_id", _new_response_item_id("call")), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ], + } + ) + elif item_type == "function_call_output": + messages.append( + { + "role": "tool", + "tool_call_id": item.get("call_id", ""), + "content": item.get("output", ""), + } + ) + else: + logger.info( + "Skipping unsupported Responses input item type %r", item_type + ) + continue + + if isinstance(item, ResponseMessageItem): + role = item.role + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.content), + } + ) + elif isinstance(item, ResponseFunctionCallItem): + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ], + } + ) + elif isinstance(item, ResponseFunctionCallOutputItem): + messages.append( + { + "role": "tool", + "tool_call_id": item.call_id, + "content": item.output, + } + ) + elif isinstance(item, ResponseReasoningItem): + # Reasoning items are metadata from previous turns; keep them out of the + # model prompt because the underlying backend only consumes plain chat + # messages plus tool-call markers. + continue + else: + logger.info( + "Skipping unsupported Responses input item type %r", + getattr(item, "type", type(item).__name__), + ) + + return messages + + +def _responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from a ResponsesRequest.""" + response_format = None + if request.text.format.type == "json_object": + response_format = {"type": "json_object"} + + tools, unsupported_tools = _responses_tools_to_chat_tools(request.tools) + messages = _responses_input_to_chat_messages(request) + if unsupported_tools: + tool_list = ", ".join(sorted(set(unsupported_tools))) + messages.insert( + 0, + { + "role": "system", + "content": ( + "The following requested tool types are not available on this " + f"backend: {tool_list}. Do not call them." + ), + }, + ) + + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + merged_system_content = "\n\n".join( + str(msg.get("content", "")).strip() + for msg in system_messages + if str(msg.get("content", "")).strip() + ) + messages = ( + [{"role": "system", "content": merged_system_content}] + if merged_system_content + else [] + ) + non_system_messages + + return ChatCompletionRequest( + model=request.model, + messages=[Message(**msg) for msg in messages], + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + tools=tools, + tool_choice=request.tool_choice, + response_format=response_format, + ) + + +def _build_responses_output_items( + text: str | None, + reasoning: str | None, + tool_calls: list[ToolCall] | None, +) -> list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem]: + """Convert parsed assistant output into Responses API output items.""" + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + + if reasoning: + output_items.append( + ResponseReasoningItem( + id=_new_response_item_id("rs"), + content=[ResponseReasoningTextPart(text=reasoning)], + ) + ) + + if text: + output_items.append( + ResponseMessageItem( + id=_new_response_item_id("msg"), + role="assistant", + content=[ResponseTextContentPart(type="output_text", text=text)], + ) + ) + + for tool_call in tool_calls or []: + output_items.append( + ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + + return output_items + + +def _response_output_items_to_chat_messages(output_items: list) -> list[dict]: + """Persist assistant output in chat-completions form for previous_response_id.""" + assistant_text_parts: list[str] = [] + assistant_tool_calls: list[dict] = [] + + for item in output_items: + if isinstance(item, ResponseMessageItem): + assistant_text_parts.append(_response_content_to_text(item.content)) + elif isinstance(item, ResponseFunctionCallItem): + assistant_tool_calls.append( + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ) + + if not assistant_text_parts and not assistant_tool_calls: + return [] + + return [ + { + "role": "assistant", + "content": "".join(assistant_text_parts), + "tool_calls": assistant_tool_calls or None, + } + ] + + +def _build_response_object( + request: ResponsesRequest, + output_items: list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem], + prompt_tokens: int, + completion_tokens: int, + response_id: str | None = None, +) -> ResponseObject: + """Build a full Responses API object.""" + response = ResponseObject( + id=response_id or _new_response_item_id("resp"), + model=_model_name or request.model, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + metadata=request.metadata, + output=output_items, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + text=request.text, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=_resolve_top_p(request.top_p), + temperature=_resolve_temperature(request.temperature), + truncation=request.truncation, + user=request.user, + store=request.store, + usage=ResponsesUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +async def _run_responses_request( + request: ResponsesRequest, + raw_request: Request, +) -> tuple[ResponseObject | None, list[dict]]: + """Execute a Responses API request against the backend chat engine.""" + _validate_model_name(request.model) + engine = get_engine() + chat_request = _responses_request_to_chat_request(request) + + if chat_request.messages: + logger.info( + f"[REQUEST] POST /v1/responses stream={request.stream} " + f"model={request.model!r} items=" + f"{len(request.input) if isinstance(request.input, list) else 1} " + f"tools={len(request.tools)}" + ) + + messages, images, videos = extract_multimodal_content( + chat_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": chat_request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(chat_request.temperature), + "top_p": _resolve_top_p(chat_request.top_p), + } + if request.tools: + chat_kwargs["tools"] = convert_tools_for_template(chat_request.tools) + if images: + chat_kwargs["images"] = images + if videos: + chat_kwargs["videos"] = videos + + timeout = _default_timeout + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + if output is None: + return None, [] + + cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, chat_request) + reasoning_text = None + if _reasoning_parser and not tool_calls: + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + cleaned_text or output.text + ) + + output_items = _build_responses_output_items( + clean_output_text(cleaned_text) if cleaned_text else None, + reasoning_text, + tool_calls, + ) + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + ) + + persisted_messages = _responses_input_to_chat_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + if request.store: + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + + return response_object, persisted_messages + + +def _responses_sse_event(event_type: str, payload: BaseModel | dict) -> str: + """Encode a Responses API SSE event.""" + data = payload.model_dump_json() if isinstance(payload, BaseModel) else json.dumps(payload) + return f"event: {event_type}\ndata: {data}\n\n" + + +async def _stream_response_object(response: ResponseObject) -> AsyncIterator[str]: + """Stream a completed response object as OpenAI-style SSE events.""" + sequence = 1 + in_progress = response.model_copy(deep=True) + in_progress.status = "in_progress" + in_progress.usage = None + in_progress.output = [] + + yield _responses_sse_event( + "response.created", + ResponseCreatedEvent(sequence_number=sequence, response=in_progress), + ) + sequence += 1 + yield _responses_sse_event( + "response.in_progress", + ResponseInProgressEvent(sequence_number=sequence, response=in_progress), + ) + sequence += 1 + + for output_index, item in enumerate(response.output): + if isinstance(item, ResponseReasoningItem): + item_id = item.id or _new_response_item_id("rs") + in_progress_item = item.model_copy(update={"id": item_id, "status": "in_progress"}) + yield _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=output_index, + item=in_progress_item, + ), + ) + sequence += 1 + part = item.content[0] if item.content else ResponseReasoningTextPart(text="") + yield _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + part=ResponseReasoningTextPart(text=""), + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.reasoning_text.delta", + ResponseReasoningTextDeltaEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + delta=part.text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.reasoning_text.done", + ResponseReasoningTextDoneEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + text=part.text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + part=part, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=output_index, + item=item, + ), + ) + sequence += 1 + continue + + if isinstance(item, ResponseMessageItem): + item_id = item.id or _new_response_item_id("msg") + in_progress_item = item.model_copy(update={"id": item_id, "status": "in_progress", "content": []}) + text_part = item.content[0] if isinstance(item.content, list) and item.content else ResponseTextContentPart(type="output_text", text="") + yield _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=output_index, + item=in_progress_item, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + part=ResponseTextContentPart(type="output_text", text=""), + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + delta=text_part.text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_text.done", + ResponseOutputTextDoneEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + text=text_part.text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=item_id, + output_index=output_index, + content_index=0, + part=text_part, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=output_index, + item=item, + ), + ) + sequence += 1 + continue + + if isinstance(item, ResponseFunctionCallItem): + in_progress_item = item.model_copy(update={"status": "in_progress"}) + yield _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=output_index, + item=in_progress_item, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.function_call_arguments.delta", + ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=sequence, + item_id=item.id or _new_response_item_id("fc"), + output_index=output_index, + delta=item.arguments, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=output_index, + item=item, + ), + ) + sequence += 1 + + yield _responses_sse_event( + "response.completed", + ResponseCompletedEvent(sequence_number=sequence, response=response), + ) + + def _detect_native_tool_support() -> bool: """ Detect if the active tool parser supports native tool format. @@ -1420,6 +2053,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re chat_kwargs["specprefill"] = request.specprefill if request.specprefill_keep_pct is not None: chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + if request.chat_template_kwargs: + chat_kwargs["chat_template_kwargs"] = dict(request.chat_template_kwargs) # Add tools if provided if request.tools: @@ -1496,6 +2131,27 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) +@app.post( + "/v1/responses", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) +async def create_response(request: ResponsesRequest, raw_request: Request): + """Create a Responses API response.""" + response_object, _persisted_messages = await _run_responses_request( + request, raw_request + ) + if response_object is None: + return Response(status_code=499) + + if request.stream: + return StreamingResponse( + _disconnect_guard(_stream_response_object(response_object), raw_request), + media_type="text/event-stream", + ) + + return response_object + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a5088395..aaaeae55 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -52,6 +52,7 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) + return model, tokenizer except ValueError as e: # Fallback for models with non-standard tokenizers if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e):