diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index eedade2d..bc3f19b3 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -80,6 +80,8 @@ jobs:
tests/test_api_models.py \
tests/test_api_utils.py \
tests/test_request.py \
+ tests/test_anthropic_models.py \
+ tests/test_anthropic_adapter.py \
-v --tb=short \
-k "not Integration and not InjectJson and not TestMLXMultimodalLMCache" \
--cov=vllm_mlx \
diff --git a/docs/guides/server.md b/docs/guides/server.md
index a654e92c..33b7ec02 100644
--- a/docs/guides/server.md
+++ b/docs/guides/server.md
@@ -128,6 +128,390 @@ GET /health
Returns server status.
+### Anthropic Messages API
+
+```bash
+POST /v1/messages
+```
+
+Anthropic-compatible endpoint that allows tools like Claude Code and OpenCode to connect directly to vllm-mlx. Internally it translates Anthropic requests to OpenAI format, runs inference through the engine, and converts the response back to Anthropic format.
+
+Capabilities:
+- Non-streaming and streaming responses (SSE)
+- System messages (plain string or list of content blocks)
+- Multi-turn conversations with user and assistant messages
+- Tool calling with `tool_use` / `tool_result` content blocks
+- Token counting for budget tracking
+- Multimodal content (images via `source` blocks)
+- Client disconnect detection (returns HTTP 499)
+- Automatic special token filtering in streamed output
+
+#### Non-streaming
+
+```python
+from anthropic import Anthropic
+
+client = Anthropic(base_url="http://localhost:8000", api_key="not-needed")
+
+response = client.messages.create(
+ model="default",
+ max_tokens=256,
+ messages=[{"role": "user", "content": "Hello!"}]
+)
+print(response.content[0].text)
+# Response includes: response.id, response.model, response.stop_reason,
+# response.usage.input_tokens, response.usage.output_tokens
+```
+
+#### Streaming
+
+Streaming follows the Anthropic SSE event protocol. Events are emitted in this order:
+`message_start` -> `content_block_start` -> `content_block_delta` (repeated) -> `content_block_stop` -> `message_delta` -> `message_stop`
+
+```python
+with client.messages.stream(
+ model="default",
+ max_tokens=256,
+ messages=[{"role": "user", "content": "Tell me a story"}]
+) as stream:
+ for text in stream.text_stream:
+ print(text, end="")
+```
+
+#### System messages
+
+System messages can be a plain string or a list of content blocks:
+
+```python
+# Plain string
+response = client.messages.create(
+ model="default",
+ max_tokens=256,
+ system="You are a helpful coding assistant.",
+ messages=[{"role": "user", "content": "Write a hello world in Python"}]
+)
+
+# List of content blocks
+response = client.messages.create(
+ model="default",
+ max_tokens=256,
+ system=[
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise in your answers."},
+ ],
+ messages=[{"role": "user", "content": "What is 2+2?"}]
+)
+```
+
+#### Tool calling
+
+Define tools with `name`, `description`, and `input_schema`. The model returns `tool_use` content blocks when it wants to call a tool. Send results back as `tool_result` blocks.
+
+```python
+# Step 1: Send request with tools
+response = client.messages.create(
+ model="default",
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What's the weather in Paris?"}],
+ tools=[{
+ "name": "get_weather",
+ "description": "Get weather for a city",
+ "input_schema": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"]
+ }
+ }]
+)
+
+# Step 2: Check if model wants to use tools
+for block in response.content:
+ if block.type == "tool_use":
+ print(f"Tool: {block.name}, Input: {block.input}, ID: {block.id}")
+ # response.stop_reason will be "tool_use"
+
+# Step 3: Send tool result back
+response = client.messages.create(
+ model="default",
+ max_tokens=1024,
+ messages=[
+ {"role": "user", "content": "What's the weather in Paris?"},
+ {"role": "assistant", "content": response.content},
+ {"role": "user", "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": block.id,
+ "content": "Sunny, 22C"
+ }
+ ]}
+ ],
+ tools=[{
+ "name": "get_weather",
+ "description": "Get weather for a city",
+ "input_schema": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"]
+ }
+ }]
+)
+print(response.content[0].text) # "The weather in Paris is sunny, 22C."
+```
+
+Tool choice modes:
+
+| `tool_choice` | Behavior |
+|---------------|----------|
+| `{"type": "auto"}` | Model decides whether to call tools (default) |
+| `{"type": "any"}` | Model must call at least one tool |
+| `{"type": "tool", "name": "get_weather"}` | Model must call the specified tool |
+| `{"type": "none"}` | Model will not call any tools |
+
+#### Multi-turn conversations
+
+```python
+messages = [
+ {"role": "user", "content": "My name is Alice."},
+ {"role": "assistant", "content": "Nice to meet you, Alice!"},
+ {"role": "user", "content": "What's my name?"},
+]
+
+response = client.messages.create(
+ model="default",
+ max_tokens=100,
+ messages=messages
+)
+```
+
+#### Token counting
+
+```bash
+POST /v1/messages/count_tokens
+```
+
+Counts input tokens for an Anthropic request using the model's tokenizer. Useful for budget tracking before sending a request. Counts tokens from system messages, conversation messages, tool_use inputs, tool_result content, and tool definitions (name, description, input_schema).
+
+```python
+import requests
+
+resp = requests.post("http://localhost:8000/v1/messages/count_tokens", json={
+ "model": "default",
+ "messages": [{"role": "user", "content": "Hello, how are you?"}],
+ "system": "You are helpful.",
+ "tools": [{
+ "name": "search",
+ "description": "Search the web",
+ "input_schema": {"type": "object", "properties": {"q": {"type": "string"}}}
+ }]
+})
+print(resp.json()) # {"input_tokens": 42}
+```
+
+#### curl examples
+
+Non-streaming:
+
+```bash
+curl http://localhost:8000/v1/messages \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "default",
+ "max_tokens": 256,
+ "messages": [{"role": "user", "content": "Hello!"}]
+ }'
+```
+
+Streaming:
+
+```bash
+curl http://localhost:8000/v1/messages \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "default",
+ "max_tokens": 256,
+ "stream": true,
+ "messages": [{"role": "user", "content": "Tell me a joke"}]
+ }'
+```
+
+Token counting:
+
+```bash
+curl http://localhost:8000/v1/messages/count_tokens \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "default",
+ "messages": [{"role": "user", "content": "Hello!"}]
+ }'
+# {"input_tokens": 12}
+```
+
+#### Request fields
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `model` | string | yes | - | Model name (use `"default"` for the loaded model) |
+| `messages` | list | yes | - | Conversation messages with `role` and `content` |
+| `max_tokens` | int | yes | - | Maximum number of tokens to generate |
+| `system` | string or list | no | null | System prompt (string or list of `{"type": "text", "text": "..."}` blocks) |
+| `stream` | bool | no | false | Enable SSE streaming |
+| `temperature` | float | no | 0.7 | Sampling temperature (0.0 = deterministic, 1.0 = creative) |
+| `top_p` | float | no | 0.9 | Nucleus sampling threshold |
+| `top_k` | int | no | null | Top-k sampling |
+| `stop_sequences` | list | no | null | Sequences that stop generation |
+| `tools` | list | no | null | Tool definitions with `name`, `description`, `input_schema` |
+| `tool_choice` | dict | no | null | Tool selection mode (`auto`, `any`, `tool`, `none`) |
+| `metadata` | dict | no | null | Arbitrary metadata (passed through, not used by server) |
+
+#### Response format
+
+Non-streaming response:
+
+```json
+{
+ "id": "msg_abc123...",
+ "type": "message",
+ "role": "assistant",
+ "model": "default",
+ "content": [
+ {"type": "text", "text": "Hello! How can I help?"}
+ ],
+ "stop_reason": "end_turn",
+ "stop_sequence": null,
+ "usage": {
+ "input_tokens": 12,
+ "output_tokens": 8
+ }
+}
+```
+
+When tools are called, `content` includes `tool_use` blocks and `stop_reason` is `"tool_use"`:
+
+```json
+{
+ "content": [
+ {"type": "text", "text": "Let me check the weather."},
+ {
+ "type": "tool_use",
+ "id": "call_abc123",
+ "name": "get_weather",
+ "input": {"city": "Paris"}
+ }
+ ],
+ "stop_reason": "tool_use"
+}
+```
+
+Stop reasons:
+
+| `stop_reason` | Meaning |
+|---------------|---------|
+| `end_turn` | Model finished naturally |
+| `tool_use` | Model wants to call a tool |
+| `max_tokens` | Hit the `max_tokens` limit |
+
+#### Using with Claude Code
+
+Point Claude Code directly at your vllm-mlx server:
+
+```bash
+# Start the server
+vllm-mlx serve mlx-community/Qwen3-Coder-Next-235B-A22B-4bit \
+ --continuous-batching \
+ --enable-auto-tool-choice \
+ --tool-call-parser hermes
+
+# In another terminal, configure Claude Code
+export ANTHROPIC_BASE_URL=http://localhost:8000
+export ANTHROPIC_API_KEY=not-needed
+claude
+```
+
+### Server Status
+
+```bash
+GET /v1/status
+```
+
+Real-time monitoring endpoint that returns server-wide statistics and per-request details. Useful for debugging performance, tracking cache efficiency, and monitoring Metal GPU memory.
+
+```bash
+curl -s http://localhost:8000/v1/status | python -m json.tool
+```
+
+Example response:
+
+```json
+{
+ "status": "running",
+ "model": "mlx-community/Qwen3-8B-4bit",
+ "uptime_s": 342.5,
+ "steps_executed": 1247,
+ "num_running": 1,
+ "num_waiting": 0,
+ "total_requests_processed": 15,
+ "total_prompt_tokens": 28450,
+ "total_completion_tokens": 3200,
+ "metal": {
+ "active_memory_gb": 5.2,
+ "peak_memory_gb": 8.1,
+ "cache_memory_gb": 2.3
+ },
+ "cache": {
+ "type": "memory_aware_cache",
+ "entries": 5,
+ "hit_rate": 0.87,
+ "memory_mb": 2350
+ },
+ "requests": [
+ {
+ "request_id": "req_abc123",
+ "phase": "generation",
+ "tokens_per_second": 45.2,
+ "ttft_s": 0.8,
+ "progress": 0.35,
+ "cache_hit_type": "prefix",
+ "cached_tokens": 1200,
+ "generated_tokens": 85,
+ "max_tokens": 256
+ }
+ ]
+}
+```
+
+Response fields:
+
+| Field | Description |
+|-------|-------------|
+| `status` | Server state: `running`, `stopped`, or `not_loaded` |
+| `model` | Name of the loaded model |
+| `uptime_s` | Seconds since the server started |
+| `steps_executed` | Total inference steps executed |
+| `num_running` | Number of requests currently generating tokens |
+| `num_waiting` | Number of requests queued for prefill |
+| `total_requests_processed` | Total requests completed since startup |
+| `total_prompt_tokens` | Total prompt tokens processed since startup |
+| `total_completion_tokens` | Total completion tokens generated since startup |
+| `metal.active_memory_gb` | Current Metal GPU memory in use (GB) |
+| `metal.peak_memory_gb` | Peak Metal GPU memory usage (GB) |
+| `metal.cache_memory_gb` | Metal cache memory usage (GB) |
+| `cache` | Cache statistics (type, entries, hit rate, memory usage) |
+| `requests` | List of active requests with per-request details |
+
+Per-request fields in `requests`:
+
+| Field | Description |
+|-------|-------------|
+| `request_id` | Unique request identifier |
+| `phase` | Current phase: `queued`, `prefill`, or `generation` |
+| `tokens_per_second` | Generation throughput for this request |
+| `ttft_s` | Time to first token (seconds) |
+| `progress` | Completion percentage (0.0 to 1.0) |
+| `cache_hit_type` | Cache match type: `exact`, `prefix`, `supersequence`, `lcp`, or `miss` |
+| `cached_tokens` | Number of tokens served from cache |
+| `generated_tokens` | Tokens generated so far |
+| `max_tokens` | Maximum tokens requested |
+
## Tool Calling
Enable OpenAI-compatible tool calling with `--enable-auto-tool-choice`:
diff --git a/tests/test_anthropic_adapter.py b/tests/test_anthropic_adapter.py
new file mode 100644
index 00000000..3fa179b4
--- /dev/null
+++ b/tests/test_anthropic_adapter.py
@@ -0,0 +1,457 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Tests for Anthropic-to-OpenAI adapter conversion functions.
+
+Tests all conversion functions in vllm_mlx/api/anthropic_adapter.py.
+These are pure logic tests with no MLX dependency.
+"""
+
+import json
+
+from vllm_mlx.api.anthropic_adapter import (
+ _convert_message,
+ _convert_stop_reason,
+ _convert_tool,
+ _convert_tool_choice,
+ anthropic_to_openai,
+ openai_to_anthropic,
+)
+from vllm_mlx.api.anthropic_models import (
+ AnthropicContentBlock,
+ AnthropicMessage,
+ AnthropicRequest,
+ AnthropicToolDef,
+)
+from vllm_mlx.api.models import (
+ AssistantMessage,
+ ChatCompletionChoice,
+ ChatCompletionResponse,
+ FunctionCall,
+ ToolCall,
+ Usage,
+)
+
+
+class TestConvertStopReason:
+ """Tests for _convert_stop_reason."""
+
+ def test_stop_to_end_turn(self):
+ assert _convert_stop_reason("stop") == "end_turn"
+
+ def test_tool_calls_to_tool_use(self):
+ assert _convert_stop_reason("tool_calls") == "tool_use"
+
+ def test_length_to_max_tokens(self):
+ assert _convert_stop_reason("length") == "max_tokens"
+
+ def test_content_filter_to_end_turn(self):
+ assert _convert_stop_reason("content_filter") == "end_turn"
+
+ def test_none_to_end_turn(self):
+ assert _convert_stop_reason(None) == "end_turn"
+
+ def test_unknown_to_end_turn(self):
+ assert _convert_stop_reason("something_else") == "end_turn"
+
+
+class TestConvertToolChoice:
+ """Tests for _convert_tool_choice."""
+
+ def test_auto(self):
+ assert _convert_tool_choice({"type": "auto"}) == "auto"
+
+ def test_any_to_required(self):
+ assert _convert_tool_choice({"type": "any"}) == "required"
+
+ def test_none_type(self):
+ assert _convert_tool_choice({"type": "none"}) == "none"
+
+ def test_specific_tool(self):
+ result = _convert_tool_choice({"type": "tool", "name": "search"})
+ assert result == {
+ "type": "function",
+ "function": {"name": "search"},
+ }
+
+ def test_missing_type_defaults_to_auto(self):
+ assert _convert_tool_choice({}) == "auto"
+
+ def test_unknown_type_defaults_to_auto(self):
+ assert _convert_tool_choice({"type": "unknown"}) == "auto"
+
+
+class TestConvertTool:
+ """Tests for _convert_tool."""
+
+ def test_minimal_tool(self):
+ tool = AnthropicToolDef(name="search")
+ result = _convert_tool(tool)
+ assert result.type == "function"
+ assert result.function["name"] == "search"
+ assert result.function["description"] == ""
+ assert result.function["parameters"] == {"type": "object", "properties": {}}
+
+ def test_full_tool(self):
+ tool = AnthropicToolDef(
+ name="get_weather",
+ description="Get weather for a city",
+ input_schema={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ )
+ result = _convert_tool(tool)
+ assert result.function["name"] == "get_weather"
+ assert result.function["description"] == "Get weather for a city"
+ assert result.function["parameters"]["required"] == ["city"]
+
+
+class TestConvertMessage:
+ """Tests for _convert_message."""
+
+ def test_simple_user_string(self):
+ msg = AnthropicMessage(role="user", content="hello")
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "user"
+ assert result[0].content == "hello"
+
+ def test_simple_assistant_string(self):
+ msg = AnthropicMessage(role="assistant", content="hi there")
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "assistant"
+ assert result[0].content == "hi there"
+
+ def test_user_with_text_blocks(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[
+ AnthropicContentBlock(type="text", text="first"),
+ AnthropicContentBlock(type="text", text="second"),
+ ],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "user"
+ assert result[0].content == "first\nsecond"
+
+ def test_user_with_tool_results(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[
+ AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_1",
+ content="sunny, 22C",
+ ),
+ AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_2",
+ content="rainy, 15C",
+ ),
+ ],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 2
+ assert result[0].role == "tool"
+ assert result[0].content == "sunny, 22C"
+ assert result[0].tool_call_id == "call_1"
+ assert result[1].role == "tool"
+ assert result[1].content == "rainy, 15C"
+
+ def test_user_with_text_and_tool_results(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[
+ AnthropicContentBlock(type="text", text="here are results"),
+ AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_1",
+ content="done",
+ ),
+ ],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 2
+ assert result[0].role == "user"
+ assert result[0].content == "here are results"
+ assert result[1].role == "tool"
+
+ def test_tool_result_with_list_content(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[
+ AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_1",
+ content=[
+ {"type": "text", "text": "line one"},
+ {"type": "text", "text": "line two"},
+ ],
+ ),
+ ],
+ )
+ result = _convert_message(msg)
+ assert result[0].role == "tool"
+ assert result[0].content == "line one\nline two"
+
+ def test_tool_result_with_none_content(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[
+ AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_1",
+ content=None,
+ ),
+ ],
+ )
+ result = _convert_message(msg)
+ assert result[0].content == ""
+
+ def test_assistant_with_tool_use(self):
+ msg = AnthropicMessage(
+ role="assistant",
+ content=[
+ AnthropicContentBlock(type="text", text="Let me check."),
+ AnthropicContentBlock(
+ type="tool_use",
+ id="call_abc",
+ name="search",
+ input={"q": "weather"},
+ ),
+ ],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "assistant"
+ assert result[0].content == "Let me check."
+ assert len(result[0].tool_calls) == 1
+ assert result[0].tool_calls[0]["function"]["name"] == "search"
+ args = json.loads(result[0].tool_calls[0]["function"]["arguments"])
+ assert args == {"q": "weather"}
+
+ def test_assistant_empty_content(self):
+ msg = AnthropicMessage(
+ role="assistant",
+ content=[],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "assistant"
+ assert result[0].content == ""
+
+ def test_user_empty_content(self):
+ msg = AnthropicMessage(
+ role="user",
+ content=[],
+ )
+ result = _convert_message(msg)
+ assert len(result) == 1
+ assert result[0].role == "user"
+ assert result[0].content == ""
+
+
+class TestAnthropicToOpenai:
+ """Tests for anthropic_to_openai conversion."""
+
+ def _make_request(self, **kwargs):
+ defaults = {
+ "model": "default",
+ "messages": [AnthropicMessage(role="user", content="hi")],
+ "max_tokens": 100,
+ }
+ defaults.update(kwargs)
+ return AnthropicRequest(**defaults)
+
+ def test_simple_request(self):
+ req = self._make_request()
+ result = anthropic_to_openai(req)
+ assert result.model == "default"
+ assert result.max_tokens == 100
+ assert len(result.messages) == 1
+ assert result.messages[0].role == "user"
+ assert result.messages[0].content == "hi"
+
+ def test_system_string(self):
+ req = self._make_request(system="Be helpful.")
+ result = anthropic_to_openai(req)
+ assert len(result.messages) == 2
+ assert result.messages[0].role == "system"
+ assert result.messages[0].content == "Be helpful."
+ assert result.messages[1].role == "user"
+
+ def test_system_list(self):
+ req = self._make_request(system=[{"type": "text", "text": "Be concise."}])
+ result = anthropic_to_openai(req)
+ assert result.messages[0].role == "system"
+ assert result.messages[0].content == "Be concise."
+
+ def test_temperature_default(self):
+ req = self._make_request()
+ result = anthropic_to_openai(req)
+ assert result.temperature == 0.7
+
+ def test_temperature_explicit(self):
+ req = self._make_request(temperature=0.3)
+ result = anthropic_to_openai(req)
+ assert result.temperature == 0.3
+
+ def test_top_p_default(self):
+ req = self._make_request()
+ result = anthropic_to_openai(req)
+ assert result.top_p == 0.9
+
+ def test_top_p_explicit(self):
+ req = self._make_request(top_p=0.5)
+ result = anthropic_to_openai(req)
+ assert result.top_p == 0.5
+
+ def test_stop_sequences(self):
+ req = self._make_request(stop_sequences=["END", "STOP"])
+ result = anthropic_to_openai(req)
+ assert result.stop == ["END", "STOP"]
+
+ def test_stream_flag(self):
+ req = self._make_request(stream=True)
+ result = anthropic_to_openai(req)
+ assert result.stream is True
+
+ def test_tools_conversion(self):
+ req = self._make_request(
+ tools=[
+ AnthropicToolDef(
+ name="search",
+ description="Search the web",
+ input_schema={
+ "type": "object",
+ "properties": {"q": {"type": "string"}},
+ },
+ )
+ ]
+ )
+ result = anthropic_to_openai(req)
+ assert len(result.tools) == 1
+ assert result.tools[0].function["name"] == "search"
+
+ def test_tool_choice_conversion(self):
+ req = self._make_request(tool_choice={"type": "any"})
+ result = anthropic_to_openai(req)
+ assert result.tool_choice == "required"
+
+ def test_no_tools(self):
+ req = self._make_request()
+ result = anthropic_to_openai(req)
+ assert result.tools is None
+ assert result.tool_choice is None
+
+ def test_multiple_messages(self):
+ msgs = [
+ AnthropicMessage(role="user", content="hello"),
+ AnthropicMessage(role="assistant", content="hi"),
+ AnthropicMessage(role="user", content="how are you"),
+ ]
+ req = self._make_request(messages=msgs)
+ result = anthropic_to_openai(req)
+ assert len(result.messages) == 3
+ assert result.messages[0].role == "user"
+ assert result.messages[1].role == "assistant"
+ assert result.messages[2].role == "user"
+
+
+class TestOpenaiToAnthropic:
+ """Tests for openai_to_anthropic conversion."""
+
+ def _make_response(self, content="hello", finish_reason="stop", tool_calls=None):
+ msg = AssistantMessage(content=content, tool_calls=tool_calls)
+ choice = ChatCompletionChoice(message=msg, finish_reason=finish_reason)
+ return ChatCompletionResponse(
+ model="default",
+ choices=[choice],
+ usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
+ )
+
+ def test_simple_text_response(self):
+ resp = self._make_response(content="hi there")
+ result = openai_to_anthropic(resp, "default")
+ assert result.model == "default"
+ assert result.type == "message"
+ assert result.role == "assistant"
+ assert len(result.content) == 1
+ assert result.content[0].type == "text"
+ assert result.content[0].text == "hi there"
+ assert result.stop_reason == "end_turn"
+
+ def test_usage_mapping(self):
+ resp = self._make_response()
+ result = openai_to_anthropic(resp, "default")
+ assert result.usage.input_tokens == 10
+ assert result.usage.output_tokens == 5
+
+ def test_tool_calls_response(self):
+ tc = ToolCall(
+ id="call_1",
+ type="function",
+ function=FunctionCall(
+ name="search",
+ arguments='{"q": "test"}',
+ ),
+ )
+ resp = self._make_response(
+ content="Let me search.",
+ finish_reason="tool_calls",
+ tool_calls=[tc],
+ )
+ result = openai_to_anthropic(resp, "default")
+ assert len(result.content) == 2
+ assert result.content[0].type == "text"
+ assert result.content[0].text == "Let me search."
+ assert result.content[1].type == "tool_use"
+ assert result.content[1].name == "search"
+ assert result.content[1].input == {"q": "test"}
+ assert result.stop_reason == "tool_use"
+
+ def test_tool_call_invalid_json_arguments(self):
+ tc = ToolCall(
+ id="call_1",
+ type="function",
+ function=FunctionCall(name="search", arguments="not json"),
+ )
+ resp = self._make_response(
+ content=None, finish_reason="tool_calls", tool_calls=[tc]
+ )
+ result = openai_to_anthropic(resp, "default")
+ tool_block = [b for b in result.content if b.type == "tool_use"][0]
+ assert tool_block.input == {}
+
+ def test_empty_choices(self):
+ resp = ChatCompletionResponse(
+ model="default",
+ choices=[],
+ usage=Usage(),
+ )
+ result = openai_to_anthropic(resp, "default")
+ assert result.stop_reason == "end_turn"
+ assert len(result.content) == 1
+ assert result.content[0].type == "text"
+ assert result.content[0].text == ""
+
+ def test_no_content_adds_empty_text(self):
+ resp = self._make_response(content=None)
+ result = openai_to_anthropic(resp, "default")
+ assert len(result.content) >= 1
+ has_text = any(b.type == "text" for b in result.content)
+ assert has_text
+
+ def test_stop_reason_length(self):
+ resp = self._make_response(finish_reason="length")
+ result = openai_to_anthropic(resp, "default")
+ assert result.stop_reason == "max_tokens"
+
+ def test_response_has_id(self):
+ resp = self._make_response()
+ result = openai_to_anthropic(resp, "test-model")
+ assert result.id.startswith("msg_")
+ assert result.model == "test-model"
diff --git a/tests/test_anthropic_models.py b/tests/test_anthropic_models.py
new file mode 100644
index 00000000..5be0f23d
--- /dev/null
+++ b/tests/test_anthropic_models.py
@@ -0,0 +1,360 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Tests for Anthropic Messages API Pydantic models.
+
+Tests all request/response models in vllm_mlx/api/anthropic_models.py.
+These are pure Pydantic models with no MLX dependency.
+"""
+
+import pytest
+from pydantic import ValidationError
+
+from vllm_mlx.api.anthropic_models import (
+ AnthropicContentBlock,
+ AnthropicMessage,
+ AnthropicRequest,
+ AnthropicResponse,
+ AnthropicResponseContentBlock,
+ AnthropicToolDef,
+ AnthropicUsage,
+)
+
+
+class TestAnthropicContentBlock:
+ """Tests for AnthropicContentBlock model."""
+
+ def test_text_block(self):
+ block = AnthropicContentBlock(type="text", text="hello")
+ assert block.type == "text"
+ assert block.text == "hello"
+
+ def test_tool_use_block(self):
+ block = AnthropicContentBlock(
+ type="tool_use",
+ id="call_123",
+ name="get_weather",
+ input={"city": "Paris"},
+ )
+ assert block.type == "tool_use"
+ assert block.id == "call_123"
+ assert block.name == "get_weather"
+ assert block.input == {"city": "Paris"}
+
+ def test_tool_result_block(self):
+ block = AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_123",
+ content="sunny",
+ )
+ assert block.type == "tool_result"
+ assert block.tool_use_id == "call_123"
+ assert block.content == "sunny"
+
+ def test_tool_result_with_error(self):
+ block = AnthropicContentBlock(
+ type="tool_result",
+ tool_use_id="call_123",
+ content="not found",
+ is_error=True,
+ )
+ assert block.is_error is True
+
+ def test_image_block(self):
+ block = AnthropicContentBlock(
+ type="image",
+ source={"type": "base64", "media_type": "image/png", "data": "abc"},
+ )
+ assert block.type == "image"
+ assert block.source["type"] == "base64"
+
+ def test_optional_fields_default_to_none(self):
+ block = AnthropicContentBlock(type="text")
+ assert block.text is None
+ assert block.id is None
+ assert block.name is None
+ assert block.input is None
+ assert block.tool_use_id is None
+ assert block.content is None
+ assert block.is_error is None
+ assert block.source is None
+
+
+class TestAnthropicMessage:
+ """Tests for AnthropicMessage model."""
+
+ def test_string_content(self):
+ msg = AnthropicMessage(role="user", content="hello")
+ assert msg.role == "user"
+ assert msg.content == "hello"
+
+ def test_list_content(self):
+ blocks = [
+ AnthropicContentBlock(type="text", text="look at this"),
+ AnthropicContentBlock(
+ type="image",
+ source={"type": "base64", "media_type": "image/png", "data": "abc"},
+ ),
+ ]
+ msg = AnthropicMessage(role="user", content=blocks)
+ assert len(msg.content) == 2
+ assert msg.content[0].type == "text"
+ assert msg.content[1].type == "image"
+
+ def test_assistant_role(self):
+ msg = AnthropicMessage(role="assistant", content="hi there")
+ assert msg.role == "assistant"
+
+ def test_missing_role_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicMessage(content="hello")
+
+ def test_missing_content_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicMessage(role="user")
+
+
+class TestAnthropicToolDef:
+ """Tests for AnthropicToolDef model."""
+
+ def test_minimal(self):
+ tool = AnthropicToolDef(name="get_weather")
+ assert tool.name == "get_weather"
+ assert tool.description is None
+ assert tool.input_schema is None
+
+ def test_full(self):
+ tool = AnthropicToolDef(
+ name="get_weather",
+ description="Get weather for a city",
+ input_schema={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ )
+ assert tool.description == "Get weather for a city"
+ assert tool.input_schema["required"] == ["city"]
+
+ def test_missing_name_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicToolDef(description="no name")
+
+
+class TestAnthropicRequest:
+ """Tests for AnthropicRequest model."""
+
+ def test_minimal_request(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ )
+ assert req.model == "default"
+ assert req.max_tokens == 100
+ assert req.stream is False
+ assert req.temperature is None
+ assert req.top_p is None
+ assert req.tools is None
+ assert req.system is None
+
+ def test_with_system_string(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ system="You are helpful.",
+ )
+ assert req.system == "You are helpful."
+
+ def test_with_system_list(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ system=[{"type": "text", "text": "Be concise."}],
+ )
+ assert isinstance(req.system, list)
+ assert req.system[0]["text"] == "Be concise."
+
+ def test_with_tools(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ tools=[AnthropicToolDef(name="search")],
+ )
+ assert len(req.tools) == 1
+ assert req.tools[0].name == "search"
+
+ def test_with_tool_choice(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ tool_choice={"type": "auto"},
+ )
+ assert req.tool_choice == {"type": "auto"}
+
+ def test_streaming(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ stream=True,
+ )
+ assert req.stream is True
+
+ def test_all_optional_params(self):
+ req = AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=256,
+ temperature=0.5,
+ top_p=0.9,
+ top_k=40,
+ stop_sequences=["END"],
+ metadata={"user_id": "123"},
+ )
+ assert req.temperature == 0.5
+ assert req.top_p == 0.9
+ assert req.top_k == 40
+ assert req.stop_sequences == ["END"]
+ assert req.metadata == {"user_id": "123"}
+
+ def test_missing_model_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicRequest(
+ messages=[AnthropicMessage(role="user", content="hi")],
+ max_tokens=100,
+ )
+
+ def test_missing_messages_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicRequest(model="default", max_tokens=100)
+
+ def test_missing_max_tokens_raises(self):
+ with pytest.raises(ValidationError):
+ AnthropicRequest(
+ model="default",
+ messages=[AnthropicMessage(role="user", content="hi")],
+ )
+
+
+class TestAnthropicUsage:
+ """Tests for AnthropicUsage model."""
+
+ def test_defaults(self):
+ usage = AnthropicUsage()
+ assert usage.input_tokens == 0
+ assert usage.output_tokens == 0
+ assert usage.cache_creation_input_tokens is None
+ assert usage.cache_read_input_tokens is None
+
+ def test_with_values(self):
+ usage = AnthropicUsage(input_tokens=100, output_tokens=50)
+ assert usage.input_tokens == 100
+ assert usage.output_tokens == 50
+
+ def test_with_cache_fields(self):
+ usage = AnthropicUsage(
+ input_tokens=100,
+ output_tokens=50,
+ cache_creation_input_tokens=20,
+ cache_read_input_tokens=80,
+ )
+ assert usage.cache_creation_input_tokens == 20
+ assert usage.cache_read_input_tokens == 80
+
+
+class TestAnthropicResponseContentBlock:
+ """Tests for AnthropicResponseContentBlock model."""
+
+ def test_text_block(self):
+ block = AnthropicResponseContentBlock(type="text", text="hello")
+ assert block.type == "text"
+ assert block.text == "hello"
+
+ def test_tool_use_block(self):
+ block = AnthropicResponseContentBlock(
+ type="tool_use",
+ id="call_abc",
+ name="search",
+ input={"query": "test"},
+ )
+ assert block.type == "tool_use"
+ assert block.id == "call_abc"
+ assert block.name == "search"
+ assert block.input == {"query": "test"}
+
+ def test_optional_fields_default_to_none(self):
+ block = AnthropicResponseContentBlock(type="text")
+ assert block.text is None
+ assert block.id is None
+ assert block.name is None
+ assert block.input is None
+
+
+class TestAnthropicResponse:
+ """Tests for AnthropicResponse model."""
+
+ def test_minimal_response(self):
+ resp = AnthropicResponse(
+ model="default",
+ content=[AnthropicResponseContentBlock(type="text", text="hi")],
+ )
+ assert resp.model == "default"
+ assert resp.type == "message"
+ assert resp.role == "assistant"
+ assert resp.id.startswith("msg_")
+ assert len(resp.id) == len("msg_") + 24
+ assert resp.stop_reason is None
+ assert resp.stop_sequence is None
+ assert resp.usage.input_tokens == 0
+ assert resp.usage.output_tokens == 0
+
+ def test_with_stop_reason(self):
+ resp = AnthropicResponse(
+ model="default",
+ content=[AnthropicResponseContentBlock(type="text", text="done")],
+ stop_reason="end_turn",
+ )
+ assert resp.stop_reason == "end_turn"
+
+ def test_with_usage(self):
+ resp = AnthropicResponse(
+ model="default",
+ content=[AnthropicResponseContentBlock(type="text", text="hi")],
+ usage=AnthropicUsage(input_tokens=10, output_tokens=5),
+ )
+ assert resp.usage.input_tokens == 10
+ assert resp.usage.output_tokens == 5
+
+ def test_unique_ids(self):
+ r1 = AnthropicResponse(
+ model="default",
+ content=[AnthropicResponseContentBlock(type="text", text="a")],
+ )
+ r2 = AnthropicResponse(
+ model="default",
+ content=[AnthropicResponseContentBlock(type="text", text="b")],
+ )
+ assert r1.id != r2.id
+
+ def test_tool_use_response(self):
+ resp = AnthropicResponse(
+ model="default",
+ content=[
+ AnthropicResponseContentBlock(type="text", text="Let me search."),
+ AnthropicResponseContentBlock(
+ type="tool_use",
+ id="call_1",
+ name="search",
+ input={"q": "test"},
+ ),
+ ],
+ stop_reason="tool_use",
+ )
+ assert len(resp.content) == 2
+ assert resp.content[0].type == "text"
+ assert resp.content[1].type == "tool_use"
+ assert resp.stop_reason == "tool_use"
diff --git a/tests/test_batching.py b/tests/test_batching.py
index 10e73c5b..7dc050ee 100644
--- a/tests/test_batching.py
+++ b/tests/test_batching.py
@@ -261,7 +261,7 @@ def test_add_duplicate_request(self, mock_model, mock_tokenizer):
scheduler.add_request(request)
def test_abort_waiting_request(self, mock_model, mock_tokenizer):
- """Test aborting a waiting request."""
+ """Test aborting a waiting request (deferred abort pattern)."""
scheduler = Scheduler(
model=mock_model,
tokenizer=mock_tokenizer,
@@ -276,21 +276,26 @@ def test_abort_waiting_request(self, mock_model, mock_tokenizer):
scheduler.add_request(request)
assert scheduler.get_num_waiting() == 1
+ # abort_request() enqueues for deferred processing
result = scheduler.abort_request("test-1")
-
assert result is True
+
+ # Process pending aborts (normally happens inside step())
+ scheduler._process_pending_aborts()
+
assert scheduler.get_num_waiting() == 0
assert "test-1" in scheduler.finished_req_ids
def test_abort_nonexistent_request(self, mock_model, mock_tokenizer):
- """Test aborting non-existent request."""
+ """Test aborting non-existent request (deferred abort always enqueues)."""
scheduler = Scheduler(
model=mock_model,
tokenizer=mock_tokenizer,
)
+ # abort_request() always returns True (enqueue is always successful)
result = scheduler.abort_request("nonexistent")
- assert result is False
+ assert result is True
def test_get_stats(self, mock_model, mock_tokenizer):
"""Test getting scheduler stats."""
diff --git a/tests/test_memory_cache.py b/tests/test_memory_cache.py
index 77c330d6..832dd7e9 100644
--- a/tests/test_memory_cache.py
+++ b/tests/test_memory_cache.py
@@ -10,6 +10,7 @@
MemoryAwarePrefixCache,
MemoryCacheConfig,
_CacheEntry,
+ _array_memory,
_get_available_memory,
estimate_kv_cache_memory,
)
@@ -116,6 +117,21 @@ def __init__(self, nbytes: int):
self.nbytes = nbytes
+class MockDtype:
+ """Mock dtype with size attribute."""
+
+ def __init__(self, size: int):
+ self.size = size
+
+
+class MockShapeArray:
+ """Mock array with shape and dtype (like MLX arrays) but no nbytes."""
+
+ def __init__(self, shape: tuple, dtype_size: int):
+ self.shape = shape
+ self.dtype = MockDtype(dtype_size)
+
+
class MockKVCache:
"""Mock KV cache with keys/values attributes."""
@@ -136,6 +152,46 @@ def state(self):
return (self._keys, self._values)
+class TestArrayMemory:
+ """Tests for _array_memory helper (shape-based, no lazy eval trigger)."""
+
+ def test_shape_dtype_estimation(self):
+ """Verify shape*dtype.size computation without .nbytes access."""
+ arr = MockShapeArray(shape=(2, 16, 128, 64), dtype_size=2)
+ # 2 * 16 * 128 * 64 * 2 = 524288
+ assert _array_memory(arr) == 2 * 16 * 128 * 64 * 2
+
+ def test_fallback_to_nbytes(self):
+ """Verify fallback to .nbytes when shape/dtype not available."""
+ arr = MockArray(nbytes=4096)
+ assert _array_memory(arr) == 4096
+
+ def test_zero_for_unknown_object(self):
+ """Return 0 for objects without shape/dtype/nbytes."""
+ assert _array_memory(42) == 0
+ assert _array_memory("string") == 0
+
+ def test_shape_dtype_preferred_over_nbytes(self):
+ """When both shape+dtype and nbytes exist, shape+dtype is used."""
+
+ class DualArray:
+ def __init__(self):
+ self.shape = (10,)
+ self.dtype = MockDtype(4)
+ self.nbytes = 9999 # should NOT be used
+
+ arr = DualArray()
+ assert _array_memory(arr) == 40 # 10 * 4, not 9999
+
+ def test_estimate_uses_shape_based_for_dict_state(self):
+ """estimate_kv_cache_memory uses _array_memory (shape-based) for dicts."""
+ keys = MockShapeArray(shape=(1, 8, 100, 64), dtype_size=2)
+ values = MockShapeArray(shape=(1, 8, 100, 64), dtype_size=2)
+ layer = {"state": (keys, values)}
+ expected = 2 * (1 * 8 * 100 * 64 * 2)
+ assert estimate_kv_cache_memory([layer]) == expected
+
+
class TestEstimateKvCacheMemory:
"""Tests for estimate_kv_cache_memory function."""
diff --git a/tests/test_memory_stability.py b/tests/test_memory_stability.py
new file mode 100644
index 00000000..f332d1b4
--- /dev/null
+++ b/tests/test_memory_stability.py
@@ -0,0 +1,271 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Tests for VRAM memory stability fixes.
+
+Verifies that:
+1. BatchGenerator.close() is called when replacing/discarding generators
+2. Periodic mx.clear_cache() is triggered in generation loop
+3. Metal memory stats are reported in get_stats()
+"""
+
+from unittest.mock import MagicMock, patch
+
+from vllm_mlx.request import SamplingParams
+from vllm_mlx.scheduler import Scheduler, SchedulerConfig
+
+
+def _make_scheduler(
+ enable_prefix_cache=False,
+) -> Scheduler:
+ """Create a scheduler with mocked model/tokenizer for unit tests."""
+ model = MagicMock()
+ tokenizer = MagicMock()
+ tokenizer.encode = lambda x: list(range(len(x.split())))
+ tokenizer.eos_token_id = 0
+
+ config = SchedulerConfig(
+ max_num_seqs=4,
+ enable_prefix_cache=enable_prefix_cache,
+ )
+ return Scheduler(model, tokenizer, config)
+
+
+class TestBatchGeneratorClose:
+ """Tests that BatchGenerator.close() is called properly."""
+
+ def test_close_called_on_replacement(self):
+ """Verify .close() is called when BatchGenerator is replaced."""
+ scheduler = _make_scheduler()
+
+ # Create a mock BatchGenerator with close()
+ old_generator = MagicMock()
+ old_generator.close = MagicMock()
+ scheduler.batch_generator = old_generator
+
+ # Trigger replacement via _close_batch_generator
+ scheduler._close_batch_generator()
+
+ old_generator.close.assert_called_once()
+ assert scheduler.batch_generator is None
+
+ def test_close_called_on_reset(self):
+ """Verify .close() is called during reset()."""
+ scheduler = _make_scheduler()
+
+ mock_generator = MagicMock()
+ mock_generator.close = MagicMock()
+ scheduler.batch_generator = mock_generator
+
+ scheduler.reset()
+
+ mock_generator.close.assert_called_once()
+ assert scheduler.batch_generator is None
+
+ def test_close_called_on_cache_error_recovery(self):
+ """Verify .close() is called during _recover_from_cache_error()."""
+ scheduler = _make_scheduler()
+
+ mock_generator = MagicMock()
+ mock_generator.close = MagicMock()
+ scheduler.batch_generator = mock_generator
+
+ scheduler._recover_from_cache_error()
+
+ mock_generator.close.assert_called_once()
+ assert scheduler.batch_generator is None
+
+ def test_close_not_called_when_none(self):
+ """Verify no error when batch_generator is already None."""
+ scheduler = _make_scheduler()
+ assert scheduler.batch_generator is None
+
+ # Should not raise
+ scheduler._close_batch_generator()
+ assert scheduler.batch_generator is None
+
+ def test_close_exception_is_caught(self):
+ """Verify exceptions in close() are caught gracefully."""
+ scheduler = _make_scheduler()
+
+ mock_generator = MagicMock()
+ mock_generator.close = MagicMock(side_effect=RuntimeError("close failed"))
+ scheduler.batch_generator = mock_generator
+
+ # Should not raise
+ scheduler._close_batch_generator()
+ assert scheduler.batch_generator is None
+
+ def test_close_called_in_ensure_batch_generator(self):
+ """Verify _close_batch_generator is called when _ensure_batch_generator replaces."""
+ scheduler = _make_scheduler()
+
+ mock_generator = MagicMock()
+ mock_generator.close = MagicMock()
+ scheduler.batch_generator = mock_generator
+ scheduler._current_sampler_params = (0.5, 0.9, 0.0)
+
+ # Patch _create_batch_generator to return a new mock
+ new_generator = MagicMock()
+ with patch.object(
+ scheduler, "_create_batch_generator", return_value=new_generator
+ ):
+ # Different params forces recreation
+ params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=100)
+ scheduler._ensure_batch_generator(params)
+
+ # Old generator should have been closed
+ mock_generator.close.assert_called_once()
+ assert scheduler.batch_generator is new_generator
+
+
+class TestClearCacheInterval:
+ """Tests for periodic mx.clear_cache() calls."""
+
+ def test_clear_cache_interval_configured(self):
+ """Verify default clear_cache interval is set."""
+ scheduler = _make_scheduler()
+ assert scheduler._step_count == 0
+ assert scheduler._clear_cache_interval == 32
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_clear_cache_called_periodically(self, mock_mx):
+ """Verify mx.clear_cache() is called every _clear_cache_interval steps."""
+ scheduler = _make_scheduler()
+ scheduler._clear_cache_interval = 4 # Small interval for testing
+
+ # Simulate steps without actual generation (no running requests)
+ for _i in range(8):
+ scheduler.step()
+
+ # Should have been called at step 4 and 8
+ assert mock_mx.clear_cache.call_count >= 2
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_clear_cache_called_on_cleanup(self, mock_mx):
+ """Verify mx.clear_cache() is called when requests finish."""
+ scheduler = _make_scheduler()
+
+ # Call _cleanup_finished with non-empty set
+ scheduler._cleanup_finished({"req-1"})
+
+ mock_mx.clear_cache.assert_called()
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_clear_cache_not_called_on_empty_cleanup(self, mock_mx):
+ """Verify mx.clear_cache() is NOT called when no requests finish."""
+ scheduler = _make_scheduler()
+
+ scheduler._cleanup_finished(set())
+
+ mock_mx.clear_cache.assert_not_called()
+
+
+class TestIncrementalCacheEval:
+ """Tests for incremental per-layer cache evaluation in _cleanup_finished()."""
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_incremental_eval_called_per_layer(self, mock_mx):
+ """Verify mx.eval is called per layer during cleanup, not as one batch."""
+ scheduler = _make_scheduler()
+
+ # Create a mock request with extracted cache (dict-state format)
+ mock_request = MagicMock()
+ mock_request.prompt_token_ids = [1, 2, 3]
+ mock_request.output_token_ids = [4, 5]
+ mock_keys_1 = MagicMock()
+ mock_values_1 = MagicMock()
+ mock_keys_2 = MagicMock()
+ mock_values_2 = MagicMock()
+ mock_request._extracted_cache = [
+ {"state": (mock_keys_1, mock_values_1)},
+ {"state": (mock_keys_2, mock_values_2)},
+ ]
+
+ scheduler.running["req-1"] = mock_request
+
+ scheduler._cleanup_finished({"req-1"})
+
+ # mx.eval should have been called once per layer (2 layers)
+ eval_calls = mock_mx.eval.call_args_list
+ assert len(eval_calls) == 2
+ # First call with layer 1 keys/values
+ assert eval_calls[0] == ((mock_keys_1, mock_values_1),)
+ # Second call with layer 2 keys/values
+ assert eval_calls[1] == ((mock_keys_2, mock_values_2),)
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_no_eval_when_no_extracted_cache(self, mock_mx):
+ """Verify mx.eval is not called when request has no extracted cache."""
+ scheduler = _make_scheduler()
+
+ mock_request = MagicMock()
+ mock_request.prompt_token_ids = [1, 2, 3]
+ mock_request.output_token_ids = [4, 5]
+ mock_request._extracted_cache = None
+
+ scheduler.running["req-1"] = mock_request
+
+ scheduler._cleanup_finished({"req-1"})
+
+ # mx.eval should NOT have been called (only mx.clear_cache for cleanup)
+ mock_mx.eval.assert_not_called()
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_no_eager_eval_in_extraction_path(self, mock_mx):
+ """Verify mx.eval(mx.array(0)) is NOT called during cache extraction."""
+ scheduler = _make_scheduler()
+
+ # Create a mock response with prompt_cache
+ mock_response = MagicMock()
+ mock_response.uid = 42
+ mock_response.token = 100
+ mock_response.finish_reason = "stop"
+ mock_response.prompt_cache = [MagicMock()]
+
+ # Setup request/uid mapping
+ mock_request = MagicMock()
+ mock_request.request_id = "req-1"
+ mock_request.output_token_ids = [100]
+ mock_request.num_output_tokens = 1
+ mock_request.num_prompt_tokens = 3
+ scheduler.running["req-1"] = mock_request
+ scheduler.uid_to_request_id[42] = "req-1"
+
+ scheduler._process_batch_responses([mock_response])
+
+ # Verify mx.eval was NOT called with mx.array(0) — the old spike pattern
+ for call_args in mock_mx.eval.call_args_list:
+ args = call_args[0]
+ # Should not be called with a single mx.array argument
+ assert not (len(args) == 1 and args[0] == mock_mx.array(0))
+
+
+class TestMemoryStats:
+ """Tests for Metal memory stats in get_stats()."""
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_metal_stats_included(self, mock_mx):
+ """Verify Metal memory stats appear in get_stats()."""
+ mock_mx.metal.is_available.return_value = True
+ mock_mx.metal.get_active_memory.return_value = 10_000_000_000 # 10GB
+ mock_mx.metal.get_peak_memory.return_value = 15_000_000_000 # 15GB
+ mock_mx.metal.get_cache_memory.return_value = 2_000_000_000 # 2GB
+
+ scheduler = _make_scheduler()
+ stats = scheduler.get_stats()
+
+ assert stats["metal_active_memory_gb"] == 10.0
+ assert stats["metal_peak_memory_gb"] == 15.0
+ assert stats["metal_cache_memory_gb"] == 2.0
+
+ @patch("vllm_mlx.scheduler.mx")
+ def test_metal_stats_graceful_on_error(self, mock_mx):
+ """Verify get_stats() works even if Metal stats fail."""
+ mock_mx.metal.is_available.side_effect = RuntimeError("no metal")
+
+ scheduler = _make_scheduler()
+ stats = scheduler.get_stats()
+
+ # Should still return basic stats without Metal info
+ assert "num_waiting" in stats
+ assert "metal_active_memory_gb" not in stats
diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py
index 58044fc2..18411617 100644
--- a/tests/test_native_tool_format.py
+++ b/tests/test_native_tool_format.py
@@ -36,6 +36,7 @@ def test_parsers_with_native_support(self):
GraniteToolParser,
FunctionaryToolParser,
KimiToolParser,
+ HermesToolParser,
]
for parser_cls in native_parsers:
assert (
@@ -49,7 +50,6 @@ def test_parsers_without_native_support(self):
"""Parsers that don't support native tool format should return False."""
non_native_parsers = [
QwenToolParser,
- HermesToolParser,
NemotronToolParser,
xLAMToolParser,
AutoToolParser,
@@ -65,14 +65,22 @@ def test_parsers_without_native_support(self):
def test_via_manager(self):
"""Test native format detection via ToolParserManager."""
# Native support
- for name in ["mistral", "llama", "deepseek", "granite", "functionary", "kimi"]:
+ for name in [
+ "mistral",
+ "llama",
+ "deepseek",
+ "granite",
+ "functionary",
+ "kimi",
+ "hermes",
+ ]:
parser_cls = ToolParserManager.get_tool_parser(name)
assert (
parser_cls.supports_native_format() is True
), f"Parser '{name}' should support native format"
# No native support
- for name in ["qwen", "hermes", "nemotron", "xlam", "auto"]:
+ for name in ["qwen", "nemotron", "xlam", "auto"]:
parser_cls = ToolParserManager.get_tool_parser(name)
assert (
parser_cls.supports_native_format() is False
diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py
index 778f4deb..8e36a397 100644
--- a/tests/test_prefix_cache.py
+++ b/tests/test_prefix_cache.py
@@ -211,19 +211,15 @@ def test_clear(self, cache_manager):
# Stats should also be reset
assert cache_manager.stats.hits == 0
- def test_cache_deep_copy(self, cache_manager):
- """Test that fetched cache is a deep copy."""
+ def test_cache_no_copy(self, cache_manager):
+ """Test that fetched cache is a reference (no copy) — MLX arrays are immutable."""
original = [[1, 2, 3]]
cache_manager.store_cache([1, 2], original)
cache, _ = cache_manager.fetch_cache([1, 2])
- # Modify returned cache
- cache[0].append(4)
-
- # Original should be unchanged
- cache2, _ = cache_manager.fetch_cache([1, 2])
- assert cache2[0] == [1, 2, 3]
+ # Returns the same object (no deep copy overhead)
+ assert cache is original
def test_multiple_prefixes(self, cache_manager):
"""Test multiple different prefixes."""
diff --git a/vllm_mlx/api/anthropic_adapter.py b/vllm_mlx/api/anthropic_adapter.py
new file mode 100644
index 00000000..dbb94200
--- /dev/null
+++ b/vllm_mlx/api/anthropic_adapter.py
@@ -0,0 +1,312 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Adapter for converting between Anthropic Messages API and OpenAI Chat Completions API.
+
+Handles translation of:
+- Requests: Anthropic → OpenAI format
+- Responses: OpenAI → Anthropic format
+- Messages: Content blocks, tool calls, tool results
+"""
+
+import json
+import uuid
+
+from .anthropic_models import (
+ AnthropicMessage,
+ AnthropicRequest,
+ AnthropicResponse,
+ AnthropicResponseContentBlock,
+ AnthropicToolDef,
+ AnthropicUsage,
+)
+from .models import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ Message,
+ ToolDefinition,
+)
+
+
+def anthropic_to_openai(request: AnthropicRequest) -> ChatCompletionRequest:
+ """
+ Convert an Anthropic Messages API request to OpenAI Chat Completions format.
+
+ Handles:
+ - system field → system message
+ - Content blocks → OpenAI message format
+ - tool_use/tool_result → OpenAI tool_calls/tool messages
+ - Anthropic tools → OpenAI tools
+
+ Args:
+ request: Anthropic Messages API request
+
+ Returns:
+ OpenAI ChatCompletionRequest
+ """
+ messages = []
+
+ # Convert system to system message
+ if request.system:
+ if isinstance(request.system, str):
+ system_text = request.system
+ elif isinstance(request.system, list):
+ # System can be a list of content blocks
+ parts = []
+ for block in request.system:
+ if isinstance(block, dict) and block.get("type") == "text":
+ parts.append(block.get("text", ""))
+ elif isinstance(block, str):
+ parts.append(block)
+ system_text = "\n".join(parts)
+ else:
+ system_text = str(request.system)
+ messages.append(Message(role="system", content=system_text))
+
+ # Convert each message
+ for msg in request.messages:
+ converted = _convert_message(msg)
+ messages.extend(converted)
+
+ # Convert tools
+ tools = None
+ if request.tools:
+ tools = [_convert_tool(t) for t in request.tools]
+
+ # Convert tool_choice
+ tool_choice = None
+ if request.tool_choice:
+ tool_choice = _convert_tool_choice(request.tool_choice)
+
+ return ChatCompletionRequest(
+ model=request.model,
+ messages=messages,
+ max_tokens=request.max_tokens,
+ temperature=request.temperature if request.temperature is not None else 0.7,
+ top_p=request.top_p if request.top_p is not None else 0.9,
+ stream=request.stream,
+ stop=request.stop_sequences,
+ tools=tools,
+ tool_choice=tool_choice,
+ )
+
+
+def openai_to_anthropic(
+ response: ChatCompletionResponse,
+ model: str,
+) -> AnthropicResponse:
+ """
+ Convert an OpenAI Chat Completions response to Anthropic Messages API format.
+
+ Args:
+ response: OpenAI ChatCompletionResponse
+ model: Model name for the response
+
+ Returns:
+ Anthropic Messages API response
+ """
+ content = []
+ choice = response.choices[0] if response.choices else None
+
+ if choice:
+ # Add text content
+ if choice.message.content:
+ content.append(
+ AnthropicResponseContentBlock(
+ type="text",
+ text=choice.message.content,
+ )
+ )
+
+ # Add tool use blocks
+ if choice.message.tool_calls:
+ for tc in choice.message.tool_calls:
+ try:
+ tool_input = json.loads(tc.function.arguments)
+ except (json.JSONDecodeError, AttributeError):
+ tool_input = {}
+
+ content.append(
+ AnthropicResponseContentBlock(
+ type="tool_use",
+ id=tc.id,
+ name=tc.function.name,
+ input=tool_input,
+ )
+ )
+
+ stop_reason = _convert_stop_reason(choice.finish_reason)
+ else:
+ stop_reason = "end_turn"
+
+ # If no content blocks, add empty text
+ if not content:
+ content.append(AnthropicResponseContentBlock(type="text", text=""))
+
+ return AnthropicResponse(
+ model=model,
+ content=content,
+ stop_reason=stop_reason,
+ usage=AnthropicUsage(
+ input_tokens=response.usage.prompt_tokens if response.usage else 0,
+ output_tokens=response.usage.completion_tokens if response.usage else 0,
+ ),
+ )
+
+
+def _convert_message(msg: AnthropicMessage) -> list[Message]:
+ """
+ Convert an Anthropic message to one or more OpenAI messages.
+
+ Anthropic tool_result blocks (sent as user messages) need to be
+ split into separate OpenAI tool messages.
+
+ Args:
+ msg: Anthropic message
+
+ Returns:
+ List of OpenAI messages
+ """
+ # Simple string content
+ if isinstance(msg.content, str):
+ return [Message(role=msg.role, content=msg.content)]
+
+ # Content is a list of blocks
+ messages = []
+ text_parts = []
+ tool_calls_for_assistant = []
+ tool_results = []
+
+ for block in msg.content:
+ if block.type == "text":
+ text_parts.append(block.text or "")
+
+ elif block.type == "tool_use":
+ # Assistant message with tool calls
+ tool_input = block.input or {}
+ tool_calls_for_assistant.append(
+ {
+ "id": block.id or f"call_{uuid.uuid4().hex[:8]}",
+ "type": "function",
+ "function": {
+ "name": block.name or "",
+ "arguments": json.dumps(tool_input),
+ },
+ }
+ )
+
+ elif block.type == "tool_result":
+ # Tool result → OpenAI tool message
+ result_content = block.content
+ if isinstance(result_content, list):
+ # Extract text from content blocks
+ parts = []
+ for item in result_content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ parts.append(item.get("text", ""))
+ elif isinstance(item, str):
+ parts.append(item)
+ result_content = "\n".join(parts)
+ elif result_content is None:
+ result_content = ""
+
+ tool_results.append(
+ Message(
+ role="tool",
+ content=str(result_content),
+ tool_call_id=block.tool_use_id or "",
+ )
+ )
+
+ # Build the messages
+ if msg.role == "assistant":
+ combined_text = "\n".join(text_parts) if text_parts else None
+ if tool_calls_for_assistant:
+ messages.append(
+ Message(
+ role="assistant",
+ content=combined_text or "",
+ tool_calls=tool_calls_for_assistant,
+ )
+ )
+ elif combined_text is not None:
+ messages.append(Message(role="assistant", content=combined_text))
+ else:
+ messages.append(Message(role="assistant", content=""))
+ elif msg.role == "user":
+ # User messages: collect text parts, then add tool results separately
+ if text_parts:
+ combined_text = "\n".join(text_parts)
+ messages.append(Message(role="user", content=combined_text))
+
+ # Tool results become separate tool messages
+ messages.extend(tool_results)
+
+ # If no text and no tool results, add empty user message
+ if not text_parts and not tool_results:
+ messages.append(Message(role="user", content=""))
+ else:
+ # Other roles
+ combined_text = "\n".join(text_parts) if text_parts else ""
+ messages.append(Message(role=msg.role, content=combined_text))
+
+ return messages
+
+
+def _convert_tool(tool: AnthropicToolDef) -> ToolDefinition:
+ """
+ Convert an Anthropic tool definition to OpenAI format.
+
+ Anthropic: {"name": "...", "description": "...", "input_schema": {...}}
+ OpenAI: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}
+ """
+ return ToolDefinition(
+ type="function",
+ function={
+ "name": tool.name,
+ "description": tool.description or "",
+ "parameters": tool.input_schema or {"type": "object", "properties": {}},
+ },
+ )
+
+
+def _convert_tool_choice(tool_choice: dict) -> str | dict | None:
+ """
+ Convert Anthropic tool_choice to OpenAI format.
+
+ Anthropic: {"type": "auto"} | {"type": "any"} | {"type": "tool", "name": "..."}
+ OpenAI: "auto" | "none" | "required" | {"type": "function", "function": {"name": "..."}}
+ """
+ choice_type = tool_choice.get("type", "auto")
+
+ if choice_type == "auto":
+ return "auto"
+ elif choice_type == "any":
+ return "required"
+ elif choice_type == "tool":
+ return {
+ "type": "function",
+ "function": {"name": tool_choice.get("name", "")},
+ }
+ elif choice_type == "none":
+ return "none"
+
+ return "auto"
+
+
+def _convert_stop_reason(openai_reason: str | None) -> str:
+ """
+ Convert OpenAI finish_reason to Anthropic stop_reason.
+
+ OpenAI: "stop" | "tool_calls" | "length" | "content_filter"
+ Anthropic: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence"
+ """
+ if openai_reason is None:
+ return "end_turn"
+
+ mapping = {
+ "stop": "end_turn",
+ "tool_calls": "tool_use",
+ "length": "max_tokens",
+ "content_filter": "end_turn",
+ }
+ return mapping.get(openai_reason, "end_turn")
diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py
new file mode 100644
index 00000000..a5bc6f77
--- /dev/null
+++ b/vllm_mlx/api/anthropic_models.py
@@ -0,0 +1,105 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Pydantic models for Anthropic Messages API.
+
+These models define the request and response schemas for the
+Anthropic-compatible /v1/messages endpoint, enabling clients like
+Claude Code to communicate with vllm-mlx.
+"""
+
+import uuid
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+# =============================================================================
+# Request Models
+# =============================================================================
+
+
+class AnthropicContentBlock(BaseModel):
+ """A content block in an Anthropic message."""
+
+ type: str # "text", "image", "tool_use", "tool_result"
+ # text block
+ text: str | None = None
+ # tool_use block
+ id: str | None = None
+ name: str | None = None
+ input: dict | None = None
+ # tool_result block
+ tool_use_id: str | None = None
+ content: str | list | None = None
+ is_error: bool | None = None
+ # image block
+ source: dict | None = None
+
+
+class AnthropicMessage(BaseModel):
+ """A message in an Anthropic conversation."""
+
+ role: str # "user" | "assistant"
+ content: str | list[AnthropicContentBlock]
+
+
+class AnthropicToolDef(BaseModel):
+ """Definition of a tool in Anthropic format."""
+
+ name: str
+ description: str | None = None
+ input_schema: dict | None = None
+
+
+class AnthropicRequest(BaseModel):
+ """Request for Anthropic Messages API."""
+
+ model: str
+ messages: list[AnthropicMessage]
+ system: str | list[dict] | None = None
+ max_tokens: int # Required in Anthropic API
+ temperature: float | None = None
+ top_p: float | None = None
+ stream: bool = False
+ stop_sequences: list[str] | None = None
+ tools: list[AnthropicToolDef] | None = None
+ tool_choice: dict | None = None
+ metadata: dict | None = None
+ top_k: int | None = None
+
+
+# =============================================================================
+# Response Models
+# =============================================================================
+
+
+class AnthropicUsage(BaseModel):
+ """Token usage for Anthropic response."""
+
+ input_tokens: int = 0
+ output_tokens: int = 0
+ cache_creation_input_tokens: int | None = None
+ cache_read_input_tokens: int | None = None
+
+
+class AnthropicResponseContentBlock(BaseModel):
+ """A content block in the Anthropic response."""
+
+ type: str # "text" or "tool_use"
+ text: str | None = None
+ # tool_use fields
+ id: str | None = None
+ name: str | None = None
+ input: Any | None = None
+
+
+class AnthropicResponse(BaseModel):
+ """Response for Anthropic Messages API."""
+
+ id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:24]}")
+ type: str = "message"
+ role: str = "assistant"
+ model: str
+ content: list[AnthropicResponseContentBlock]
+ stop_reason: str | None = None
+ stop_sequence: str | None = None
+ usage: AnthropicUsage = Field(default_factory=AnthropicUsage)
diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py
index f9aa4be6..1443c167 100644
--- a/vllm_mlx/api/tool_calling.py
+++ b/vllm_mlx/api/tool_calling.py
@@ -82,7 +82,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]:
return tool_calls if tool_calls else None
-def parse_tool_calls(text: str) -> Tuple[str, Optional[List[ToolCall]]]:
+def parse_tool_calls(
+ text: str, request: dict[str, Any] | None = None
+) -> Tuple[str, Optional[List[ToolCall]]]:
"""
Parse tool calls from model output.
@@ -144,7 +146,13 @@ def parse_tool_calls(text: str) -> Tuple[str, Optional[List[ToolCall]]]:
# Parse parameters from value format
param_pattern = r"]+)>\s*(.*?)\s*"
params = re.findall(param_pattern, params_block, re.DOTALL)
- arguments = {p_name.strip(): p_value.strip() for p_name, p_value in params}
+ arguments = {}
+ for p_name, p_value in params:
+ val = p_value.strip()
+ try:
+ arguments[p_name.strip()] = json.loads(val)
+ except (json.JSONDecodeError, ValueError):
+ arguments[p_name.strip()] = val
tool_calls.append(
ToolCall(
diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py
index 91332bfa..b1d5005f 100644
--- a/vllm_mlx/api/utils.py
+++ b/vllm_mlx/api/utils.py
@@ -185,11 +185,28 @@ def extract_multimodal_content(
tool_calls_list = []
for tc in tool_calls:
if isinstance(tc, dict):
- tool_calls_list.append(tc)
+ tc_copy = tc
elif hasattr(tc, "model_dump"):
- tool_calls_list.append(tc.model_dump())
+ tc_copy = tc.model_dump()
elif hasattr(tc, "dict"):
- tool_calls_list.append(tc.dict())
+ tc_copy = tc.dict()
+ else:
+ continue
+
+ # Chat templates (e.g. Qwen3) iterate arguments|items,
+ # but OpenAI API sends arguments as a JSON string.
+ # Parse it into a dict so the template can iterate it.
+ func = tc_copy.get("function") or {}
+ args = func.get("arguments")
+ if isinstance(args, str):
+ try:
+ import json
+
+ func["arguments"] = json.loads(args)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ tool_calls_list.append(tc_copy)
msg_dict = {"role": role, "content": content if content else ""}
if tool_calls_list:
diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py
index 700b35b7..89ef5343 100644
--- a/vllm_mlx/cli.py
+++ b/vllm_mlx/cli.py
@@ -139,9 +139,13 @@ def serve_command(args):
use_paged_cache=args.use_paged_cache,
paged_cache_block_size=args.paged_cache_block_size,
max_cache_blocks=args.max_cache_blocks,
+ # Chunked prefill
+ chunked_prefill_tokens=args.chunked_prefill_tokens,
)
print("Mode: Continuous batching (for multiple concurrent users)")
+ if args.chunked_prefill_tokens > 0:
+ print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step")
print(f"Stream interval: {args.stream_interval} tokens")
if args.use_paged_cache:
print(
@@ -502,6 +506,14 @@ def main():
default=1000,
help="Maximum number of cache blocks (default: 1000)",
)
+ # Chunked prefill
+ serve_parser.add_argument(
+ "--chunked-prefill-tokens",
+ type=int,
+ default=0,
+ help="Max prefill tokens per scheduler step (0=disabled). "
+ "Prevents starvation of active requests during long prefills.",
+ )
# MCP options
serve_parser.add_argument(
"--mcp-config",
@@ -574,6 +586,19 @@ def main():
f"Options: {', '.join(reasoning_choices)}."
),
)
+ # Generation defaults
+ serve_parser.add_argument(
+ "--default-temperature",
+ type=float,
+ default=None,
+ help="Override default temperature for all requests (default: use model default)",
+ )
+ serve_parser.add_argument(
+ "--default-top-p",
+ type=float,
+ default=None,
+ help="Override default top_p for all requests (default: use model default)",
+ )
# Embedding model option
serve_parser.add_argument(
"--embedding-model",
diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py
index ddb95813..3b2e973e 100644
--- a/vllm_mlx/engine/batched.py
+++ b/vllm_mlx/engine/batched.py
@@ -260,6 +260,30 @@ async def _start_llm(self) -> None:
tokenizer_config=tokenizer_config,
)
+ # Set Metal memory limits to make allocation failures graceful
+ # instead of fatal Metal command buffer errors (SIGABRT)
+ try:
+ import mlx.core as mx
+
+ if mx.metal.is_available():
+ device_info = mx.device_info()
+ max_recommended = device_info.get(
+ "max_recommended_working_set_size",
+ device_info.get("memory_size", 0),
+ )
+ if max_recommended > 0:
+ soft_limit = int(max_recommended * 0.90)
+ mx.set_memory_limit(soft_limit)
+ mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB
+ logger.info(
+ f"Metal memory limits set: "
+ f"allocation_limit={soft_limit / 1e9:.1f}GB "
+ f"(90% of {max_recommended / 1e9:.1f}GB), "
+ f"cache_limit=32GB"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to set Metal memory limits: {e}")
+
# Create engine config
scheduler_config = self._scheduler_config or SchedulerConfig()
engine_config = EngineConfig(
@@ -345,10 +369,11 @@ def _apply_chat_template(
# Fall through to standard template
if hasattr(tokenizer, "apply_chat_template"):
+ enable_thinking = "coder" not in self._model_name.lower()
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
- "enable_thinking": True,
+ "enable_thinking": enable_thinking,
}
if tools:
template_kwargs["tools"] = tools
@@ -498,9 +523,11 @@ async def stream_generate(
stop=stop or [],
)
+ prefix_boundary = kwargs.pop("prefix_boundary", 0)
request_id = await self._engine.add_request(
prompt=prompt,
sampling_params=sampling_params,
+ prefix_boundary=prefix_boundary,
)
async for output in self._engine.stream_outputs(request_id):
@@ -575,6 +602,57 @@ async def chat(
**kwargs,
)
+ def _compute_prefix_boundary(
+ self, messages: list[dict[str, Any]], tools: list[dict] | None = None
+ ) -> int:
+ """Compute token count for the shared prefix across message variations.
+
+ Uses a two-tokenization approach: tokenize the full prompt twice
+ (once as-is, once with the last user message replaced by a dummy)
+ and find the longest common prefix (LCP). This gives the exact
+ boundary where different user suffixes diverge, avoiding template
+ discrepancies (e.g. Qwen3 markers on last assistant).
+ """
+ # Find index of last user message
+ last_user_idx = None
+ for i in range(len(messages) - 1, -1, -1):
+ if messages[i].get("role") == "user":
+ last_user_idx = i
+ break
+ if last_user_idx is None or last_user_idx == 0:
+ return 0
+ try:
+ template_tools = convert_tools_for_template(tools) if tools else None
+
+ # Tokenize the real prompt
+ real_prompt = self._apply_chat_template(messages, template_tools)
+
+ # Build a dummy variant with different last user content
+ dummy_messages = list(messages)
+ dummy_messages[last_user_idx] = {
+ **messages[last_user_idx],
+ "content": "XXXXXXXXXX",
+ }
+ dummy_prompt = self._apply_chat_template(dummy_messages, template_tools)
+
+ tokenizer = self.tokenizer
+ if hasattr(tokenizer, "tokenizer"):
+ tokenizer = tokenizer.tokenizer
+
+ real_tokens = tokenizer.encode(real_prompt)
+ dummy_tokens = tokenizer.encode(dummy_prompt)
+
+ # Find LCP — the point where the two diverge is the boundary
+ lcp = 0
+ for j in range(min(len(real_tokens), len(dummy_tokens))):
+ if real_tokens[j] != dummy_tokens[j]:
+ break
+ lcp = j + 1
+
+ return lcp
+ except Exception:
+ return 0
+
async def stream_chat(
self,
messages: list[dict[str, Any]],
@@ -625,6 +703,11 @@ async def stream_chat(
num_images=len(all_images),
)
+ # Compute prefix boundary for cache
+ prefix_boundary = self._compute_prefix_boundary(messages, tools)
+ if prefix_boundary > 0:
+ kwargs["prefix_boundary"] = prefix_boundary
+
async for output in self.stream_generate(
prompt=prompt,
max_tokens=max_tokens,
@@ -660,3 +743,15 @@ def get_cache_stats(self) -> dict[str, Any] | None:
elif self._engine:
return self._engine.get_cache_stats()
return None
+
+ def save_cache_to_disk(self, cache_dir: str) -> bool:
+ """Save prefix cache to disk for persistence across restarts."""
+ if self._engine:
+ return self._engine.save_cache_to_disk(cache_dir)
+ return False
+
+ def load_cache_from_disk(self, cache_dir: str) -> int:
+ """Load prefix cache from disk. Returns number of entries loaded."""
+ if self._engine:
+ return self._engine.load_cache_from_disk(cache_dir)
+ return 0
diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py
index b174145d..aaa0ccc4 100644
--- a/vllm_mlx/engine_core.py
+++ b/vllm_mlx/engine_core.py
@@ -18,6 +18,8 @@
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, Union
+import mlx.core as mx
+
from .request import Request, RequestOutput, SamplingParams
from .scheduler import Scheduler, SchedulerConfig
from .output_collector import RequestOutputCollector, RequestStreamState
@@ -129,19 +131,68 @@ def is_running(self) -> bool:
return self._running
async def _engine_loop(self) -> None:
- """Main engine loop - optimized for minimal overhead."""
- # Cache config values for faster access
+ """Main engine loop - hybrid executor for prefill vs generation.
+
+ Prefill steps (long prompts) are run in a thread executor to keep
+ the asyncio event loop responsive. Generation-only steps (~1-3ms)
+ are called directly to avoid ~0.5-2ms context switch overhead,
+ giving ~5-10% throughput improvement during sustained generation.
+ """
+ import concurrent.futures
+
+ # Single-thread executor ensures MLX calls are never concurrent
+ _executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=1, thread_name_prefix="mlx-step"
+ )
+ loop = asyncio.get_running_loop()
+
step_interval = self.config.step_interval
stream_interval = self.config.stream_interval
use_simple_streaming = stream_interval == 1
+ # Emergency memory pressure threshold (200GB)
+ _memory_pressure_threshold = 200 * 1024 * 1024 * 1024
+ _memory_check_interval = 64
+
while self._running:
try:
if self.scheduler.has_requests():
- # Run one generation step
- output = self.scheduler.step()
+ # Hybrid approach: use executor only when prefill is likely.
+ # Prefill happens when there are waiting requests that need
+ # to be inserted into the batch (may block for seconds).
+ # Generation-only steps are fast (<3ms) and can run inline.
+ has_waiting = self.scheduler.get_num_waiting() > 0
+ has_partial = (
+ self.scheduler.batch_generator is not None
+ and getattr(self.scheduler.batch_generator, "_partial", None)
+ is not None
+ )
+ needs_executor = has_waiting or has_partial
+
+ if needs_executor:
+ output = await loop.run_in_executor(
+ _executor, self.scheduler.step
+ )
+ else:
+ output = self.scheduler.step()
+ # Yield to event loop after inline step
+ await asyncio.sleep(0)
self._steps_executed += 1
+ # Emergency memory pressure check
+ if self._steps_executed % _memory_check_interval == 0:
+ try:
+ active_mem = mx.get_active_memory()
+ if active_mem > _memory_pressure_threshold:
+ mx.clear_cache()
+ logger.warning(
+ f"[Memory pressure] {active_mem / 1e9:.1f}GB > "
+ f"{_memory_pressure_threshold / 1e9:.0f}GB threshold, "
+ f"forced cache clear"
+ )
+ except Exception:
+ pass
+
# Fast path: distribute outputs to collectors
outputs = output.outputs
if outputs:
@@ -171,9 +222,15 @@ async def _engine_loop(self) -> None:
if event:
event.set()
- # OPTIMIZATION: Only yield if streaming consumers are waiting
- if RequestOutputCollector.has_waiting_consumers():
- await asyncio.sleep(0)
+ # Free Metal buffers after distributing finished outputs
+ if output.finished_request_ids:
+ mx.clear_cache()
+
+ # Always yield to prevent event loop starvation.
+ # Without this, orphaned requests (client disconnected but
+ # request still in scheduler) block the entire event loop,
+ # making the server unresponsive to all HTTP requests.
+ await asyncio.sleep(0)
else:
# No work, yield control
await asyncio.sleep(step_interval)
@@ -191,6 +248,7 @@ async def add_request(
request_id: Optional[str] = None,
images: Optional[List[Any]] = None,
videos: Optional[List[Any]] = None,
+ prefix_boundary: int = 0,
) -> str:
"""
Add a request for processing.
@@ -201,6 +259,7 @@ async def add_request(
request_id: Optional custom request ID
images: Optional images for multimodal
videos: Optional videos for multimodal
+ prefix_boundary: Token count for shared prefix (for cache)
Returns:
The request ID
@@ -217,6 +276,7 @@ async def add_request(
sampling_params=sampling_params,
images=images,
videos=videos,
+ prefix_boundary=prefix_boundary,
)
# Setup output collector with stream_interval from config
@@ -264,16 +324,24 @@ async def stream_outputs(
Yields:
RequestOutput objects as tokens are generated
"""
+ import time as _time
+
+ _t0 = _time.monotonic()
+ _token_count = 0
+
collector = self._output_collectors.get(request_id)
if collector is None:
- # Request might not be added yet or already cleaned up
+ logger.warning(
+ f"[stream_outputs] {request_id[:12]} no collector found, returning immediately"
+ )
return
+ logger.info(f"[stream_outputs] {request_id[:12]} START waiting for tokens")
+
+ finished_normally = False
try:
while True:
try:
- # Non-blocking drain pattern from vLLM
- # Try get_nowait first to avoid task switch if output ready
if timeout:
output = collector.get_nowait()
if output is None:
@@ -283,17 +351,48 @@ async def stream_outputs(
else:
output = collector.get_nowait() or await collector.get()
+ _token_count += 1
+ if _token_count == 1:
+ logger.info(
+ f"[stream_outputs] {request_id[:12]} first token after "
+ f"{_time.monotonic() - _t0:.1f}s"
+ )
+
yield output
if output.finished:
+ finished_normally = True
+ logger.info(
+ f"[stream_outputs] {request_id[:12]} finished normally, "
+ f"{_token_count} tokens in {_time.monotonic() - _t0:.1f}s"
+ )
break
except asyncio.TimeoutError:
- logger.warning(f"Timeout waiting for request {request_id}")
+ logger.warning(
+ f"[stream_outputs] {request_id[:12]} TIMEOUT after "
+ f"{_token_count} tokens, {_time.monotonic() - _t0:.1f}s"
+ )
break
+ except (GeneratorExit, asyncio.CancelledError) as exc:
+ logger.info(
+ f"[stream_outputs] {request_id[:12]} {type(exc).__name__} after "
+ f"{_token_count} tokens, {_time.monotonic() - _t0:.1f}s"
+ )
+
finally:
+ if not finished_normally:
+ logger.info(
+ f"[stream_outputs] {request_id[:12]} ABORTING orphaned request "
+ f"({_token_count} tokens generated in {_time.monotonic() - _t0:.1f}s)"
+ )
+ aborted = self.scheduler.abort_request(request_id)
+ logger.info(
+ f"[stream_outputs] {request_id[:12]} abort_request returned {aborted}"
+ )
self._cleanup_request(request_id)
+ logger.info(f"[stream_outputs] {request_id[:12]} cleanup done")
async def generate(
self,
@@ -329,29 +428,35 @@ async def generate(
if event is None:
raise RuntimeError(f"No event for request {request_id}")
- # Wait for the request to finish
- await event.wait()
+ try:
+ # Wait for the request to finish
+ await event.wait()
- # Get the final output from collector
- collector = self._output_collectors.get(request_id)
- if collector is None:
- raise RuntimeError(f"No collector for request {request_id}")
+ # Get the final output from collector
+ collector = self._output_collectors.get(request_id)
+ if collector is None:
+ raise RuntimeError(f"No collector for request {request_id}")
- # Drain all outputs and get the last one
- final_output = None
- while True:
- output = collector.get_nowait()
- if output is None:
- break
- final_output = output
+ # Drain all outputs and get the last one
+ final_output = None
+ while True:
+ output = collector.get_nowait()
+ if output is None:
+ break
+ final_output = output
- # Cleanup
- self._cleanup_request(request_id)
+ if final_output is None:
+ raise RuntimeError(f"No output for request {request_id}")
+
+ return final_output
- if final_output is None:
- raise RuntimeError(f"No output for request {request_id}")
+ except (asyncio.CancelledError, GeneratorExit):
+ logger.info(f"[generate] {request_id[:12]} CANCELLED, aborting request")
+ self.scheduler.abort_request(request_id)
+ raise
- return final_output
+ finally:
+ self._cleanup_request(request_id)
def generate_batch_sync(
self,
@@ -416,6 +521,7 @@ def get_stats(self) -> Dict[str, Any]:
"steps_executed": self._steps_executed,
"active_requests": len(self._output_collectors),
"stream_interval": self.config.stream_interval,
+ "requests": self.scheduler.get_running_requests_info(),
**scheduler_stats,
}
@@ -423,6 +529,14 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]:
"""Get prefix cache statistics."""
return self.scheduler.get_cache_stats()
+ def save_cache_to_disk(self, cache_dir: str) -> bool:
+ """Save prefix cache to disk."""
+ return self.scheduler.save_cache_to_disk(cache_dir)
+
+ def load_cache_from_disk(self, cache_dir: str) -> int:
+ """Load prefix cache from disk."""
+ return self.scheduler.load_cache_from_disk(cache_dir)
+
def _release_model(self) -> None:
"""Release model ownership."""
if self._owns_model and not self._closed:
@@ -559,3 +673,11 @@ def get_stats(self) -> Dict[str, Any]:
def get_cache_stats(self) -> Optional[Dict[str, Any]]:
"""Get prefix cache statistics."""
return self.engine.get_cache_stats()
+
+ def save_cache_to_disk(self, cache_dir: str) -> bool:
+ """Save prefix cache to disk."""
+ return self.engine.save_cache_to_disk(cache_dir)
+
+ def load_cache_from_disk(self, cache_dir: str) -> int:
+ """Load prefix cache from disk."""
+ return self.engine.load_cache_from_disk(cache_dir)
diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py
index f414f321..902f33f7 100644
--- a/vllm_mlx/memory_cache.py
+++ b/vllm_mlx/memory_cache.py
@@ -24,7 +24,9 @@
from __future__ import annotations
+import bisect
import logging
+import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any
@@ -57,12 +59,38 @@ def _get_available_memory() -> int:
return 0
+def _array_memory(arr) -> int:
+ """
+ Estimate array memory from shape+dtype without triggering lazy eval.
+
+ Accessing .nbytes on a lazy MLX array forces evaluation of the entire
+ computation graph, causing a VRAM spike. This function uses shape and
+ dtype metadata (which are always available without eval) to compute
+ the same value.
+
+ Args:
+ arr: An MLX array or similar object.
+
+ Returns:
+ Estimated memory in bytes.
+ """
+ if hasattr(arr, "shape") and hasattr(arr, "dtype"):
+ dtype = arr.dtype
+ if hasattr(dtype, "size"):
+ return math.prod(arr.shape) * dtype.size
+ # Fallback for non-MLX arrays or objects without shape/dtype
+ if hasattr(arr, "nbytes"):
+ return arr.nbytes
+ return 0
+
+
def estimate_kv_cache_memory(cache: list[Any]) -> int:
"""
Estimate memory usage of a KV cache in bytes.
This function inspects MLX arrays in the cache and calculates their
- total memory footprint.
+ total memory footprint using shape+dtype metadata to avoid triggering
+ lazy evaluation (which would cause a VRAM spike).
Args:
cache: List of layer cache objects, each containing keys/values tensors.
@@ -81,18 +109,14 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
if isinstance(layer_cache, dict) and "state" in layer_cache:
# Extracted state dict
keys, values = layer_cache["state"]
- if hasattr(keys, "nbytes"):
- total_bytes += keys.nbytes
- if hasattr(values, "nbytes"):
- total_bytes += values.nbytes
+ total_bytes += _array_memory(keys)
+ total_bytes += _array_memory(values)
elif hasattr(layer_cache, "state") and not isinstance(layer_cache, dict):
# Cache with state property returning (keys, values)
try:
keys, values = layer_cache.state
- if hasattr(keys, "nbytes"):
- total_bytes += keys.nbytes
- if hasattr(values, "nbytes"):
- total_bytes += values.nbytes
+ total_bytes += _array_memory(keys)
+ total_bytes += _array_memory(values)
except (TypeError, ValueError):
pass
elif hasattr(layer_cache, "keys") and hasattr(layer_cache, "values"):
@@ -100,10 +124,10 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
keys_attr = layer_cache.keys
values_attr = layer_cache.values
# Ensure these are arrays, not methods
- if not callable(keys_attr) and hasattr(keys_attr, "nbytes"):
- total_bytes += keys_attr.nbytes
- if not callable(values_attr) and hasattr(values_attr, "nbytes"):
- total_bytes += values_attr.nbytes
+ if not callable(keys_attr):
+ total_bytes += _array_memory(keys_attr)
+ if not callable(values_attr):
+ total_bytes += _array_memory(values_attr)
return total_bytes
@@ -209,6 +233,28 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry:
)
+def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
+ """Create shallow copies of KVCache layers with offset reduced by *trim_by*.
+
+ This is used when returning a cached KV state to the scheduler so that
+ the last N positions are "freed" and the model will recompute them on the
+ next forward pass (preventing duplicate KV entries).
+ """
+ from mlx_lm.models.cache import KVCache
+
+ trimmed: list[Any] = []
+ for layer_cache in cache:
+ if hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"):
+ tc = KVCache.__new__(KVCache)
+ tc.keys = layer_cache.keys
+ tc.values = layer_cache.values
+ tc.offset = max(layer_cache.offset - trim_by, 0)
+ trimmed.append(tc)
+ else:
+ trimmed.append(layer_cache)
+ return trimmed
+
+
class MemoryAwarePrefixCache:
"""
Prefix cache with memory-based eviction.
@@ -246,6 +292,11 @@ def __init__(
# Key: tuple(tokens), Value: _CacheEntry
self._entries: OrderedDict[tuple[int, ...], _CacheEntry] = OrderedDict()
+ # Sorted index of token keys for efficient prefix/supersequence lookup.
+ # Tuple lexicographic ordering means a prefix key P is always < any
+ # extension of P, so bisect gives O(log N) range scans instead of O(N).
+ self._sorted_keys: list[tuple[int, ...]] = []
+
# Memory tracking
self._max_memory = self._config.compute_memory_limit()
self._current_memory = 0
@@ -253,6 +304,9 @@ def __init__(
# Statistics
self._stats = CacheStats(max_memory_bytes=self._max_memory)
+ # Track the match type from the last fetch() call
+ self._last_match_type: str | None = None
+
logger.info(
f"MemoryAwarePrefixCache initialized: "
f"max_memory={self._max_memory / _BYTES_PER_MB:.1f}MB, "
@@ -263,7 +317,10 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
"""
Find cached KV state for the given tokens.
- This method searches for exact matches and prefix matches.
+ This method searches for exact matches, prefix matches, supersequence
+ matches, and longest-common-prefix (LCP) matches. Uses a sorted key
+ index for O(log N) lookup instead of scanning all entries.
+
Returns the cached KV state directly (no copy) since MLX arrays
are immutable and safe to share.
@@ -277,47 +334,175 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
"""
if not tokens:
self._stats.misses += 1
+ self._last_match_type = "miss"
return None, tokens
tokens_key = tuple(tokens)
- # Check for exact match
+ # --- O(1) exact match ---
if tokens_key in self._entries:
entry = self._entries[tokens_key]
- # Move to end (most recently used)
self._entries.move_to_end(tokens_key)
self._stats.hits += 1
self._stats.tokens_saved += len(tokens)
- # Return reference directly - MLX arrays are immutable
+ self._last_match_type = "exact"
return entry.cache, []
- # Check for prefix matches (shorter cached sequences)
+ # --- O(log N) prefix & supersequence match via sorted index ---
best_match: _CacheEntry | None = None
best_length = 0
+ best_super: _CacheEntry | None = None
+
+ sorted_keys = self._sorted_keys
+ if sorted_keys:
+ # Find insertion point for tokens_key in the sorted list.
+ # Keys that are prefixes of tokens_key or supersequences will be
+ # clustered around this position due to lexicographic ordering.
+ idx = bisect.bisect_left(sorted_keys, tokens_key)
+
+ # Scan backwards from idx to find cached keys that are PREFIXES
+ # of tokens_key (shorter cached sequences). A prefix P of T
+ # satisfies P <= T lexicographically, so P is at idx-1 or earlier.
+ for i in range(idx - 1, -1, -1):
+ cached_key = sorted_keys[i]
+ cached_len = len(cached_key)
+ if cached_len >= len(tokens_key):
+ continue # Not a prefix (same length or longer)
+ # Check if cached_key is a prefix of tokens_key
+ if tokens_key[:cached_len] == cached_key:
+ if cached_len > best_length:
+ best_match = self._entries[cached_key]
+ best_length = cached_len
+ # Found best prefix — shorter entries can't be longer
+ break
+ # Once we go past the prefix range, stop
+ if cached_key[0] != tokens_key[0]:
+ break
+
+ # Scan forward from idx to find cached keys that are SUPERSEQUENCES
+ # of tokens_key (longer cached sequences starting with tokens_key).
+ for i in range(idx, len(sorted_keys)):
+ cached_key = sorted_keys[i]
+ cached_len = len(cached_key)
+ if cached_len < len(tokens_key):
+ continue
+ # Check if tokens_key is a prefix of cached_key
+ if cached_key[: len(tokens_key)] == tokens_key:
+ if best_super is None or cached_len > len(best_super.tokens):
+ best_super = self._entries[cached_key]
+ else:
+ # Past the supersequence range
+ break
+
+ # --- Supersequence match handling ---
+ if best_super is not None:
+ n_cached = len(best_super.tokens)
+ n_requested = len(tokens)
+ excess = n_cached - n_requested
+
+ has_non_trimmable = any(
+ not (hasattr(lc, "offset") and hasattr(lc, "keys"))
+ for lc in best_super.cache
+ )
- for cached_key, entry in self._entries.items():
- cached_len = len(cached_key)
- # Check if cached sequence is a prefix of requested tokens
- if (
- cached_len < len(tokens)
- and cached_len > best_length
- and tokens_key[:cached_len] == cached_key
- ):
- best_match = entry
- best_length = cached_len
-
+ if excess > 0 and has_non_trimmable:
+ logger.debug(
+ "[cache_fetch] supersequence match skipped: "
+ "non-trimmable cache layers (hybrid model)"
+ )
+ elif excess > 0:
+ trimmed_cache = _trim_cache_offset(best_super.cache, excess)
+ self._entries.move_to_end(best_super.tokens)
+ self._stats.hits += 1
+ self._stats.tokens_saved += n_requested
+ self._last_match_type = "supersequence"
+ return trimmed_cache, []
+ else:
+ self._entries.move_to_end(best_super.tokens)
+ self._stats.hits += 1
+ self._stats.tokens_saved += n_requested
+ self._last_match_type = "supersequence"
+ return best_super.cache, []
+
+ # --- Prefix match ---
if best_match is not None:
- # Move matched entry to end (most recently used)
self._entries.move_to_end(best_match.tokens)
self._stats.hits += 1
self._stats.tokens_saved += best_length
remaining = tokens[best_length:]
+ self._last_match_type = "prefix"
return best_match.cache, remaining
+ # --- LCP (Longest Common Prefix) for divergent sequences ---
+ # This handles the agentic pattern: same system+context prefix
+ # but different final user message. Use the sorted index to find
+ # the nearest neighbor which likely shares the longest prefix.
+ best_lcp_entry: _CacheEntry | None = None
+ best_lcp_length = 0
+
+ if sorted_keys:
+ idx = bisect.bisect_left(sorted_keys, tokens_key)
+ # Check neighbors around insertion point (they share the most
+ # common prefix due to lexicographic ordering).
+ for i in (idx - 1, idx):
+ if i < 0 or i >= len(sorted_keys):
+ continue
+ cached_key = sorted_keys[i]
+ if cached_key == tokens_key:
+ continue # Skip exact (already handled)
+ min_len = min(len(cached_key), len(tokens_key))
+ if min_len <= best_lcp_length:
+ continue
+ # Compute LCP length
+ lcp = 0
+ for j in range(min_len):
+ if cached_key[j] != tokens_key[j]:
+ break
+ lcp = j + 1
+ if lcp > best_lcp_length:
+ best_lcp_entry = self._entries[cached_key]
+ best_lcp_length = lcp
+ logger.debug(
+ f"[cache_fetch] LCP scan: cached_len={len(cached_key)} "
+ f"req_len={len(tokens_key)} lcp={lcp}"
+ )
+
+ if best_lcp_entry is not None and best_lcp_length > 0:
+ excess = len(best_lcp_entry.tokens) - best_lcp_length
+
+ has_non_trimmable = any(
+ not (hasattr(lc, "offset") and hasattr(lc, "keys"))
+ for lc in best_lcp_entry.cache
+ )
+ logger.debug(
+ f"[cache_fetch] LCP candidate: lcp={best_lcp_length} "
+ f"entry_len={len(best_lcp_entry.tokens)} excess={excess} "
+ f"non_trimmable={has_non_trimmable} "
+ f"cache_layers={len(best_lcp_entry.cache)} "
+ f"layer_types={[type(lc).__name__ for lc in best_lcp_entry.cache[:3]]}"
+ )
+
+ if not has_non_trimmable:
+ trimmed_cache = _trim_cache_offset(best_lcp_entry.cache, excess)
+ self._entries.move_to_end(best_lcp_entry.tokens)
+ self._stats.hits += 1
+ self._stats.tokens_saved += best_lcp_length
+ remaining = tokens[best_lcp_length:]
+ logger.debug(
+ f"[cache_fetch] LCP hit: shared={best_lcp_length} "
+ f"trimmed={excess} remaining={len(remaining)}"
+ )
+ self._last_match_type = "lcp"
+ return trimmed_cache, remaining
+
self._stats.misses += 1
+ self._last_match_type = "miss"
+
return None, tokens
- def store(self, tokens: list[int], cache: list[Any]) -> bool:
+ def store(
+ self, tokens: list[int], cache: list[Any], evict_prefixes: bool = True
+ ) -> bool:
"""
Store KV cache for future reuse.
@@ -328,6 +513,11 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool:
Args:
tokens: Token sequence that was processed.
cache: The computed KV cache to store.
+ evict_prefixes: If True, evict existing entries whose token
+ sequence is a strict prefix of ``tokens``. Set to False
+ when storing prompt+output entries to preserve prompt-only
+ entries created by prompt_cache_save (those are the entries
+ that future requests will actually match).
Returns:
True if stored successfully, False if rejected.
@@ -353,6 +543,36 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool:
)
return False
+ # Prefix-subset eviction: remove entries whose token sequence
+ # is a strict prefix of the new entry. Uses sorted index for
+ # O(log N + K) lookup instead of O(N) scan.
+ if evict_prefixes and self._sorted_keys:
+ to_remove = []
+ idx = bisect.bisect_left(self._sorted_keys, tokens_key)
+ # Scan backwards — prefixes of tokens_key are immediately before idx
+ for i in range(idx - 1, -1, -1):
+ key = self._sorted_keys[i]
+ klen = len(key)
+ if klen >= len(tokens_key):
+ continue
+ if tokens_key[:klen] == key:
+ to_remove.append(key)
+ elif key[0] != tokens_key[0]:
+ break
+ for key in to_remove:
+ old = self._entries.pop(key)
+ self._current_memory -= old.memory_bytes
+ self._stats.evictions += 1
+ self._remove_from_sorted(key)
+ logger.debug(
+ f"[prefix_evict] removed {len(key)} tokens, "
+ f"freed {old.memory_bytes / _BYTES_PER_MB:.2f}MB, "
+ f"new_entry={len(tokens_key)} tokens"
+ )
+ if to_remove:
+ self._stats.entry_count = len(self._entries)
+ self._stats.current_memory_bytes = self._current_memory
+
# Evict until we have room
while (
self._current_memory + entry.memory_bytes > self._max_memory
@@ -363,6 +583,7 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool:
# Store entry
self._entries[tokens_key] = entry
self._current_memory += entry.memory_bytes
+ bisect.insort(self._sorted_keys, tokens_key)
self._stats.entry_count = len(self._entries)
self._stats.current_memory_bytes = self._current_memory
@@ -374,6 +595,12 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool:
return True
+ def _remove_from_sorted(self, key: tuple[int, ...]) -> None:
+ """Remove a key from the sorted index using bisect for O(log N)."""
+ idx = bisect.bisect_left(self._sorted_keys, key)
+ if idx < len(self._sorted_keys) and self._sorted_keys[idx] == key:
+ self._sorted_keys.pop(idx)
+
def _evict_lru(self) -> None:
"""Evict the least recently used entry."""
if not self._entries:
@@ -382,12 +609,13 @@ def _evict_lru(self) -> None:
# popitem(last=False) removes oldest entry (FIFO order = LRU)
tokens_key, entry = self._entries.popitem(last=False)
self._current_memory -= entry.memory_bytes
+ self._remove_from_sorted(tokens_key)
self._stats.evictions += 1
self._stats.entry_count = len(self._entries)
self._stats.current_memory_bytes = self._current_memory
logger.debug(
- f"Evicted cache: {len(tokens_key)} tokens, "
+ f"[lru_evict] removed {len(tokens_key)} tokens, "
f"freed {entry.memory_bytes / _BYTES_PER_MB:.2f}MB"
)
@@ -405,6 +633,7 @@ def remove(self, tokens: list[int]) -> bool:
entry = self._entries.pop(tokens_key, None)
if entry is not None:
self._current_memory -= entry.memory_bytes
+ self._remove_from_sorted(tokens_key)
self._stats.entry_count = len(self._entries)
self._stats.current_memory_bytes = self._current_memory
return True
@@ -413,6 +642,7 @@ def remove(self, tokens: list[int]) -> bool:
def clear(self) -> None:
"""Clear all cached entries."""
self._entries.clear()
+ self._sorted_keys.clear()
self._current_memory = 0
self._stats = CacheStats(max_memory_bytes=self._max_memory)
logger.debug("Cache cleared")
@@ -446,3 +676,184 @@ def __len__(self) -> int:
def __contains__(self, tokens: list[int]) -> bool:
"""Check if tokens are cached."""
return tuple(tokens) in self._entries
+
+ # -----------------------------------------------------------------
+ # Disk persistence — survives server restarts
+ # -----------------------------------------------------------------
+
+ def save_to_disk(self, cache_dir: str) -> bool:
+ """Save all cache entries to disk using mlx_lm's safetensors format.
+
+ Directory layout::
+
+ cache_dir/
+ index.json # token keys + metadata per entry
+ entry_0.safetensors # KV arrays for entry 0
+ entry_1.safetensors
+ ...
+
+ Returns True if at least one entry was saved.
+ """
+ import json
+ import os
+ import time as _time
+
+ if not self._entries:
+ logger.info("[cache_persist] nothing to save (0 entries)")
+ return False
+
+ t0 = _time.monotonic()
+ os.makedirs(cache_dir, exist_ok=True)
+
+ try:
+ from mlx_lm.models.cache import save_prompt_cache
+ except ImportError:
+ logger.warning("[cache_persist] mlx_lm not available, cannot save")
+ return False
+
+ index = {
+ "version": 2,
+ "num_entries": len(self._entries),
+ "total_memory_bytes": self._current_memory,
+ "entries": [],
+ }
+
+ saved = 0
+ for i, (tokens_key, entry) in enumerate(self._entries.items()):
+ entry_path = os.path.join(cache_dir, f"entry_{i}.safetensors")
+ try:
+ save_prompt_cache(
+ entry_path,
+ entry.cache,
+ metadata={"num_tokens": str(len(tokens_key))},
+ )
+ # Save tokens separately (can be 100K+ ints → binary is smaller)
+ tokens_path = os.path.join(cache_dir, f"entry_{i}_tokens.bin")
+ import array as _array
+
+ arr = _array.array("i", tokens_key) # 32-bit signed ints
+ with open(tokens_path, "wb") as f:
+ arr.tofile(f)
+
+ index["entries"].append(
+ {
+ "index": i,
+ "num_tokens": len(tokens_key),
+ "memory_bytes": entry.memory_bytes,
+ }
+ )
+ saved += 1
+ logger.info(
+ f"[cache_persist] saved entry {i}: "
+ f"{len(tokens_key)} tokens, "
+ f"{entry.memory_bytes / _BYTES_PER_MB:.1f}MB KV, "
+ f"file={entry_path}"
+ )
+ except Exception as e:
+ logger.warning(f"[cache_persist] failed to save entry {i}: {e}")
+
+ index_path = os.path.join(cache_dir, "index.json")
+ with open(index_path, "w") as f:
+ json.dump(index, f, indent=2)
+
+ dt = _time.monotonic() - t0
+ logger.info(
+ f"[cache_persist] SAVED {saved}/{len(self._entries)} entries "
+ f"to {cache_dir} in {dt:.1f}s "
+ f"({self._current_memory / _BYTES_PER_MB:.0f}MB total)"
+ )
+ return saved > 0
+
+ def load_from_disk(self, cache_dir: str) -> int:
+ """Load cache entries from disk.
+
+ Returns the number of entries successfully loaded.
+ """
+ import json
+ import os
+ import time as _time
+
+ index_path = os.path.join(cache_dir, "index.json")
+ if not os.path.exists(index_path):
+ logger.info(f"[cache_persist] no index at {index_path}, nothing to load")
+ return 0
+
+ t0 = _time.monotonic()
+
+ try:
+ from mlx_lm.models.cache import load_prompt_cache
+ except ImportError:
+ logger.warning("[cache_persist] mlx_lm not available, cannot load")
+ return 0
+
+ with open(index_path) as f:
+ index = json.load(f)
+
+ version = index.get("version", 1)
+ if version < 2:
+ logger.warning(f"[cache_persist] unsupported version {version}, skipping")
+ return 0
+
+ loaded = 0
+ for entry_meta in index.get("entries", []):
+ i = entry_meta["index"]
+ entry_path = os.path.join(cache_dir, f"entry_{i}.safetensors")
+ tokens_path = os.path.join(cache_dir, f"entry_{i}_tokens.bin")
+
+ if not os.path.exists(entry_path) or not os.path.exists(tokens_path):
+ logger.warning(f"[cache_persist] missing files for entry {i}, skipping")
+ continue
+
+ try:
+ # Load tokens from binary
+ import array as _array
+
+ arr = _array.array("i")
+ with open(tokens_path, "rb") as f:
+ arr.fromfile(f, entry_meta["num_tokens"])
+ tokens = list(arr)
+
+ # Load KV cache
+ cache = load_prompt_cache(entry_path)
+
+ # Estimate memory
+ memory = estimate_kv_cache_memory(cache)
+
+ # Check if it fits
+ if self._current_memory + memory > self._max_memory:
+ logger.info(
+ f"[cache_persist] entry {i} would exceed memory limit "
+ f"({(self._current_memory + memory) / _BYTES_PER_MB:.0f}MB > "
+ f"{self._max_memory / _BYTES_PER_MB:.0f}MB), stopping load"
+ )
+ break
+
+ tokens_key = tuple(tokens)
+ entry = _CacheEntry(
+ tokens=tokens_key,
+ cache=cache,
+ memory_bytes=memory,
+ )
+ self._entries[tokens_key] = entry
+ self._current_memory += memory
+ bisect.insort(self._sorted_keys, tokens_key)
+ loaded += 1
+
+ logger.info(
+ f"[cache_persist] loaded entry {i}: "
+ f"{len(tokens)} tokens, "
+ f"{memory / _BYTES_PER_MB:.1f}MB KV"
+ )
+
+ except Exception as e:
+ logger.warning(f"[cache_persist] failed to load entry {i}: {e}")
+
+ self._stats.entry_count = len(self._entries)
+ self._stats.current_memory_bytes = self._current_memory
+
+ dt = _time.monotonic() - t0
+ logger.info(
+ f"[cache_persist] LOADED {loaded} entries from {cache_dir} "
+ f"in {dt:.1f}s ({self._current_memory / _BYTES_PER_MB:.0f}MB total)"
+ )
+ return loaded
diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py
index 91067e30..fba3ae02 100644
--- a/vllm_mlx/mllm_batch_generator.py
+++ b/vllm_mlx/mllm_batch_generator.py
@@ -359,7 +359,7 @@ def __init__(
self._old_wired_limit = None
if mx.metal.is_available():
self._old_wired_limit = mx.set_wired_limit(
- mx.metal.device_info()["max_recommended_working_set_size"]
+ mx.device_info()["max_recommended_working_set_size"]
)
def close(self) -> None:
diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py
index 6a086025..764f0543 100644
--- a/vllm_mlx/mllm_scheduler.py
+++ b/vllm_mlx/mllm_scheduler.py
@@ -662,15 +662,24 @@ async def stream_outputs(
if output_queue is None:
return
- while True:
- output = await output_queue.get()
- if output is None:
- break
- yield output
-
- # Cleanup queue
- if request_id in self.output_queues:
- del self.output_queues[request_id]
+ finished_normally = False
+ try:
+ while True:
+ output = await output_queue.get()
+ if output is None:
+ finished_normally = True
+ break
+ yield output
+ if output.finished:
+ finished_normally = True
+ break
+ finally:
+ if not finished_normally:
+ logger.info(f"Aborting orphaned MLLM request {request_id}")
+ self.abort_request(request_id)
+ # Cleanup queue
+ if request_id in self.output_queues:
+ del self.output_queues[request_id]
async def generate(
self,
diff --git a/vllm_mlx/optimizations.py b/vllm_mlx/optimizations.py
index ac4aafd9..624d5d77 100644
--- a/vllm_mlx/optimizations.py
+++ b/vllm_mlx/optimizations.py
@@ -86,7 +86,7 @@ def get_system_memory_gb() -> float:
except Exception:
# Fallback: try to get from MLX device info
try:
- device_info = mx.metal.device_info()
+ device_info = mx.device_info()
if "memory_size" in device_info:
return device_info["memory_size"] / (1024**3)
except Exception:
@@ -105,7 +105,7 @@ def detect_hardware() -> HardwareInfo:
HardwareInfo with detected hardware specifications
"""
try:
- device_info = mx.metal.device_info()
+ device_info = mx.device_info()
device_name = device_info.get("device_name", "")
actual_memory_gb = get_system_memory_gb()
@@ -182,7 +182,7 @@ def get_optimization_status() -> dict:
dict with hardware info and MLX configuration
"""
hw = detect_hardware()
- device_info = mx.metal.device_info()
+ device_info = mx.device_info()
flash_available = hasattr(mx, "fast") and hasattr(
mx.fast, "scaled_dot_product_attention"
)
diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py
index 3e79803b..e8f47a32 100644
--- a/vllm_mlx/prefix_cache.py
+++ b/vllm_mlx/prefix_cache.py
@@ -186,8 +186,8 @@ def fetch_cache(self, tokens: List[int]) -> Tuple[Optional[List[Any]], List[int]
self.stats.hits += 1
self.stats.tokens_saved += len(tokens)
self._touch_lru(tokens_tuple)
- # Deep copy to prevent mutation
- return copy.deepcopy(cache_entry.prompt_cache), []
+ # No copy needed - MLX arrays are immutable
+ return cache_entry.prompt_cache, []
if shorter:
# Shorter prefix cached - return cache and remaining tokens
@@ -197,7 +197,8 @@ def fetch_cache(self, tokens: List[int]) -> Tuple[Optional[List[Any]], List[int]
self.stats.tokens_saved += len(shorter)
self._touch_lru(tuple(shorter))
remaining = tokens[len(shorter) :]
- return copy.deepcopy(cache_entry.prompt_cache), remaining
+ # No copy needed - MLX arrays are immutable
+ return cache_entry.prompt_cache, remaining
if longer:
# Longer prefix cached - trim to match and return
diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py
index d7b9db1f..41679c0b 100644
--- a/vllm_mlx/request.py
+++ b/vllm_mlx/request.py
@@ -111,6 +111,7 @@ class Request:
prompt_cache: Optional[List[Any]] = None # Cached KV state from prefix cache
cached_tokens: int = 0 # Number of tokens retrieved from cache
remaining_tokens: Optional[List[int]] = None # Tokens still needing processing
+ prefix_boundary: int = 0 # Token count for shared prefix (messages[:-1])
# Paged cache fields (for BlockAwarePrefixCache)
block_table: Optional["BlockTable"] = None # Block table for paged cache
@@ -129,6 +130,12 @@ class Request:
# Metadata
finish_reason: Optional[str] = None
+ first_token_time: Optional[float] = (
+ None # Time when first output token was generated
+ )
+ cache_hit_type: Optional[str] = (
+ None # Type of cache hit: exact/prefix/supersequence/lcp/miss
+ )
@property
def num_output_tokens(self) -> int:
diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py
index c5f2c52d..26ef5315 100644
--- a/vllm_mlx/scheduler.py
+++ b/vllm_mlx/scheduler.py
@@ -17,6 +17,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
+import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
@@ -77,6 +78,17 @@ class SchedulerConfig:
paged_cache_block_size: int = 64 # Tokens per block
max_cache_blocks: int = 1000 # Maximum number of cache blocks
+ # Chunked prefill: max tokens to prefill per scheduler step (0 = disabled)
+ # When enabled, large prompts are split into chunks so that active
+ # generation requests are not starved during long prefills.
+ chunked_prefill_tokens: int = 0
+
+ # Mid-prefill cache saving: save intermediate KV cache every N tokens
+ # during chunked prefill. If the client disconnects mid-prefill, the
+ # saved cache is reused for the next request with the same prefix.
+ # 0 = disabled. Only effective when chunked_prefill_tokens > 0.
+ mid_prefill_save_interval: int = 8192
+
@dataclass
class SchedulerOutput:
@@ -98,6 +110,375 @@ class SchedulerOutput:
has_work: bool = False
+def _install_chunked_prefill(
+ batch_gen: "BatchGenerator",
+ budget: int,
+ mid_prefill_save=None,
+ prompt_cache_save=None,
+ pending_abort_ids: Optional[Set[str]] = None,
+ uid_to_request_id: Optional[Dict[int, str]] = None,
+ requests: Optional[Dict[str, Any]] = None,
+) -> None:
+ """
+ Monkey-patch a BatchGenerator instance so that large prefills are
+ broken into chunks of at most *budget* tokens each.
+
+ Between chunks the generation loop gets a chance to produce one token
+ for every active request, preventing starvation during long prefills.
+
+ Args:
+ batch_gen: The BatchGenerator to patch.
+ budget: Max tokens per prefill chunk.
+ mid_prefill_save: Optional callback(uid, processed, prompt_cache)
+ called after each chunk to save intermediate KV cache state.
+ """
+ import time as _time
+
+ from mlx_lm.generate import (
+ Batch,
+ _left_pad_prompts,
+ _make_cache,
+ _merge_caches,
+ _right_pad_prompts,
+ )
+
+ # Keep references to originals
+ _orig_next = batch_gen._next
+ _orig_remove = batch_gen.remove
+ _orig_process_prompts = batch_gen._process_prompts
+
+ # Partial prefill state (None when no prefill in progress)
+ batch_gen._partial = None
+
+ # Monkey-patch _process_prompts to capture prompt-only cache state.
+ # At the point where _process_prompts returns, the Batch cache contains
+ # the exact prompt-only state: all prompt tokens have been processed
+ # through the model, but no output token has been fed back yet.
+ # This is the only safe capture point for hybrid Mamba+Transformer
+ # models whose MambaCache state is cumulative.
+ if prompt_cache_save is not None:
+
+ def _patched_process_prompts(prompts, _self=batch_gen):
+ batch = _orig_process_prompts(prompts)
+ for e, uid in enumerate(batch.uids):
+ if batch.num_tokens[e] == 0:
+ try:
+ prompt_cache_save(uid, batch.extract_cache(e))
+ except Exception:
+ pass
+ return batch
+
+ batch_gen._process_prompts = _patched_process_prompts
+
+ def _generation_step(self=batch_gen):
+ """Run one generation step on the active batch. Returns responses."""
+ batch = self.active_batch
+ if batch is None or len(batch) == 0:
+ return []
+
+ tic_gen = _time.perf_counter()
+ y, logprobs = batch.y, batch.logprobs
+ for i, toks in enumerate(batch.tokens):
+ batch.tokens[i] = mx.concatenate((toks, y[i : i + 1]))
+ batch.y, batch.logprobs = self._step(
+ y[:, None],
+ batch.cache,
+ batch.samplers,
+ batch.logits_processors,
+ batch.tokens,
+ )
+ mx.async_eval(batch.y, batch.logprobs)
+
+ y = y.tolist()
+ self._stats.generation_time += _time.perf_counter() - tic_gen
+
+ keep_idx = []
+ end_idx = []
+ responses = []
+ for e, (t, uid, num_tok, max_tok) in enumerate(
+ zip(y, batch.uids, batch.num_tokens, batch.max_tokens)
+ ):
+ cache_out = None
+ num_tok += 1
+ batch.num_tokens[e] = num_tok
+ if t in self.stop_tokens:
+ finish_reason = "stop"
+ end_idx.append(e)
+ elif num_tok >= max_tok:
+ finish_reason = "length"
+ end_idx.append(e)
+ else:
+ finish_reason = None
+ keep_idx.append(e)
+ if finish_reason is not None:
+ cache_out = batch.extract_cache(e)
+ responses.append(
+ self.Response(uid, t, logprobs[e], finish_reason, cache_out)
+ )
+
+ if len(end_idx):
+ if len(keep_idx) > 0:
+ batch.filter(keep_idx)
+ else:
+ self.active_batch = None
+
+ self._stats.generation_tokens += len(responses)
+ return responses
+
+ def _chunked_next(self=batch_gen): # noqa: C901
+ """
+ Replacement for _next() that chunks large prefills.
+
+ Only intercepts when:
+ 1. A partial prefill is in progress (_partial is not None)
+ 2. The next prompt batch exceeds the budget
+
+ Everything else delegates to the original _next().
+ """
+ # ----- Continue a partial prefill -----
+ if self._partial is not None:
+ # Check for pending aborts BEFORE processing next chunk
+ if pending_abort_ids is not None and uid_to_request_id is not None:
+ partial_rids = {uid_to_request_id.get(u) for u in self._partial["uids"]}
+ aborted_rids = partial_rids & pending_abort_ids
+ if aborted_rids:
+ logger.info(
+ f"[chunked_prefill] abort detected mid-prefill, "
+ f"clearing partial for: {aborted_rids}"
+ )
+ self._partial = None
+ mx.clear_cache()
+ return _generation_step()
+
+ tic = _time.perf_counter()
+ partial = self._partial
+ inputs = partial["inputs"]
+ prompt_cache = partial["cache"]
+ remaining = inputs.shape[1]
+
+ n_to_process = min(budget, remaining - 1) if remaining > 1 else 0
+
+ if n_to_process > 0:
+ self.model(inputs[:, :n_to_process], cache=prompt_cache)
+ mx.eval([c.state for c in prompt_cache])
+ inputs = inputs[:, n_to_process:]
+ partial["inputs"] = inputs
+ partial["processed"] += n_to_process
+
+ self.prompt_progress_callback(
+ [
+ (uid, partial["processed"], partial["total"])
+ for uid in partial["uids"]
+ ]
+ )
+
+ # Save intermediate cache for disconnect resilience
+ if mid_prefill_save is not None and len(partial["uids"]) == 1:
+ mid_prefill_save(
+ partial["uids"][0], partial["processed"], prompt_cache
+ )
+
+ if partial.get("is_cached"):
+ mx.clear_cache()
+
+ # Check if prefill is done (only 1 token left or 0)
+ if inputs.shape[1] <= 1:
+ # Finalize
+ if partial.get("is_cached"):
+ mx.eval([c.state for c in prompt_cache])
+ inputs = partial["last_inputs"]
+
+ for c in prompt_cache:
+ c.finalize()
+ mx.clear_cache()
+
+ y, logprobs = self._step(
+ inputs,
+ prompt_cache,
+ partial["samplers"],
+ partial["logits_processors"],
+ partial["tokens"],
+ )
+ mx.async_eval(y, logprobs)
+
+ new_batch = Batch(
+ list(partial["uids"]),
+ y,
+ logprobs,
+ list(partial["max_tokens"]),
+ [0] * len(partial["uids"]),
+ prompt_cache,
+ list(partial["samplers"]),
+ list(partial["logits_processors"]),
+ partial["tokens"],
+ )
+
+ # Save prompt-only cache BEFORE merging into active batch.
+ # This is the chunked-prefill equivalent of the
+ # _patched_process_prompts hook — at this point the cache
+ # contains the exact prompt-only state (num_tokens == 0).
+ if prompt_cache_save is not None and len(partial["uids"]) == 1:
+ uid = partial["uids"][0]
+ try:
+ prompt_cache_save(uid, new_batch.extract_cache(0))
+ except Exception:
+ pass
+
+ if self.active_batch is None:
+ self.active_batch = new_batch
+ else:
+ self.active_batch.extend(new_batch)
+
+ self._partial = None
+ self._stats.prompt_time += _time.perf_counter() - tic
+ else:
+ # Not done yet — record prompt time for this chunk
+ self._stats.prompt_time += _time.perf_counter() - tic
+
+ # Generation step for active requests between chunks
+ return _generation_step()
+
+ # ----- No partial — check if next prompt batch needs chunking -----
+ num_active = len(self.active_batch) if self.active_batch else 0
+ num_to_add = self.completion_batch_size - num_active
+
+ if num_to_add >= self.prefill_batch_size and self.unprocessed_prompts:
+ batch_prompts = self.unprocessed_prompts[: self.prefill_batch_size]
+ if batch_prompts:
+ total_tokens = sum(len(p[1]) for p in batch_prompts)
+
+ # Check if any prompt has a prefix_boundary that
+ # requires two-phase prefill for cache save at that boundary.
+ _needs_boundary_split = False
+ if requests is not None and uid_to_request_id is not None:
+ for _uid, _toks, *_ in batch_prompts:
+ _rid = uid_to_request_id.get(_uid)
+ _req = requests.get(_rid) if _rid else None
+ if _req and getattr(_req, "prefix_boundary", 0) > 0:
+ _needs_boundary_split = True
+ break
+
+ if total_tokens > budget or _needs_boundary_split:
+ # Large prompt batch or prefix boundary — start partial prefill
+ tic = _time.perf_counter()
+
+ # Eval outstanding generation tokens before switching
+ if self.active_batch is not None:
+ mx.eval(self.active_batch.y, self.active_batch.logprobs)
+ self._stats.generation_time += _time.perf_counter() - tic
+ tic = _time.perf_counter()
+
+ (
+ uids,
+ inputs_raw,
+ max_tokens_list,
+ caches,
+ samplers,
+ logits_processors,
+ ) = zip(*batch_prompts)
+ lengths = [len(p) for p in inputs_raw]
+ max_length = max(lengths)
+ padding = [max_length - ln for ln in lengths]
+ tokens = [mx.array(inp) for inp in inputs_raw]
+ is_cached = not all(c[0].empty() for c in caches)
+
+ self._stats.prompt_tokens += sum(lengths)
+
+ if not is_cached:
+ padded = _left_pad_prompts(inputs_raw, max_length=max_length)
+ prompt_cache = _make_cache(self.model, padding)
+ else:
+ last_inputs = mx.array([p[-1:] for p in inputs_raw])
+ padded = _right_pad_prompts(inputs_raw, max_length=max_length)
+ prompt_cache = _merge_caches(caches)
+ for c in prompt_cache:
+ c.prepare(
+ lengths=[ln - 1 for ln in lengths],
+ right_padding=padding,
+ )
+
+ # Remove from unprocessed
+ self.unprocessed_prompts = self.unprocessed_prompts[
+ self.prefill_batch_size :
+ ]
+
+ # Process first chunk — if prefix_boundary is set,
+ # use it as the first chunk size so that mid_prefill_save
+ # can capture the exact prefix cache state (critical for
+ # hybrid Mamba+Transformer models where trim is unsafe).
+ # When the request already has cached tokens (cache hit),
+ # adjust the boundary relative to the remaining tokens.
+ _first_chunk = budget
+ if _needs_boundary_split and len(batch_prompts) == 1:
+ _uid0 = uids[0]
+ _rid0 = uid_to_request_id.get(_uid0)
+ _req0 = requests.get(_rid0) if _rid0 else None
+ _pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0
+ _cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0
+ _adjusted_pb = _pb - _cached
+ if 0 < _adjusted_pb < padded.shape[1]:
+ _first_chunk = _adjusted_pb
+ n_to_process = min(_first_chunk, padded.shape[1] - 1)
+ if n_to_process > 0:
+ self.model(padded[:, :n_to_process], cache=prompt_cache)
+ mx.eval([c.state for c in prompt_cache])
+ padded = padded[:, n_to_process:]
+ if is_cached:
+ mx.clear_cache()
+
+ self._partial = {
+ "uids": list(uids),
+ "inputs": padded,
+ "cache": prompt_cache,
+ "tokens": tokens,
+ "max_tokens": list(max_tokens_list),
+ "samplers": list(samplers),
+ "logits_processors": list(logits_processors),
+ "processed": n_to_process,
+ "total": max_length,
+ "is_cached": is_cached,
+ }
+ if is_cached:
+ self._partial["last_inputs"] = last_inputs
+
+ self.prompt_progress_callback(
+ [
+ (uid, n_to_process, max_length)
+ for uid in self._partial["uids"]
+ ]
+ )
+
+ # Save intermediate cache for disconnect resilience
+ if mid_prefill_save is not None and len(uids) == 1:
+ mid_prefill_save(uids[0], n_to_process, prompt_cache)
+
+ self._stats.prompt_time += _time.perf_counter() - tic
+
+ # Generation step for active requests
+ return _generation_step()
+
+ # Small prompts, pure generation, or no work — delegate to original
+ return _orig_next()
+
+ def _patched_remove(uids_to_remove, _self=batch_gen):
+ """Clear partial state if aborted request is being prefilled."""
+ if _self._partial is not None:
+ partial_uids = set(_self._partial["uids"])
+ if partial_uids & set(uids_to_remove):
+ logger.info(
+ f"[chunked_prefill] clearing partial state for aborted uids: "
+ f"{partial_uids & set(uids_to_remove)}"
+ )
+ _self._partial = None
+ mx.clear_cache() # flush Metal encoders after dropping partial state
+ _orig_remove(uids_to_remove)
+
+ batch_gen._next = _chunked_next
+ batch_gen.remove = _patched_remove
+
+ logger.info(f"[chunked_prefill] installed with budget={budget} tokens per step")
+
+
class Scheduler:
"""
Scheduler for continuous batching using mlx-lm BatchGenerator.
@@ -192,11 +573,21 @@ def __init__(
f"Prefix cache enabled with max_entries={self.config.prefix_cache_size}"
)
+ # Thread-safe set for deferred aborts (main thread → executor thread)
+ # CPython GIL guarantees set.add() and `x in set` are atomic.
+ self._pending_abort_ids: Set[str] = set()
+
# Statistics
self.num_requests_processed = 0
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
+ # Memory management: periodic mx.clear_cache() to free Metal command buffers
+ # Lower interval = less VRAM spike during generation but slight throughput cost
+ self._step_count = 0
+ self._clear_cache_interval = 32
+ self._memory_log_interval = 256
+
def _get_actual_tokenizer(self, tokenizer: Any) -> Any:
"""
Get the actual tokenizer from a processor or tokenizer.
@@ -254,7 +645,16 @@ def _create_batch_generator(
if sampling_params.stop_token_ids:
stop_tokens.update(sampling_params.stop_token_ids)
- return BatchGenerator(
+ def _prefill_progress(progress_list):
+ """Log prefill progress for each uid chunk."""
+ for uid, processed, total in progress_list:
+ rid = self.uid_to_request_id.get(uid, "?")
+ logger.info(
+ f"[prefill] request={rid[:12] if isinstance(rid, str) else rid} "
+ f"tokens={processed}/{total}"
+ )
+
+ bg = BatchGenerator(
model=self.model,
max_tokens=sampling_params.max_tokens,
stop_tokens=stop_tokens,
@@ -262,8 +662,159 @@ def _create_batch_generator(
prefill_batch_size=self.config.prefill_batch_size,
completion_batch_size=self.config.completion_batch_size,
prefill_step_size=self.config.prefill_step_size,
+ prompt_progress_callback=_prefill_progress,
)
+ # Install chunked prefill when explicitly configured OR when
+ # memory-aware cache is active (needed for prefix_boundary saves
+ # in agentic multi-turn workloads with hybrid Mamba+Transformer models).
+ chunked_budget = self.config.chunked_prefill_tokens
+ need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None
+ if need_chunked:
+ if chunked_budget <= 0:
+ # No explicit budget — use a very large value so normal
+ # prompts pass through unchanged. Prefix boundary splits
+ # still trigger via _needs_boundary_split.
+ chunked_budget = 999_999
+ mid_prefill_cb = None
+ save_interval = self.config.mid_prefill_save_interval
+ if save_interval > 0 and self.memory_aware_cache is not None:
+ mid_prefill_cb = self._make_mid_prefill_save_callback(save_interval)
+ logger.info(f"[mid_prefill_cache] enabled, interval={save_interval}")
+ prompt_cache_cb = None
+ if self.memory_aware_cache is not None:
+ prompt_cache_cb = self._make_prompt_cache_save_callback()
+ _install_chunked_prefill(
+ bg,
+ chunked_budget,
+ mid_prefill_cb,
+ prompt_cache_save=prompt_cache_cb,
+ pending_abort_ids=self._pending_abort_ids,
+ uid_to_request_id=self.uid_to_request_id,
+ requests=self.requests,
+ )
+
+ return bg
+
+ def _make_prompt_cache_save_callback(self):
+ """Create a callback that stores prompt-only KV/Mamba cache.
+
+ Called from ``_generation_step`` right before the first output token
+ is fed into the model. At that point ``num_tokens == 0`` and the
+ batch cache contains the exact prompt-only state (correct for both
+ KVCache and MambaCache/ArraysCache layers).
+
+ The cache is stored with key = prompt_token_ids so that a future
+ request with the identical prompt gets an exact hit.
+ """
+ import time as _time
+
+ def _prompt_cache_save(uid, extracted_cache):
+ request_id = self.uid_to_request_id.get(uid)
+ if not request_id:
+ return
+ request = self.requests.get(request_id)
+ if not request or not request.prompt_token_ids:
+ return
+
+ prompt_tokens = list(request.prompt_token_ids)
+ _t0 = _time.monotonic()
+ # evict_prefixes=False: keep mid-prefill boundary entries so
+ # that future requests with the same prefix but different
+ # suffix get a prefix cache hit (critical for agentic multi-turn).
+ stored = self.memory_aware_cache.store(
+ prompt_tokens, extracted_cache, evict_prefixes=False
+ )
+ _dt = _time.monotonic() - _t0
+ if stored:
+ logger.info(
+ f"[prompt_cache_save] request={request_id[:12]} "
+ f"prompt_tokens={len(prompt_tokens)} "
+ f"store_time={_dt:.3f}s"
+ )
+
+ return _prompt_cache_save
+
+ def _make_mid_prefill_save_callback(self, save_interval: int):
+ """Create a callback for saving intermediate KV cache during chunked prefill.
+
+ The callback is called after each chunk with (uid, processed_tokens,
+ prompt_cache). It extracts the cache state (immutable MLX array
+ snapshots), reconstructs KVCache objects, and stores them in the
+ memory-aware prefix cache so that a subsequent request with the same
+ prompt prefix can skip the already-computed tokens.
+ """
+ import time as _time
+
+ def _mid_prefill_save(uid, processed_tokens, prompt_cache):
+ request_id = self.uid_to_request_id.get(uid)
+ if not request_id:
+ return
+ request = self.requests.get(request_id)
+ if not request or not request.prompt_token_ids:
+ return
+
+ total_cached = (request.cached_tokens or 0) + processed_tokens
+
+ # Always save at prefix_boundary (message boundary for cache
+ # reuse with different final user messages).
+ prefix_boundary = getattr(request, "prefix_boundary", 0)
+ at_prefix_boundary = prefix_boundary > 0 and total_cached == prefix_boundary
+
+ # Throttle: only save every save_interval tokens,
+ # unless we're at the prefix boundary.
+ last_save = getattr(request, "_mid_prefill_last_save", 0)
+ if not at_prefix_boundary and total_cached - last_save < save_interval:
+ return
+
+ # Extract immutable state snapshots
+ extracted = self._extract_cache_states(prompt_cache)
+ if not extracted:
+ return
+
+ # Reconstruct cache objects (directly usable by BatchGenerator)
+ reconstructed = self._reconstruct_cache_from_states(extracted)
+ if not reconstructed:
+ return
+
+ prefix_tokens = list(request.prompt_token_ids[:total_cached])
+
+ # Remove previous intermediate entry to avoid memory waste
+ old_key = getattr(request, "_mid_prefill_cache_key", None)
+ if old_key is not None:
+ self.memory_aware_cache.remove(list(old_key))
+
+ _t0 = _time.monotonic()
+ stored = self.memory_aware_cache.store(prefix_tokens, reconstructed)
+ _dt = _time.monotonic() - _t0
+
+ if stored:
+ request._mid_prefill_last_save = total_cached
+ request._mid_prefill_cache_key = tuple(prefix_tokens)
+ logger.info(
+ f"[mid_prefill_cache] request={request_id[:12]} "
+ f"saved {total_cached}/{len(request.prompt_token_ids)} tokens "
+ f"({total_cached * 100 // len(request.prompt_token_ids)}%) "
+ f"store_time={_dt:.3f}s"
+ )
+ else:
+ logger.debug(
+ f"[mid_prefill_cache] request={request_id[:12]} "
+ f"store rejected for {total_cached} tokens"
+ )
+
+ return _mid_prefill_save
+
+ def _close_batch_generator(self) -> None:
+ """Properly close BatchGenerator to restore wired_limit."""
+ if self.batch_generator is not None:
+ try:
+ if hasattr(self.batch_generator, "close"):
+ self.batch_generator.close()
+ except Exception as e:
+ logger.debug(f"Error closing BatchGenerator: {e}")
+ self.batch_generator = None
+
def _ensure_batch_generator(self, sampling_params: SamplingParams) -> None:
"""Ensure BatchGenerator exists with compatible settings."""
sampler_params = (
@@ -285,20 +836,26 @@ def _ensure_batch_generator(self, sampling_params: SamplingParams) -> None:
)
return
- # Clear prefix cache when BatchGenerator changes
- # BatchKVCache objects are tied to their generator instance
+ # Keep prefix cache across BatchGenerator recreations.
+ # KV cache entries depend only on the input tokens, not on
+ # sampling params (temperature, top_p, min_p). Since the
+ # server runs a single model, the cache is always valid.
if self.batch_generator is not None:
+ n_entries = 0
if self.memory_aware_cache is not None:
- logger.debug(
- "Clearing memory-aware cache: BatchGenerator being recreated"
- )
- self.memory_aware_cache.clear()
+ n_entries = len(self.memory_aware_cache._entries)
elif self.prefix_cache is not None:
- logger.debug(
- "Clearing prefix cache: BatchGenerator being recreated"
+ n_entries = (
+ len(self.prefix_cache)
+ if hasattr(self.prefix_cache, "__len__")
+ else 0
)
- self.prefix_cache.clear()
+ logger.info(
+ f"[batch_generator] recreating (sampler params changed), "
+ f"keeping {n_entries} cache entries"
+ )
+ self._close_batch_generator()
self.batch_generator = self._create_batch_generator(sampling_params)
self._current_sampler_params = sampler_params
@@ -306,8 +863,11 @@ def _validate_cache(self, cache: Any) -> bool:
"""
Validate that a cache object is usable.
- This prevents NoneType errors when mlx-lm's BatchKVCache
- contains invalid/stale references.
+ Checks for None references AND shape compatibility. Restored
+ cache entries must have batch_size == 1 (single sequence) so
+ they can be merged into the running batch by _merge_caches.
+ A shape mismatch here (e.g. batch=2 from a stale entry) would
+ cause a concatenation crash inside _merge_caches.
Args:
cache: The cache object to validate
@@ -331,6 +891,23 @@ def _validate_cache(self, cache: Any) -> bool:
return False
if hasattr(layer_cache, "values") and layer_cache.values is None:
return False
+ # Validate batch dimension == 1 for KVCache layers
+ if hasattr(layer_cache, "keys") and layer_cache.keys is not None:
+ if layer_cache.keys.shape[0] != 1:
+ logger.debug(
+ f"Cache layer invalid: keys batch={layer_cache.keys.shape[0]}, expected 1"
+ )
+ return False
+ # Validate batch dimension for MambaCache layers
+ if hasattr(layer_cache, "cache") and isinstance(
+ layer_cache.cache, list
+ ):
+ for arr in layer_cache.cache:
+ if arr is not None and arr.shape[0] != 1:
+ logger.debug(
+ f"Cache layer invalid: mamba batch={arr.shape[0]}, expected 1"
+ )
+ return False
# Check BatchKVCache structure
if hasattr(cache, "caches"):
@@ -363,13 +940,14 @@ def _extract_cache_states(self, raw_cache: List[Any]) -> List[Dict[str, Any]]:
for layer_cache in raw_cache:
try:
if hasattr(layer_cache, "state") and hasattr(layer_cache, "meta_state"):
- state = layer_cache.state # (keys, values) MLX arrays
+ state = layer_cache.state # (keys, values) or more for Mamba
meta = layer_cache.meta_state # (offset,) as strings
extracted.append(
{
"state": state,
"meta_state": meta,
"class_name": type(layer_cache).__name__,
+ "class_ref": type(layer_cache),
}
)
except Exception as e:
@@ -378,6 +956,72 @@ def _extract_cache_states(self, raw_cache: List[Any]) -> List[Dict[str, Any]]:
return extracted if len(extracted) == len(raw_cache) else []
+ def _reconstruct_cache_from_states(
+ self, extracted_states: List[Dict[str, Any]]
+ ) -> Optional[List[Any]]:
+ """
+ Reconstruct cache objects from extracted cache states.
+
+ This is the inverse of _extract_cache_states(). Uses mlx-lm's
+ _BaseCache.from_state() to reconstruct any cache type (KVCache,
+ MambaCache, etc.) from its state/meta_state.
+
+ Args:
+ extracted_states: List of dicts from _extract_cache_states()
+
+ Returns:
+ List of cache objects, or None if reconstruction fails
+ """
+ if not extracted_states:
+ return None
+
+ try:
+ caches = []
+ for layer_state in extracted_states:
+ state = layer_state.get("state")
+ meta_state = layer_state.get("meta_state")
+ cache_cls = layer_state.get("class_ref")
+ if state is None:
+ return None
+
+ if cache_cls is not None and hasattr(cache_cls, "from_state"):
+ # BatchKVCache doesn't inherit from KVCache, so
+ # _merge_caches can't handle it. Convert to KVCache
+ # (safe because mid-prefill save is always batch_size=1).
+ from mlx_lm.models.cache import (
+ BatchKVCache as _BatchKVCache,
+ KVCache as _KVCache,
+ )
+
+ if cache_cls is _BatchKVCache:
+ # BatchKVCache.state = (keys, values, offset, left_padding)
+ keys, values = state[0], state[1]
+ cache = _KVCache()
+ cache.keys = keys
+ cache.values = values
+ cache.offset = keys.shape[2]
+ else:
+ cache = cache_cls.from_state(state, meta_state)
+ else:
+ # Fallback: try KVCache manual reconstruction
+ from mlx_lm.models.cache import KVCache
+
+ if len(state) != 2:
+ return None
+ cache = KVCache()
+ cache.keys, cache.values = state
+ cache.offset = (
+ int(meta_state[0]) if meta_state else cache.keys.shape[2]
+ )
+
+ caches.append(cache)
+
+ return caches
+
+ except Exception as e:
+ logger.info(f"[mid_prefill_cache] reconstruct EXCEPTION: {e}")
+ return None
+
def add_request(self, request: Request) -> None:
"""
Add a new request to the scheduler.
@@ -418,6 +1062,7 @@ def add_request(self, request: Request) -> None:
request.prompt_token_ids,
)
if block_table and block_table.num_tokens > 0:
+ request.cache_hit_type = "hit"
# Reconstruct actual KVCache objects from stored tensor data
reconstructed = self.block_aware_cache.reconstruct_cache(block_table)
if reconstructed:
@@ -433,30 +1078,44 @@ def add_request(self, request: Request) -> None:
)
else:
# Reconstruction failed, treat as cache miss
+ request.cache_hit_type = "miss"
request.remaining_tokens = request.prompt_token_ids
logger.debug(
f"Request {request.request_id}: paged cache reconstruction failed"
)
else:
+ request.cache_hit_type = "miss"
request.remaining_tokens = request.prompt_token_ids
elif self.memory_aware_cache is not None:
# Use memory-aware prefix cache
+ import time as _time
+
+ _fetch_t0 = _time.monotonic()
cache, remaining = self.memory_aware_cache.fetch(request.prompt_token_ids)
+ _fetch_dt = _time.monotonic() - _fetch_t0
+ request.cache_hit_type = self.memory_aware_cache._last_match_type
if cache:
request.prompt_cache = cache
request.cached_tokens = len(request.prompt_token_ids) - len(remaining)
request.remaining_tokens = remaining
- logger.debug(
- f"Request {request.request_id}: memory-aware cache hit, "
- f"{request.cached_tokens} tokens cached, "
- f"{len(remaining)} tokens remaining"
+ logger.info(
+ f"[cache_fetch] request={request.request_id[:12]} HIT "
+ f"prompt_tokens={len(request.prompt_token_ids)} "
+ f"cached={request.cached_tokens} remaining={len(remaining)} "
+ f"time={_fetch_dt:.3f}s"
)
else:
request.remaining_tokens = request.prompt_token_ids
+ logger.info(
+ f"[cache_fetch] request={request.request_id[:12]} MISS "
+ f"prompt_tokens={len(request.prompt_token_ids)} "
+ f"time={_fetch_dt:.3f}s entries={len(self.memory_aware_cache._entries)}"
+ )
elif self.prefix_cache is not None:
# Use legacy prefix cache
cache, remaining = self.prefix_cache.fetch_cache(request.prompt_token_ids)
if cache:
+ request.cache_hit_type = "hit"
request.prompt_cache = cache
request.cached_tokens = len(request.prompt_token_ids) - len(remaining)
request.remaining_tokens = remaining
@@ -466,8 +1125,10 @@ def add_request(self, request: Request) -> None:
f"{len(remaining)} tokens remaining"
)
else:
+ request.cache_hit_type = "miss"
request.remaining_tokens = request.prompt_token_ids
else:
+ request.cache_hit_type = "miss"
request.remaining_tokens = request.prompt_token_ids
# Add to tracking
@@ -480,41 +1141,82 @@ def add_request(self, request: Request) -> None:
def abort_request(self, request_id: str) -> bool:
"""
- Abort a request.
+ Queue request for abort. Thread-safe, called from any thread.
+
+ The actual abort is deferred to the executor thread (inside step())
+ to avoid race conditions with in-flight Metal GPU operations.
Args:
request_id: The request ID to abort
Returns:
- True if request was found and aborted, False otherwise
+ True (abort is always enqueued)
+ """
+ self._pending_abort_ids.add(request_id)
+ logger.info(f"[abort_request] {request_id[:12]} enqueued for deferred abort")
+ return True
+
+ def _process_pending_aborts(self) -> None:
+ """Drain and process pending abort requests. Called from executor thread."""
+ while self._pending_abort_ids:
+ request_id = self._pending_abort_ids.pop()
+ self._do_abort_request(request_id)
+
+ def _do_abort_request(self, request_id: str) -> bool:
+ """
+ Actually abort a request. Must be called from the executor thread.
+
+ Handles the case where the request was already removed from
+ self.requests by _cleanup_request() but still lives in the
+ BatchGenerator (e.g. in _partial or active_batch).
+
+ Args:
+ request_id: The request ID to abort
+
+ Returns:
+ True if any cleanup was performed, False otherwise
"""
request = self.requests.get(request_id)
- if request is None:
- return False
+ was_waiting = False
+ was_running = False
+ removed_from_batch = False
# Remove from waiting queue
- if request.status == RequestStatus.WAITING:
+ if request is not None and request.status == RequestStatus.WAITING:
+ was_waiting = True
try:
self.waiting.remove(request)
except ValueError:
pass
- # Remove from running (BatchGenerator)
- if request.request_id in self.request_id_to_uid:
- uid = self.request_id_to_uid[request.request_id]
+ # Remove from running (BatchGenerator) — do this even if request
+ # was already cleaned up from self.requests, because the UID may
+ # still be live inside the BatchGenerator (_partial / active_batch).
+ if request_id in self.request_id_to_uid:
+ was_running = True
+ uid = self.request_id_to_uid[request_id]
if self.batch_generator is not None:
self.batch_generator.remove([uid])
+ removed_from_batch = True
del self.uid_to_request_id[uid]
- del self.request_id_to_uid[request.request_id]
+ del self.request_id_to_uid[request_id]
if request_id in self.running:
del self.running[request_id]
- # Mark as aborted
- request.set_finished(RequestStatus.FINISHED_ABORTED)
+ if request is not None:
+ request.set_finished(RequestStatus.FINISHED_ABORTED)
self.finished_req_ids.add(request_id)
- logger.debug(f"Aborted request {request_id}")
+ # Flush Metal encoders after removing arrays from batch
+ mx.clear_cache()
+
+ logger.info(
+ f"[abort_request] {request_id[:12]} ABORTED "
+ f"was_waiting={was_waiting} was_running={was_running} "
+ f"removed_from_batch={removed_from_batch} "
+ f"remaining_running={len(self.running)} remaining_waiting={len(self.waiting)}"
+ )
return True
def has_requests(self) -> bool:
@@ -577,12 +1279,34 @@ def _schedule_waiting(self) -> List[Request]:
request.remaining_tokens = request.prompt_token_ids
tokens_to_process = request.prompt_token_ids
- # Insert into BatchGenerator with optional cache
- uids = self.batch_generator.insert(
- [tokens_to_process],
- max_tokens=[request.sampling_params.max_tokens],
- caches=[cache_to_use] if cache_to_use else None,
- )
+ # Insert into BatchGenerator with optional cache.
+ # Wrap in try/except: if cache shapes are incompatible
+ # (e.g. stale entry after BatchGenerator recreation),
+ # fall back to no-cache insert instead of crashing.
+ try:
+ uids = self.batch_generator.insert(
+ [tokens_to_process],
+ max_tokens=[request.sampling_params.max_tokens],
+ caches=[cache_to_use] if cache_to_use else None,
+ )
+ except Exception as e:
+ if cache_to_use is not None:
+ logger.warning(
+ f"[cache_insert_error] request={request.request_id[:12]} "
+ f"cache insert failed ({e}), retrying without cache"
+ )
+ cache_to_use = None
+ request.prompt_cache = None
+ request.cached_tokens = 0
+ request.remaining_tokens = request.prompt_token_ids
+ tokens_to_process = request.prompt_token_ids
+ uids = self.batch_generator.insert(
+ [tokens_to_process],
+ max_tokens=[request.sampling_params.max_tokens],
+ caches=None,
+ )
+ else:
+ raise
if uids:
uid = uids[0]
@@ -599,9 +1323,13 @@ def _schedule_waiting(self) -> List[Request]:
if request.cached_tokens > 0
else ""
)
- logger.debug(
- f"Scheduled request {request.request_id} (uid={uid}) "
- f"with {request.num_prompt_tokens} tokens{cache_info}"
+ tokens_to_prefill = len(tokens_to_process)
+ logger.info(
+ f"[schedule] request={request.request_id[:12]} uid={uid} "
+ f"prompt_tokens={request.num_prompt_tokens} "
+ f"tokens_to_prefill={tokens_to_prefill}{cache_info} "
+ f"max_tokens={request.sampling_params.max_tokens} "
+ f"running={len(self.running)} waiting={len(self.waiting)}"
)
return scheduled
@@ -633,8 +1361,17 @@ def _process_batch_responses(
# Append token to request
request.append_output_token(response.token)
- # Decode the new token
- new_text = self._decode_tokens([response.token])
+ # Record first token time for TTFT metric
+ if request.first_token_time is None and request.num_output_tokens > 0:
+ import time as _time
+
+ request.first_token_time = _time.time()
+
+ # Decode the new token (skip stop tokens — they are not content)
+ if response.finish_reason == "stop":
+ new_text = ""
+ else:
+ new_text = self._decode_tokens([response.token])
# Create output
output = RequestOutput(
@@ -661,7 +1398,7 @@ def _process_batch_responses(
output.output_text = self._decode_tokens(request.output_token_ids)
request.output_text = output.output_text
- # Extract cache for future reuse
+ # Extract cache for future reuse (critical for agentic multi-turn)
if hasattr(response, "prompt_cache"):
try:
# prompt_cache may be callable or direct attribute
@@ -735,6 +1472,11 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
# unused blocks when under memory pressure.
elif self.memory_aware_cache is not None:
+ # Keep mid-prefill entry as prefix cache for future
+ # requests that share a common prefix (e.g. same system
+ # prompt + tools but different user message). LRU
+ # eviction handles memory pressure.
+
# Store in memory-aware prefix cache
# Key includes both prompt and output tokens for multi-turn chat caching
if (
@@ -745,13 +1487,33 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
full_token_sequence = list(request.prompt_token_ids) + list(
request.output_token_ids
)
- self.memory_aware_cache.store(
+ import time as _time
+
+ _store_t0 = _time.monotonic()
+ stored = self.memory_aware_cache.store(
full_token_sequence,
request._extracted_cache,
+ evict_prefixes=False,
)
- logger.debug(
- f"Stored memory-aware cache for request {request_id} "
- f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)"
+ _store_dt = _time.monotonic() - _store_t0
+ # NOTE: We intentionally do NOT store a prompt-only
+ # cache entry. Hybrid Mamba+Transformer models
+ # (like Qwen3-Coder-Next) have MambaCache layers
+ # whose state is cumulative and cannot be trimmed
+ # back to "prompt only". Reusing such state causes
+ # the model to immediately produce EOS.
+ # The full prompt+output entry is stored above; a
+ # future request with the same prompt will hit the
+ # supersequence match path in the fetch, which is
+ # now disabled for safety (see memory_cache.py).
+
+ logger.info(
+ f"[cache_store] request={request_id[:12]} "
+ f"tokens={len(full_token_sequence)} "
+ f"({len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output) "
+ f"stored={stored} time={_store_dt:.3f}s "
+ f"cache_entries={len(self.memory_aware_cache._entries)} "
+ f"cache_mem={self.memory_aware_cache._current_memory / 1e6:.0f}MB"
)
except Exception as e:
logger.debug(
@@ -781,6 +1543,24 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
except Exception as e:
logger.debug(f"Failed to store cache for {request_id}: {e}")
+ # Evaluate stored cache tensors incrementally (per-layer) to prevent
+ # a deferred batch evaluation spike when all lazy ops resolve at once.
+ # This spreads the VRAM cost across smaller per-layer evaluations.
+ if (
+ request is not None
+ and hasattr(request, "_extracted_cache")
+ and request._extracted_cache
+ ):
+ for layer in request._extracted_cache:
+ if isinstance(layer, dict) and "state" in layer:
+ keys, values = layer["state"]
+ mx.eval(keys, values)
+ elif hasattr(layer, "keys") and hasattr(layer, "values"):
+ keys_attr = layer.keys
+ values_attr = layer.values
+ if not callable(keys_attr) and not callable(values_attr):
+ mx.eval(keys_attr, values_attr)
+
# Remove from running
if request_id in self.running:
del self.running[request_id]
@@ -795,6 +1575,10 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
# Track as finished
self.finished_req_ids.add(request_id)
+ # Free Metal command buffers after cleanup (prevents end-of-generation spike)
+ if finished_ids:
+ mx.clear_cache()
+
def _is_cache_corruption_error(self, error: Exception) -> bool:
"""Check if an error indicates cache corruption."""
error_str = str(error)
@@ -802,8 +1586,8 @@ def _is_cache_corruption_error(self, error: Exception) -> bool:
def _recover_from_cache_error(self) -> None:
"""Recover from cache corruption error."""
- # Clear batch generator (this is the source of the corruption)
- self.batch_generator = None
+ # Properly close batch generator (this is the source of the corruption)
+ self._close_batch_generator()
self._current_sampler_params = None
# Clear caches
@@ -856,6 +1640,9 @@ def step(self, max_retries: int = 1) -> SchedulerOutput:
"""
output = SchedulerOutput()
+ # Process pending aborts FIRST (in executor thread, safe for MLX)
+ self._process_pending_aborts()
+
for attempt in range(max_retries + 1):
try:
# Schedule waiting requests
@@ -907,6 +1694,27 @@ def step(self, max_retries: int = 1) -> SchedulerOutput:
old_finished = self.finished_req_ids
self.finished_req_ids = set()
+ # Periodically clear Metal cache to prevent memory accumulation
+ self._step_count += 1
+ if self._step_count % self._clear_cache_interval == 0:
+ mx.clear_cache()
+
+ # Periodically log memory stats for monitoring
+ if self._step_count % self._memory_log_interval == 0:
+ try:
+ if mx.metal.is_available():
+ active_gb = mx.get_active_memory() / 1e9
+ peak_gb = mx.get_peak_memory() / 1e9
+ cache_gb = mx.get_cache_memory() / 1e9
+ logger.info(
+ f"[Metal memory] active={active_gb:.1f}GB "
+ f"peak={peak_gb:.1f}GB cache={cache_gb:.1f}GB "
+ f"step={self._step_count} "
+ f"running={len(self.running)} waiting={len(self.waiting)}"
+ )
+ except Exception:
+ pass
+
return output
def get_request(self, request_id: str) -> Optional[Request]:
@@ -917,6 +1725,74 @@ def remove_finished_request(self, request_id: str) -> Optional[Request]:
"""Remove a finished request from tracking."""
return self.requests.pop(request_id, None)
+ def get_running_requests_info(self) -> List[Dict[str, Any]]:
+ """Per-request details for status endpoint."""
+ import time as _time
+
+ now = _time.time()
+ result = []
+
+ # Waiting requests
+ for req in self.waiting:
+ result.append(
+ {
+ "request_id": req.request_id,
+ "status": "waiting",
+ "phase": "queued",
+ "elapsed_s": round(now - req.arrival_time, 2),
+ "prompt_tokens": req.num_prompt_tokens,
+ "completion_tokens": 0,
+ "max_tokens": req.max_tokens,
+ "progress": 0.0,
+ "tokens_per_second": None,
+ "ttft_s": None,
+ "cache_hit_type": req.cache_hit_type,
+ "cached_tokens": req.cached_tokens,
+ }
+ )
+
+ # Running requests
+ for req in self.running.values():
+ n_out = req.num_output_tokens
+ elapsed = now - req.arrival_time
+
+ # Phase detection
+ if n_out == 0:
+ phase = "prefill"
+ else:
+ phase = "generation"
+
+ # Tokens per second (generation phase only)
+ tok_s = None
+ ttft = None
+ if req.first_token_time is not None:
+ ttft = round(req.first_token_time - req.arrival_time, 3)
+ gen_elapsed = now - req.first_token_time
+ if gen_elapsed > 0 and n_out > 0:
+ tok_s = round(n_out / gen_elapsed, 1)
+
+ # Progress: completion_tokens / max_tokens
+ progress = round(n_out / req.max_tokens, 3) if req.max_tokens > 0 else 0.0
+
+ result.append(
+ {
+ "request_id": req.request_id,
+ "status": "running",
+ "phase": phase,
+ "elapsed_s": round(elapsed, 2),
+ "prompt_tokens": req.num_prompt_tokens,
+ "completion_tokens": n_out,
+ "max_tokens": req.max_tokens,
+ "progress": min(progress, 1.0),
+ "tokens_per_second": tok_s,
+ "ttft_s": ttft,
+ "cache_hit_type": req.cache_hit_type,
+ "cached_tokens": req.cached_tokens,
+ }
+ )
+
+ return result
+
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
stats = {
@@ -926,6 +1802,15 @@ def get_stats(self) -> Dict[str, Any]:
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
}
+ # Include Metal memory stats
+ try:
+ if mx.metal.is_available():
+ stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2)
+ stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2)
+ stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2)
+ except Exception:
+ pass
+
# Include cache stats
if self.block_aware_cache is not None:
stats["paged_cache"] = self.block_aware_cache.get_stats()
@@ -947,9 +1832,12 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]:
def reset(self) -> None:
"""Reset the scheduler state."""
- # Abort all requests
+ # Drain any pending deferred aborts
+ self._pending_abort_ids.clear()
+
+ # Abort all requests directly (reset is synchronous)
for request_id in list(self.requests.keys()):
- self.abort_request(request_id)
+ self._do_abort_request(request_id)
self.waiting.clear()
self.running.clear()
@@ -957,7 +1845,7 @@ def reset(self) -> None:
self.finished_req_ids.clear()
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
- self.batch_generator = None
+ self._close_batch_generator()
self._current_sampler_params = None
# Clear caches
@@ -997,3 +1885,21 @@ def deep_reset(self) -> None:
gc.collect()
logger.info("Deep reset completed - all caches cleared")
+
+ # -----------------------------------------------------------------
+ # Cache persistence
+ # -----------------------------------------------------------------
+
+ def save_cache_to_disk(self, cache_dir: str) -> bool:
+ """Save prefix cache to disk for persistence across restarts."""
+ if self.memory_aware_cache is not None:
+ return self.memory_aware_cache.save_to_disk(cache_dir)
+ logger.info("[cache_persist] no memory-aware cache to save")
+ return False
+
+ def load_cache_from_disk(self, cache_dir: str) -> int:
+ """Load prefix cache from disk. Returns number of entries loaded."""
+ if self.memory_aware_cache is not None:
+ return self.memory_aware_cache.load_from_disk(cache_dir)
+ logger.info("[cache_persist] no memory-aware cache to load into")
+ return 0
diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py
index d36779ef..7b131067 100644
--- a/vllm_mlx/server.py
+++ b/vllm_mlx/server.py
@@ -49,7 +49,6 @@
import uuid
from collections import defaultdict
from collections.abc import AsyncIterator
-from contextlib import asynccontextmanager
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
@@ -58,6 +57,8 @@
# Import from new modular API
# Re-export for backwards compatibility with tests
+from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic
+from .api.anthropic_models import AnthropicRequest
from .api.models import (
AssistantMessage, # noqa: F401
ChatCompletionChoice, # noqa: F401
@@ -96,6 +97,7 @@
parse_tool_calls,
)
from .api.utils import (
+ SPECIAL_TOKENS_PATTERN,
clean_output_text,
extract_multimodal_content,
is_mllm_model, # noqa: F401
@@ -157,7 +159,50 @@ def _resolve_top_p(request_value: float | None) -> float:
_tool_parser_instance = None # Instantiated parser
-@asynccontextmanager
+def _load_prefix_cache_from_disk() -> None:
+ """Load prefix cache from disk during startup."""
+ try:
+ d = _get_cache_dir()
+ logger.info(f"[lifespan] Loading prefix cache from {d}")
+ loaded = _engine.load_cache_from_disk(d)
+ if loaded > 0:
+ logger.info(f"[lifespan] Loaded {loaded} prefix cache entries")
+ else:
+ logger.info("[lifespan] No prefix cache entries found on disk")
+ except Exception as e:
+ logger.warning(f"[lifespan] Failed to load cache from disk: {e}", exc_info=True)
+
+
+def _save_prefix_cache_to_disk() -> None:
+ """Save prefix cache to disk during shutdown."""
+ try:
+ d = _get_cache_dir()
+ logger.info(f"[lifespan] Saving prefix cache to {d}")
+ saved = _engine.save_cache_to_disk(d)
+ if saved:
+ logger.info(f"[lifespan] Saved prefix cache to {d}")
+ else:
+ logger.info("[lifespan] No cache to save")
+ except Exception as e:
+ logger.warning(f"[lifespan] Failed to save cache to disk: {e}", exc_info=True)
+
+
+def _get_cache_dir() -> str:
+ """Get cache persistence directory based on model name."""
+ # Use global _model_name which is always a string, set during load_model()
+ model_name = _model_name if _model_name else "default"
+ logger.info(
+ f"[_get_cache_dir] _model_name={_model_name!r} type={type(_model_name)}"
+ )
+ # Sanitize model name for filesystem
+ safe_name = str(model_name).replace("/", "--").replace("\\", "--")
+ cache_dir = os.path.join(
+ os.path.expanduser("~"), ".cache", "vllm-mlx", "prefix_cache", safe_name
+ )
+ logger.info(f"[_get_cache_dir] cache_dir={cache_dir!r}")
+ return cache_dir
+
+
async def lifespan(app: FastAPI):
"""FastAPI lifespan for startup/shutdown events."""
global _engine, _mcp_manager
@@ -166,6 +211,10 @@ async def lifespan(app: FastAPI):
if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded:
await _engine.start()
+ # Load persisted cache from disk (AFTER engine start — AsyncEngineCore must exist)
+ if _engine is not None and hasattr(_engine, "load_cache_from_disk"):
+ _load_prefix_cache_from_disk()
+
# Initialize MCP if config provided
mcp_config = os.environ.get("VLLM_MLX_MCP_CONFIG")
if mcp_config:
@@ -173,6 +222,10 @@ async def lifespan(app: FastAPI):
yield
+ # Shutdown: Save cache to disk BEFORE stopping engine
+ if _engine is not None and hasattr(_engine, "save_cache_to_disk"):
+ _save_prefix_cache_to_disk()
+
# Shutdown: Close MCP connections and stop engine
if _mcp_manager is not None:
await _mcp_manager.stop()
@@ -300,9 +353,11 @@ def _parse_tool_calls_with_parser(
"""
global _tool_parser_instance
+ request_dict = request.model_dump() if request else None
+
# If auto tool choice is not enabled, use the generic parser
if not _enable_auto_tool_choice or not _tool_call_parser:
- return parse_tool_calls(output_text)
+ return parse_tool_calls(output_text, request_dict)
# Initialize parser if needed
if _tool_parser_instance is None:
@@ -319,13 +374,13 @@ def _parse_tool_calls_with_parser(
f"Failed to initialize tool parser '{_tool_call_parser}': {e}"
)
logger.warning("Falling back to generic parser")
- return parse_tool_calls(output_text)
+ return parse_tool_calls(output_text, request_dict)
# Use the configured parser
try:
# Reset parser state between requests
_tool_parser_instance.reset()
- result = _tool_parser_instance.extract_tool_calls(output_text)
+ result = _tool_parser_instance.extract_tool_calls(output_text, request_dict)
if result.tools_called:
tool_calls = [
ToolCall(
@@ -340,10 +395,12 @@ def _parse_tool_calls_with_parser(
]
return result.content or "", tool_calls
else:
- return result.content or output_text, None
+ # Fallback: specific parser didn't find tool calls,
+ # try generic parser which handles more formats (e.g. Nemotron XML)
+ return parse_tool_calls(output_text, request_dict)
except Exception as e:
logger.warning(f"Tool parser error: {e}")
- return parse_tool_calls(output_text)
+ return parse_tool_calls(output_text, request_dict)
def _detect_native_tool_support() -> bool:
@@ -505,6 +562,36 @@ async def health():
}
+@app.get("/v1/status")
+async def status():
+ """Real-time status with per-request details for debugging and monitoring."""
+ if _engine is None:
+ return {"status": "not_loaded", "model": None, "requests": []}
+
+ stats = _engine.get_stats()
+
+ return {
+ "status": "running" if stats.get("running") else "stopped",
+ "model": _model_name,
+ "uptime_s": round(stats.get("uptime_seconds", 0), 1),
+ "steps_executed": stats.get("steps_executed", 0),
+ "num_running": stats.get("num_running", 0),
+ "num_waiting": stats.get("num_waiting", 0),
+ "total_requests_processed": stats.get("num_requests_processed", 0),
+ "total_prompt_tokens": stats.get("total_prompt_tokens", 0),
+ "total_completion_tokens": stats.get("total_completion_tokens", 0),
+ "metal": {
+ "active_memory_gb": stats.get("metal_active_memory_gb"),
+ "peak_memory_gb": stats.get("metal_peak_memory_gb"),
+ "cache_memory_gb": stats.get("metal_cache_memory_gb"),
+ },
+ "cache": stats.get("memory_aware_cache")
+ or stats.get("paged_cache")
+ or stats.get("prefix_cache"),
+ "requests": stats.get("requests", []),
+ }
+
+
@app.get("/v1/cache/stats")
async def cache_stats():
"""Get cache statistics for debugging and monitoring."""
@@ -892,6 +979,184 @@ async def list_voices(model: str = "kokoro"):
return {"voices": ["default"]}
+# =============================================================================
+# Streaming disconnect detection
+# =============================================================================
+
+
+async def _disconnect_guard(
+ generator: AsyncIterator[str],
+ raw_request: Request,
+ poll_interval: float = 0.5,
+) -> AsyncIterator[str]:
+ """Wrap streaming generator to abort on client disconnect.
+
+ Uses asyncio racing: each __anext__() on the inner generator is
+ raced against a disconnect poller. This catches disconnects even
+ during prefill when no chunks are being yielded for tens of seconds.
+
+ On disconnect, aclose() propagates down the generator chain to
+ engine_core.stream_outputs() finally-block → abort_request().
+ """
+ import time as _time
+
+ _t0 = _time.monotonic()
+
+ def _elapsed():
+ return f"{_time.monotonic() - _t0:.1f}s"
+
+ logger.info(f"[disconnect_guard] START poll_interval={poll_interval}s")
+
+ async def _wait_disconnect():
+ poll_count = 0
+ while True:
+ await asyncio.sleep(poll_interval)
+ poll_count += 1
+ is_disc = await raw_request.is_disconnected()
+ if poll_count % 10 == 0 or is_disc:
+ logger.info(
+ f"[disconnect_guard] poll #{poll_count} "
+ f"disconnected={is_disc} elapsed={_elapsed()}"
+ )
+ if is_disc:
+ return
+
+ chunk_count = 0
+ disconnect_task: asyncio.Task | None = None
+ anext_task: asyncio.Task | None = None
+ try:
+ aiter = generator.__aiter__()
+ disconnect_task = asyncio.create_task(_wait_disconnect())
+ while True:
+ anext_task = asyncio.ensure_future(aiter.__anext__())
+ done, _ = await asyncio.wait(
+ [anext_task, disconnect_task],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ if disconnect_task in done:
+ logger.info(
+ f"[disconnect_guard] CLIENT DISCONNECTED after "
+ f"{chunk_count} chunks, elapsed={_elapsed()}"
+ )
+ anext_task.cancel()
+ try:
+ await anext_task
+ except (asyncio.CancelledError, StopAsyncIteration):
+ pass
+ break
+ try:
+ chunk = anext_task.result()
+ except StopAsyncIteration:
+ logger.info(
+ f"[disconnect_guard] generator exhausted normally, "
+ f"{chunk_count} chunks, elapsed={_elapsed()}"
+ )
+ break
+ chunk_count += 1
+ if chunk_count == 1:
+ logger.info(
+ f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}"
+ )
+ yield chunk
+ except GeneratorExit:
+ logger.info(
+ f"[disconnect_guard] GeneratorExit after {chunk_count} chunks, elapsed={_elapsed()}"
+ )
+ finally:
+ if disconnect_task and not disconnect_task.done():
+ disconnect_task.cancel()
+ if anext_task and not anext_task.done():
+ anext_task.cancel()
+ # NOTE: Do NOT call generator.aclose() here. With run_in_executor,
+ # scheduler.step() runs in a background thread. aclose() would throw
+ # GeneratorExit into the async-generator chain, which can trigger
+ # mlx::core::eval on the main thread while the executor thread is also
+ # mid-eval → Metal assertion failure → SIGABRT.
+ #
+ # Instead, rely on the task cancellation propagation:
+ # anext_task.cancel() → CancelledError in stream_outputs()
+ # → finally block → abort_request() → request removed from scheduler
+ logger.info(
+ f"[disconnect_guard] CLEANUP done, {chunk_count} chunks total, elapsed={_elapsed()}"
+ )
+
+
+async def _wait_with_disconnect(
+ coro,
+ raw_request: Request,
+ timeout: float,
+ poll_interval: float = 0.5,
+):
+ """Run a coroutine with both timeout and client disconnect detection.
+
+ For non-streaming requests where _disconnect_guard() can't be used.
+ Races the coroutine against a disconnect poller, same pattern as
+ _disconnect_guard but for awaitable (non-generator) coroutines.
+ """
+ import time as _time
+
+ _t0 = _time.monotonic()
+
+ task = asyncio.ensure_future(coro)
+
+ async def _wait_disconnect():
+ poll_count = 0
+ while True:
+ await asyncio.sleep(poll_interval)
+ poll_count += 1
+ is_disc = await raw_request.is_disconnected()
+ if poll_count % 10 == 0 or is_disc:
+ logger.info(
+ f"[disconnect_guard] poll #{poll_count} "
+ f"disconnected={is_disc} elapsed={_time.monotonic() - _t0:.1f}s"
+ )
+ if is_disc:
+ return
+
+ disconnect_task = asyncio.create_task(_wait_disconnect())
+
+ try:
+ done, _ = await asyncio.wait(
+ [task, disconnect_task],
+ timeout=timeout,
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+
+ if not done:
+ # Timeout
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+ raise HTTPException(
+ status_code=504,
+ detail=f"Request timed out after {timeout:.1f} seconds",
+ )
+
+ if disconnect_task in done:
+ # Client disconnected
+ logger.info(
+ f"[disconnect_guard] CLIENT DISCONNECTED (non-stream) "
+ f"elapsed={_time.monotonic() - _t0:.1f}s"
+ )
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+ return None # Signal to caller that client disconnected
+
+ # Task completed
+ return task.result()
+
+ finally:
+ if not disconnect_task.done():
+ disconnect_task.cancel()
+ if not task.done():
+ task.cancel()
+
+
# =============================================================================
# Completion Endpoints
# =============================================================================
@@ -900,16 +1165,28 @@ async def list_voices(model: str = "kokoro"):
@app.post(
"/v1/completions", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)]
)
-async def create_completion(request: CompletionRequest):
+async def create_completion(request: CompletionRequest, raw_request: Request):
"""Create a text completion."""
engine = get_engine()
# Handle single prompt or list of prompts
prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
+ # --- Detailed request logging ---
+ prompt_preview = prompts[0][:200] if prompts else "(empty)"
+ prompt_len = sum(len(p) for p in prompts)
+ logger.info(
+ f"[REQUEST] POST /v1/completions stream={request.stream} "
+ f"max_tokens={request.max_tokens} temp={request.temperature} "
+ f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}"
+ )
+
if request.stream:
return StreamingResponse(
- stream_completion(engine, prompts[0], request),
+ _disconnect_guard(
+ stream_completion(engine, prompts[0], request),
+ raw_request,
+ ),
media_type="text/event-stream",
)
@@ -921,21 +1198,19 @@ async def create_completion(request: CompletionRequest):
total_prompt_tokens = 0
for i, prompt in enumerate(prompts):
- try:
- output = await asyncio.wait_for(
- engine.generate(
- prompt=prompt,
- max_tokens=request.max_tokens or _default_max_tokens,
- temperature=_resolve_temperature(request.temperature),
- top_p=_resolve_top_p(request.top_p),
- stop=request.stop,
- ),
- timeout=timeout,
- )
- except asyncio.TimeoutError:
- raise HTTPException(
- status_code=504, detail=f"Request timed out after {timeout:.1f} seconds"
- )
+ output = await _wait_with_disconnect(
+ engine.generate(
+ prompt=prompt,
+ max_tokens=request.max_tokens or _default_max_tokens,
+ temperature=_resolve_temperature(request.temperature),
+ top_p=_resolve_top_p(request.top_p),
+ stop=request.stop,
+ ),
+ raw_request,
+ timeout=timeout,
+ )
+ if output is None:
+ return Response(status_code=499) # Client closed request
choices.append(
CompletionChoice(
@@ -970,7 +1245,7 @@ async def create_completion(request: CompletionRequest):
"/v1/chat/completions",
dependencies=[Depends(verify_api_key), Depends(check_rate_limit)],
)
-async def create_chat_completion(request: ChatCompletionRequest):
+async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
"""
Create a chat completion (supports multimodal content for VLM models).
@@ -1014,6 +1289,27 @@ async def create_chat_completion(request: ChatCompletionRequest):
"""
engine = get_engine()
+ # --- Detailed request logging ---
+ n_msgs = len(request.messages)
+ msg_roles = [m.role for m in request.messages]
+ total_chars = 0
+ last_user_preview = ""
+ for m in request.messages:
+ content = m.content if isinstance(m.content, str) else str(m.content)
+ total_chars += len(content)
+ if m.role == "user":
+ last_user_preview = content[:300]
+ has_tools = bool(request.tools)
+ n_tools = len(request.tools) if request.tools else 0
+ logger.info(
+ f"[REQUEST] POST /v1/chat/completions stream={request.stream} "
+ f"model={request.model!r} max_tokens={request.max_tokens} "
+ f"temp={request.temperature} msgs={n_msgs} roles={msg_roles} "
+ f"total_chars={total_chars} tools={n_tools} "
+ f"response_format={request.response_format}"
+ )
+ logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}")
+
# For MLLM models, keep original messages with embedded images
# (MLLM.chat() extracts images from message content internally)
if engine.is_mllm:
@@ -1063,7 +1359,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
return StreamingResponse(
- stream_chat_completion(engine, messages, request, **chat_kwargs),
+ _disconnect_guard(
+ stream_chat_completion(engine, messages, request, **chat_kwargs),
+ raw_request,
+ ),
media_type="text/event-stream",
)
@@ -1071,14 +1370,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
start_time = time.perf_counter()
timeout = request.timeout or _default_timeout
- try:
- output = await asyncio.wait_for(
- engine.chat(messages=messages, **chat_kwargs), timeout=timeout
- )
- except asyncio.TimeoutError:
- raise HTTPException(
- status_code=504, detail=f"Request timed out after {timeout:.1f} seconds"
- )
+ output = await _wait_with_disconnect(
+ engine.chat(messages=messages, **chat_kwargs),
+ raw_request,
+ timeout=timeout,
+ )
+ if output is None:
+ return Response(status_code=499) # Client closed request
elapsed = time.perf_counter() - start_time
tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0
@@ -1163,6 +1461,349 @@ def _inject_json_instruction(messages: list, instruction: str) -> list:
return messages
+# =============================================================================
+# Anthropic Messages API Endpoints
+# =============================================================================
+
+
+@app.post("/v1/messages")
+async def create_anthropic_message(
+ request: Request,
+):
+ """
+ Anthropic Messages API endpoint.
+
+ Translates Anthropic-format requests to OpenAI format, runs inference
+ through the existing engine, and converts the response back.
+
+ Supports both streaming and non-streaming modes.
+ """
+ engine = get_engine()
+
+ # Parse the raw body to handle Anthropic request format
+ body = await request.json()
+ anthropic_request = AnthropicRequest(**body)
+
+ # --- Detailed request logging ---
+ n_msgs = len(anthropic_request.messages)
+ total_chars = 0
+ last_user_preview = ""
+ for m in anthropic_request.messages:
+ content = m.content if isinstance(m.content, str) else str(m.content)
+ total_chars += len(content)
+ if m.role == "user":
+ last_user_preview = content[:300]
+ sys_chars = len(anthropic_request.system) if anthropic_request.system else 0
+ n_tools = len(anthropic_request.tools) if anthropic_request.tools else 0
+ logger.info(
+ f"[REQUEST] POST /v1/messages (anthropic) stream={anthropic_request.stream} "
+ f"model={anthropic_request.model!r} max_tokens={anthropic_request.max_tokens} "
+ f"msgs={n_msgs} total_chars={total_chars} system_chars={sys_chars} "
+ f"tools={n_tools}"
+ )
+ logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}")
+
+ # Convert Anthropic request -> OpenAI request
+ openai_request = anthropic_to_openai(anthropic_request)
+
+ if anthropic_request.stream:
+ return StreamingResponse(
+ _disconnect_guard(
+ _stream_anthropic_messages(engine, openai_request, anthropic_request),
+ request,
+ ),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+
+ # Non-streaming: run inference through existing engine
+ messages, images, videos = extract_multimodal_content(
+ openai_request.messages,
+ preserve_native_format=engine.preserve_native_tool_format,
+ )
+
+ chat_kwargs = {
+ "max_tokens": openai_request.max_tokens or _default_max_tokens,
+ "temperature": openai_request.temperature,
+ "top_p": openai_request.top_p,
+ }
+
+ if openai_request.tools:
+ chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools)
+
+ start_time = time.perf_counter()
+ timeout = _default_timeout
+
+ output = await _wait_with_disconnect(
+ engine.chat(messages=messages, **chat_kwargs),
+ request,
+ timeout=timeout,
+ )
+ if output is None:
+ return Response(status_code=499) # Client closed request
+
+ elapsed = time.perf_counter() - start_time
+ tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0
+ logger.info(
+ f"Anthropic messages: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
+ )
+
+ # Parse tool calls
+ cleaned_text, tool_calls = _parse_tool_calls_with_parser(
+ output.text, openai_request
+ )
+
+ # Clean output text
+ final_content = None
+ if cleaned_text:
+ final_content = clean_output_text(cleaned_text)
+
+ # Determine finish reason
+ finish_reason = "tool_calls" if tool_calls else output.finish_reason
+
+ # Build OpenAI response to convert
+ openai_response = ChatCompletionResponse(
+ model=openai_request.model,
+ choices=[
+ ChatCompletionChoice(
+ message=AssistantMessage(
+ content=final_content,
+ tool_calls=tool_calls,
+ ),
+ finish_reason=finish_reason,
+ )
+ ],
+ usage=Usage(
+ prompt_tokens=output.prompt_tokens,
+ completion_tokens=output.completion_tokens,
+ total_tokens=output.prompt_tokens + output.completion_tokens,
+ ),
+ )
+
+ # Convert to Anthropic response
+ anthropic_response = openai_to_anthropic(openai_response, anthropic_request.model)
+ return Response(
+ content=anthropic_response.model_dump_json(exclude_none=True),
+ media_type="application/json",
+ )
+
+
+@app.post("/v1/messages/count_tokens")
+async def count_anthropic_tokens(request: Request):
+ """
+ Count tokens for an Anthropic Messages API request.
+
+ Uses the model's tokenizer for accurate counting.
+ Claude Code calls this endpoint for token budgeting.
+ Note: Don't parse via AnthropicRequest — count_tokens requests
+ from Claude Code don't include max_tokens.
+ """
+ body = await request.json()
+
+ engine = get_engine()
+ tokenizer = engine.tokenizer
+
+ total_tokens = 0
+
+ # System message
+ system = body.get("system", "")
+ if isinstance(system, str) and system:
+ total_tokens += len(tokenizer.encode(system))
+ elif isinstance(system, list):
+ for block in system:
+ if isinstance(block, dict):
+ text = block.get("text", "")
+ if text:
+ total_tokens += len(tokenizer.encode(text))
+
+ # Messages
+ for msg in body.get("messages", []):
+ content = msg.get("content", "")
+ if isinstance(content, str):
+ if content:
+ total_tokens += len(tokenizer.encode(content))
+ elif isinstance(content, list):
+ for block in content:
+ if isinstance(block, dict):
+ text = block.get("text", "")
+ if text:
+ total_tokens += len(tokenizer.encode(text))
+ # tool_use input
+ if block.get("input"):
+ total_tokens += len(
+ tokenizer.encode(json.dumps(block["input"]))
+ )
+ # tool_result content
+ sub_content = block.get("content", "")
+ if isinstance(sub_content, str) and sub_content:
+ total_tokens += len(tokenizer.encode(sub_content))
+ elif isinstance(sub_content, list):
+ for item in sub_content:
+ if isinstance(item, dict):
+ item_text = item.get("text", "")
+ if item_text:
+ total_tokens += len(tokenizer.encode(item_text))
+
+ # Tools
+ for tool in body.get("tools", []):
+ name = tool.get("name", "")
+ if name:
+ total_tokens += len(tokenizer.encode(name))
+ desc = tool.get("description", "")
+ if desc:
+ total_tokens += len(tokenizer.encode(desc))
+ if tool.get("input_schema"):
+ total_tokens += len(tokenizer.encode(json.dumps(tool["input_schema"])))
+
+ return {"input_tokens": total_tokens}
+
+
+async def _stream_anthropic_messages(
+ engine: BaseEngine,
+ openai_request: ChatCompletionRequest,
+ anthropic_request: AnthropicRequest,
+) -> AsyncIterator[str]:
+ """
+ Stream Anthropic Messages API SSE events.
+
+ Converts OpenAI streaming chunks to Anthropic event format:
+ message_start -> content_block_start -> content_block_delta* ->
+ content_block_stop -> message_delta -> message_stop
+ """
+ msg_id = f"msg_{uuid.uuid4().hex[:24]}"
+ start_time = time.perf_counter()
+
+ # Extract messages for engine
+ messages, images, videos = extract_multimodal_content(
+ openai_request.messages,
+ preserve_native_format=engine.preserve_native_tool_format,
+ )
+
+ chat_kwargs = {
+ "max_tokens": openai_request.max_tokens or _default_max_tokens,
+ "temperature": openai_request.temperature,
+ "top_p": openai_request.top_p,
+ }
+
+ if openai_request.tools:
+ chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools)
+
+ # Emit message_start
+ message_start = {
+ "type": "message_start",
+ "message": {
+ "id": msg_id,
+ "type": "message",
+ "role": "assistant",
+ "model": anthropic_request.model,
+ "content": [],
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {
+ "input_tokens": 0,
+ "output_tokens": 0,
+ },
+ },
+ }
+ yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
+
+ # Emit content_block_start for text
+ content_block_start = {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "text", "text": ""},
+ }
+ yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
+
+ # Stream content deltas
+ accumulated_text = ""
+ completion_tokens = 0
+
+ async for output in engine.stream_chat(messages=messages, **chat_kwargs):
+ delta_text = output.new_text
+
+ # Track token counts
+ if hasattr(output, "completion_tokens") and output.completion_tokens:
+ completion_tokens = output.completion_tokens
+
+ if delta_text:
+ # Filter special tokens
+ content = SPECIAL_TOKENS_PATTERN.sub("", delta_text)
+
+ if content:
+ accumulated_text += content
+ delta_event = {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "text_delta", "text": content},
+ }
+ yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
+
+ # Check for tool calls in accumulated text
+ _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request)
+
+ # Emit content_block_stop for text block
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
+
+ # If there are tool calls, emit tool_use blocks
+ if tool_calls:
+ for i, tc in enumerate(tool_calls):
+ tool_index = i + 1
+ try:
+ tool_input = json.loads(tc.function.arguments)
+ except (json.JSONDecodeError, AttributeError):
+ tool_input = {}
+
+ # content_block_start for tool_use
+ tool_block_start = {
+ "type": "content_block_start",
+ "index": tool_index,
+ "content_block": {
+ "type": "tool_use",
+ "id": tc.id,
+ "name": tc.function.name,
+ "input": {},
+ },
+ }
+ yield f"event: content_block_start\ndata: {json.dumps(tool_block_start)}\n\n"
+
+ # Send input as a single delta
+ input_json = json.dumps(tool_input)
+ input_delta = {
+ "type": "content_block_delta",
+ "index": tool_index,
+ "delta": {"type": "input_json_delta", "partial_json": input_json},
+ }
+ yield f"event: content_block_delta\ndata: {json.dumps(input_delta)}\n\n"
+
+ # content_block_stop
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': tool_index})}\n\n"
+
+ # Determine stop reason
+ stop_reason = "tool_use" if tool_calls else "end_turn"
+
+ # Emit message_delta with stop_reason and usage
+ message_delta = {
+ "type": "message_delta",
+ "delta": {"stop_reason": stop_reason, "stop_sequence": None},
+ "usage": {"output_tokens": completion_tokens},
+ }
+ yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
+
+ # Log throughput
+ elapsed = time.perf_counter() - start_time
+ tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0
+ logger.info(
+ f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
+ )
+
+ # Emit message_stop
+ yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n"
+
+
# =============================================================================
# Streaming Helpers
# =============================================================================
@@ -1209,6 +1850,7 @@ async def stream_chat_completion(
) -> AsyncIterator[str]:
"""Stream chat completion response."""
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
+ start_time = time.perf_counter()
# Check if we should include usage in the final chunk
include_usage = request.stream_options and request.stream_options.include_usage
@@ -1242,6 +1884,28 @@ async def stream_chat_completion(
completion_tokens = 0
last_output = None
+ # Tool call streaming state
+ global _tool_parser_instance
+ tool_parser = None
+ tool_accumulated_text = ""
+ tool_calls_detected = False
+ tool_markup_possible = False # Fast path: skip parsing until '<' seen
+ if _enable_auto_tool_choice and _tool_call_parser:
+ # Initialize parser if needed (same as _parse_tool_calls_with_parser)
+ if _tool_parser_instance is None:
+ try:
+ parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
+ tokenizer = None
+ if _engine is not None and hasattr(_engine, "_tokenizer"):
+ tokenizer = _engine._tokenizer
+ _tool_parser_instance = parser_cls(tokenizer)
+ logger.info(f"Initialized tool call parser: {_tool_call_parser}")
+ except Exception as e:
+ logger.warning(f"Failed to init tool parser for streaming: {e}")
+ if _tool_parser_instance is not None:
+ tool_parser = _tool_parser_instance
+ tool_parser.reset()
+
# Stream content
async for output in engine.stream_chat(messages=messages, **kwargs):
delta_text = output.new_text
@@ -1284,11 +1948,60 @@ async def stream_chat_completion(
# Standard path without reasoning parsing
content = delta_text
+ # Filter special tokens that may leak into streaming output
+ if content:
+ content = SPECIAL_TOKENS_PATTERN.sub("", content)
+
# Add prefix on first content chunk for thinking models
if is_thinking_model and not think_prefix_sent and content:
content = "" + content
think_prefix_sent = True
+ # Tool call streaming parsing
+ if tool_parser and delta_text:
+ # Fast path: skip full parsing until '<' is seen in the stream,
+ # which could start tool markup (e.g. ). This avoids
+ # per-token string scanning on the growing accumulated text.
+ if not tool_markup_possible and "<" not in delta_text:
+ tool_accumulated_text += delta_text
+ # No tool markup yet, fall through to normal chunk emission
+ else:
+ if not tool_markup_possible:
+ tool_markup_possible = True
+ tool_previous = tool_accumulated_text
+ tool_accumulated_text += delta_text
+ tool_result = tool_parser.extract_tool_calls_streaming(
+ tool_previous, tool_accumulated_text, delta_text
+ )
+
+ if tool_result is None:
+ # Inside tool markup - suppress output
+ continue
+
+ if "tool_calls" in tool_result:
+ # Emit structured tool calls
+ tool_calls_detected = True
+ chunk = ChatCompletionChunk(
+ id=response_id,
+ model=request.model,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(
+ tool_calls=tool_result["tool_calls"]
+ ),
+ finish_reason=(
+ "tool_calls" if output.finished else None
+ ),
+ )
+ ],
+ usage=get_usage(output) if output.finished else None,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+ continue
+
+ # Normal content from tool parser
+ content = tool_result.get("content", "")
+
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
@@ -1297,13 +2010,59 @@ async def stream_chat_completion(
delta=ChatCompletionChunkDelta(
content=content if content else None
),
- finish_reason=output.finish_reason if output.finished else None,
+ finish_reason=(
+ "tool_calls"
+ if (output.finished and tool_calls_detected)
+ else (output.finish_reason if output.finished else None)
+ ),
)
],
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"
+ # Fallback: if tool parser accumulated text but never emitted tool_calls
+ # (e.g., never arrived - incomplete tool call)
+ if (
+ tool_parser
+ and tool_accumulated_text
+ and not tool_calls_detected
+ and "" in tool_accumulated_text
+ ):
+ result = tool_parser.extract_tool_calls(tool_accumulated_text)
+ if result.tools_called:
+ tool_chunk = ChatCompletionChunk(
+ id=response_id,
+ model=request.model,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(
+ tool_calls=[
+ {
+ "index": i,
+ "id": tc["id"],
+ "type": "function",
+ "function": {
+ "name": tc["name"],
+ "arguments": tc["arguments"],
+ },
+ }
+ for i, tc in enumerate(result.tool_calls)
+ ]
+ ),
+ finish_reason="tool_calls",
+ )
+ ],
+ )
+ yield f"data: {tool_chunk.model_dump_json()}\n\n"
+
+ # Log throughput
+ elapsed = time.perf_counter() - start_time
+ tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0
+ logger.info(
+ f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
+ )
+
# Send final chunk with usage if requested
if include_usage:
usage_chunk = ChatCompletionChunk(
diff --git a/vllm_mlx/tool_parsers/hermes_tool_parser.py b/vllm_mlx/tool_parsers/hermes_tool_parser.py
index 0605e640..da91a816 100644
--- a/vllm_mlx/tool_parsers/hermes_tool_parser.py
+++ b/vllm_mlx/tool_parsers/hermes_tool_parser.py
@@ -36,12 +36,23 @@ class HermesToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser hermes are set.
"""
+ # Qwen3 / Hermes chat templates handle role="tool" and tool_calls natively.
+ # Without this, tool history is converted to "[Calling tool: ...]" text,
+ # which causes the model to mimic that text format instead of producing
+ # proper XML after a few rounds of tool use.
+ SUPPORTS_NATIVE_TOOL_FORMAT = True
+
# Standard format: {"name": ..., "arguments": ...}
TOOL_CALL_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL)
# Lenient format: followed by JSON (handles malformed tags)
TOOL_CALL_LENIENT_PATTERN = re.compile(
r'v
+ NEMOTRON_PATTERN = re.compile(
+ r"\s*]+)>(.*?)\s*", re.DOTALL
+ )
+ PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL)
REASONING_PATTERN = re.compile(
r"(.*?)", re.DOTALL
)
@@ -91,7 +102,29 @@ def extract_tool_calls(
if matches:
cleaned_text = self.TOOL_CALL_PATTERN.sub("", cleaned_text).strip()
- # Fallback 1: try lenient pattern for malformed tags like
+ # Try Nemotron XML format if no JSON tool calls found
+ if not tool_calls:
+ nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text)
+ for name, params_block in nemotron_matches:
+ params = self.PARAM_PATTERN.findall(params_block)
+ arguments = {}
+ for p_name, p_value in params:
+ val = p_value.strip()
+ try:
+ arguments[p_name.strip()] = json.loads(val)
+ except (json.JSONDecodeError, ValueError):
+ arguments[p_name.strip()] = val
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": name.strip(),
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ }
+ )
+ if nemotron_matches:
+ cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip()
+
+ # Fallback: try lenient pattern for malformed tags like
if not tool_calls:
lenient_matches = self.TOOL_CALL_LENIENT_PATTERN.findall(cleaned_text)
for match in lenient_matches[:1]: # Only first to avoid hallucinations
@@ -117,16 +150,14 @@ def extract_tool_calls(
except json.JSONDecodeError:
continue
- # Fallback 2: try raw JSON format if no tagged tool calls found
+ # Fallback: try raw JSON format if no tagged tool calls found
# Only parse the FIRST valid tool call to avoid hallucinated multiple calls
if not tool_calls:
raw_matches = self.RAW_JSON_TOOL_PATTERN.findall(cleaned_text)
if raw_matches:
- # Only take the first match to avoid hallucinated tool calls
name, args_str = raw_matches[0]
try:
arguments = json.loads(args_str)
- # Validate: only accept if tool name exists in request tools
valid_tool = True
if request and "tools" in request:
tool_names = [
@@ -144,7 +175,6 @@ def extract_tool_calls(
"arguments": json.dumps(arguments, ensure_ascii=False),
}
)
- # Remove the matched tool call from text
cleaned_text = self.RAW_JSON_TOOL_PATTERN.sub(
"", cleaned_text, count=1
).strip()
diff --git a/vllm_mlx/worker.py b/vllm_mlx/worker.py
index 7c3a9047..02f44573 100644
--- a/vllm_mlx/worker.py
+++ b/vllm_mlx/worker.py
@@ -212,7 +212,7 @@ def shutdown(self) -> None:
try:
import mlx.core as mx
- mx.metal.clear_cache()
+ mx.clear_cache()
except Exception:
pass