diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eedade2d..bc3f19b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,6 +80,8 @@ jobs: tests/test_api_models.py \ tests/test_api_utils.py \ tests/test_request.py \ + tests/test_anthropic_models.py \ + tests/test_anthropic_adapter.py \ -v --tb=short \ -k "not Integration and not InjectJson and not TestMLXMultimodalLMCache" \ --cov=vllm_mlx \ diff --git a/docs/guides/server.md b/docs/guides/server.md index a654e92c..33b7ec02 100644 --- a/docs/guides/server.md +++ b/docs/guides/server.md @@ -128,6 +128,390 @@ GET /health Returns server status. +### Anthropic Messages API + +```bash +POST /v1/messages +``` + +Anthropic-compatible endpoint that allows tools like Claude Code and OpenCode to connect directly to vllm-mlx. Internally it translates Anthropic requests to OpenAI format, runs inference through the engine, and converts the response back to Anthropic format. + +Capabilities: +- Non-streaming and streaming responses (SSE) +- System messages (plain string or list of content blocks) +- Multi-turn conversations with user and assistant messages +- Tool calling with `tool_use` / `tool_result` content blocks +- Token counting for budget tracking +- Multimodal content (images via `source` blocks) +- Client disconnect detection (returns HTTP 499) +- Automatic special token filtering in streamed output + +#### Non-streaming + +```python +from anthropic import Anthropic + +client = Anthropic(base_url="http://localhost:8000", api_key="not-needed") + +response = client.messages.create( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Hello!"}] +) +print(response.content[0].text) +# Response includes: response.id, response.model, response.stop_reason, +# response.usage.input_tokens, response.usage.output_tokens +``` + +#### Streaming + +Streaming follows the Anthropic SSE event protocol. Events are emitted in this order: +`message_start` -> `content_block_start` -> `content_block_delta` (repeated) -> `content_block_stop` -> `message_delta` -> `message_stop` + +```python +with client.messages.stream( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Tell me a story"}] +) as stream: + for text in stream.text_stream: + print(text, end="") +``` + +#### System messages + +System messages can be a plain string or a list of content blocks: + +```python +# Plain string +response = client.messages.create( + model="default", + max_tokens=256, + system="You are a helpful coding assistant.", + messages=[{"role": "user", "content": "Write a hello world in Python"}] +) + +# List of content blocks +response = client.messages.create( + model="default", + max_tokens=256, + system=[ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise in your answers."}, + ], + messages=[{"role": "user", "content": "What is 2+2?"}] +) +``` + +#### Tool calling + +Define tools with `name`, `description`, and `input_schema`. The model returns `tool_use` content blocks when it wants to call a tool. Send results back as `tool_result` blocks. + +```python +# Step 1: Send request with tools +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) + +# Step 2: Check if model wants to use tools +for block in response.content: + if block.type == "tool_use": + print(f"Tool: {block.name}, Input: {block.input}, ID: {block.id}") + # response.stop_reason will be "tool_use" + +# Step 3: Send tool result back +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather in Paris?"}, + {"role": "assistant", "content": response.content}, + {"role": "user", "content": [ + { + "type": "tool_result", + "tool_use_id": block.id, + "content": "Sunny, 22C" + } + ]} + ], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) +print(response.content[0].text) # "The weather in Paris is sunny, 22C." +``` + +Tool choice modes: + +| `tool_choice` | Behavior | +|---------------|----------| +| `{"type": "auto"}` | Model decides whether to call tools (default) | +| `{"type": "any"}` | Model must call at least one tool | +| `{"type": "tool", "name": "get_weather"}` | Model must call the specified tool | +| `{"type": "none"}` | Model will not call any tools | + +#### Multi-turn conversations + +```python +messages = [ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What's my name?"}, +] + +response = client.messages.create( + model="default", + max_tokens=100, + messages=messages +) +``` + +#### Token counting + +```bash +POST /v1/messages/count_tokens +``` + +Counts input tokens for an Anthropic request using the model's tokenizer. Useful for budget tracking before sending a request. Counts tokens from system messages, conversation messages, tool_use inputs, tool_result content, and tool definitions (name, description, input_schema). + +```python +import requests + +resp = requests.post("http://localhost:8000/v1/messages/count_tokens", json={ + "model": "default", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "system": "You are helpful.", + "tools": [{ + "name": "search", + "description": "Search the web", + "input_schema": {"type": "object", "properties": {"q": {"type": "string"}}} + }] +}) +print(resp.json()) # {"input_tokens": 42} +``` + +#### curl examples + +Non-streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +Streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "stream": true, + "messages": [{"role": "user", "content": "Tell me a joke"}] + }' +``` + +Token counting: + +```bash +curl http://localhost:8000/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}] + }' +# {"input_tokens": 12} +``` + +#### Request fields + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `model` | string | yes | - | Model name (use `"default"` for the loaded model) | +| `messages` | list | yes | - | Conversation messages with `role` and `content` | +| `max_tokens` | int | yes | - | Maximum number of tokens to generate | +| `system` | string or list | no | null | System prompt (string or list of `{"type": "text", "text": "..."}` blocks) | +| `stream` | bool | no | false | Enable SSE streaming | +| `temperature` | float | no | 0.7 | Sampling temperature (0.0 = deterministic, 1.0 = creative) | +| `top_p` | float | no | 0.9 | Nucleus sampling threshold | +| `top_k` | int | no | null | Top-k sampling | +| `stop_sequences` | list | no | null | Sequences that stop generation | +| `tools` | list | no | null | Tool definitions with `name`, `description`, `input_schema` | +| `tool_choice` | dict | no | null | Tool selection mode (`auto`, `any`, `tool`, `none`) | +| `metadata` | dict | no | null | Arbitrary metadata (passed through, not used by server) | + +#### Response format + +Non-streaming response: + +```json +{ + "id": "msg_abc123...", + "type": "message", + "role": "assistant", + "model": "default", + "content": [ + {"type": "text", "text": "Hello! How can I help?"} + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 12, + "output_tokens": 8 + } +} +``` + +When tools are called, `content` includes `tool_use` blocks and `stop_reason` is `"tool_use"`: + +```json +{ + "content": [ + {"type": "text", "text": "Let me check the weather."}, + { + "type": "tool_use", + "id": "call_abc123", + "name": "get_weather", + "input": {"city": "Paris"} + } + ], + "stop_reason": "tool_use" +} +``` + +Stop reasons: + +| `stop_reason` | Meaning | +|---------------|---------| +| `end_turn` | Model finished naturally | +| `tool_use` | Model wants to call a tool | +| `max_tokens` | Hit the `max_tokens` limit | + +#### Using with Claude Code + +Point Claude Code directly at your vllm-mlx server: + +```bash +# Start the server +vllm-mlx serve mlx-community/Qwen3-Coder-Next-235B-A22B-4bit \ + --continuous-batching \ + --enable-auto-tool-choice \ + --tool-call-parser hermes + +# In another terminal, configure Claude Code +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### Server Status + +```bash +GET /v1/status +``` + +Real-time monitoring endpoint that returns server-wide statistics and per-request details. Useful for debugging performance, tracking cache efficiency, and monitoring Metal GPU memory. + +```bash +curl -s http://localhost:8000/v1/status | python -m json.tool +``` + +Example response: + +```json +{ + "status": "running", + "model": "mlx-community/Qwen3-8B-4bit", + "uptime_s": 342.5, + "steps_executed": 1247, + "num_running": 1, + "num_waiting": 0, + "total_requests_processed": 15, + "total_prompt_tokens": 28450, + "total_completion_tokens": 3200, + "metal": { + "active_memory_gb": 5.2, + "peak_memory_gb": 8.1, + "cache_memory_gb": 2.3 + }, + "cache": { + "type": "memory_aware_cache", + "entries": 5, + "hit_rate": 0.87, + "memory_mb": 2350 + }, + "requests": [ + { + "request_id": "req_abc123", + "phase": "generation", + "tokens_per_second": 45.2, + "ttft_s": 0.8, + "progress": 0.35, + "cache_hit_type": "prefix", + "cached_tokens": 1200, + "generated_tokens": 85, + "max_tokens": 256 + } + ] +} +``` + +Response fields: + +| Field | Description | +|-------|-------------| +| `status` | Server state: `running`, `stopped`, or `not_loaded` | +| `model` | Name of the loaded model | +| `uptime_s` | Seconds since the server started | +| `steps_executed` | Total inference steps executed | +| `num_running` | Number of requests currently generating tokens | +| `num_waiting` | Number of requests queued for prefill | +| `total_requests_processed` | Total requests completed since startup | +| `total_prompt_tokens` | Total prompt tokens processed since startup | +| `total_completion_tokens` | Total completion tokens generated since startup | +| `metal.active_memory_gb` | Current Metal GPU memory in use (GB) | +| `metal.peak_memory_gb` | Peak Metal GPU memory usage (GB) | +| `metal.cache_memory_gb` | Metal cache memory usage (GB) | +| `cache` | Cache statistics (type, entries, hit rate, memory usage) | +| `requests` | List of active requests with per-request details | + +Per-request fields in `requests`: + +| Field | Description | +|-------|-------------| +| `request_id` | Unique request identifier | +| `phase` | Current phase: `queued`, `prefill`, or `generation` | +| `tokens_per_second` | Generation throughput for this request | +| `ttft_s` | Time to first token (seconds) | +| `progress` | Completion percentage (0.0 to 1.0) | +| `cache_hit_type` | Cache match type: `exact`, `prefix`, `supersequence`, `lcp`, or `miss` | +| `cached_tokens` | Number of tokens served from cache | +| `generated_tokens` | Tokens generated so far | +| `max_tokens` | Maximum tokens requested | + ## Tool Calling Enable OpenAI-compatible tool calling with `--enable-auto-tool-choice`: diff --git a/tests/test_anthropic_adapter.py b/tests/test_anthropic_adapter.py new file mode 100644 index 00000000..3fa179b4 --- /dev/null +++ b/tests/test_anthropic_adapter.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for Anthropic-to-OpenAI adapter conversion functions. + +Tests all conversion functions in vllm_mlx/api/anthropic_adapter.py. +These are pure logic tests with no MLX dependency. +""" + +import json + +from vllm_mlx.api.anthropic_adapter import ( + _convert_message, + _convert_stop_reason, + _convert_tool, + _convert_tool_choice, + anthropic_to_openai, + openai_to_anthropic, +) +from vllm_mlx.api.anthropic_models import ( + AnthropicContentBlock, + AnthropicMessage, + AnthropicRequest, + AnthropicToolDef, +) +from vllm_mlx.api.models import ( + AssistantMessage, + ChatCompletionChoice, + ChatCompletionResponse, + FunctionCall, + ToolCall, + Usage, +) + + +class TestConvertStopReason: + """Tests for _convert_stop_reason.""" + + def test_stop_to_end_turn(self): + assert _convert_stop_reason("stop") == "end_turn" + + def test_tool_calls_to_tool_use(self): + assert _convert_stop_reason("tool_calls") == "tool_use" + + def test_length_to_max_tokens(self): + assert _convert_stop_reason("length") == "max_tokens" + + def test_content_filter_to_end_turn(self): + assert _convert_stop_reason("content_filter") == "end_turn" + + def test_none_to_end_turn(self): + assert _convert_stop_reason(None) == "end_turn" + + def test_unknown_to_end_turn(self): + assert _convert_stop_reason("something_else") == "end_turn" + + +class TestConvertToolChoice: + """Tests for _convert_tool_choice.""" + + def test_auto(self): + assert _convert_tool_choice({"type": "auto"}) == "auto" + + def test_any_to_required(self): + assert _convert_tool_choice({"type": "any"}) == "required" + + def test_none_type(self): + assert _convert_tool_choice({"type": "none"}) == "none" + + def test_specific_tool(self): + result = _convert_tool_choice({"type": "tool", "name": "search"}) + assert result == { + "type": "function", + "function": {"name": "search"}, + } + + def test_missing_type_defaults_to_auto(self): + assert _convert_tool_choice({}) == "auto" + + def test_unknown_type_defaults_to_auto(self): + assert _convert_tool_choice({"type": "unknown"}) == "auto" + + +class TestConvertTool: + """Tests for _convert_tool.""" + + def test_minimal_tool(self): + tool = AnthropicToolDef(name="search") + result = _convert_tool(tool) + assert result.type == "function" + assert result.function["name"] == "search" + assert result.function["description"] == "" + assert result.function["parameters"] == {"type": "object", "properties": {}} + + def test_full_tool(self): + tool = AnthropicToolDef( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ) + result = _convert_tool(tool) + assert result.function["name"] == "get_weather" + assert result.function["description"] == "Get weather for a city" + assert result.function["parameters"]["required"] == ["city"] + + +class TestConvertMessage: + """Tests for _convert_message.""" + + def test_simple_user_string(self): + msg = AnthropicMessage(role="user", content="hello") + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].content == "hello" + + def test_simple_assistant_string(self): + msg = AnthropicMessage(role="assistant", content="hi there") + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].content == "hi there" + + def test_user_with_text_blocks(self): + msg = AnthropicMessage( + role="user", + content=[ + AnthropicContentBlock(type="text", text="first"), + AnthropicContentBlock(type="text", text="second"), + ], + ) + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].content == "first\nsecond" + + def test_user_with_tool_results(self): + msg = AnthropicMessage( + role="user", + content=[ + AnthropicContentBlock( + type="tool_result", + tool_use_id="call_1", + content="sunny, 22C", + ), + AnthropicContentBlock( + type="tool_result", + tool_use_id="call_2", + content="rainy, 15C", + ), + ], + ) + result = _convert_message(msg) + assert len(result) == 2 + assert result[0].role == "tool" + assert result[0].content == "sunny, 22C" + assert result[0].tool_call_id == "call_1" + assert result[1].role == "tool" + assert result[1].content == "rainy, 15C" + + def test_user_with_text_and_tool_results(self): + msg = AnthropicMessage( + role="user", + content=[ + AnthropicContentBlock(type="text", text="here are results"), + AnthropicContentBlock( + type="tool_result", + tool_use_id="call_1", + content="done", + ), + ], + ) + result = _convert_message(msg) + assert len(result) == 2 + assert result[0].role == "user" + assert result[0].content == "here are results" + assert result[1].role == "tool" + + def test_tool_result_with_list_content(self): + msg = AnthropicMessage( + role="user", + content=[ + AnthropicContentBlock( + type="tool_result", + tool_use_id="call_1", + content=[ + {"type": "text", "text": "line one"}, + {"type": "text", "text": "line two"}, + ], + ), + ], + ) + result = _convert_message(msg) + assert result[0].role == "tool" + assert result[0].content == "line one\nline two" + + def test_tool_result_with_none_content(self): + msg = AnthropicMessage( + role="user", + content=[ + AnthropicContentBlock( + type="tool_result", + tool_use_id="call_1", + content=None, + ), + ], + ) + result = _convert_message(msg) + assert result[0].content == "" + + def test_assistant_with_tool_use(self): + msg = AnthropicMessage( + role="assistant", + content=[ + AnthropicContentBlock(type="text", text="Let me check."), + AnthropicContentBlock( + type="tool_use", + id="call_abc", + name="search", + input={"q": "weather"}, + ), + ], + ) + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].content == "Let me check." + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0]["function"]["name"] == "search" + args = json.loads(result[0].tool_calls[0]["function"]["arguments"]) + assert args == {"q": "weather"} + + def test_assistant_empty_content(self): + msg = AnthropicMessage( + role="assistant", + content=[], + ) + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].content == "" + + def test_user_empty_content(self): + msg = AnthropicMessage( + role="user", + content=[], + ) + result = _convert_message(msg) + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].content == "" + + +class TestAnthropicToOpenai: + """Tests for anthropic_to_openai conversion.""" + + def _make_request(self, **kwargs): + defaults = { + "model": "default", + "messages": [AnthropicMessage(role="user", content="hi")], + "max_tokens": 100, + } + defaults.update(kwargs) + return AnthropicRequest(**defaults) + + def test_simple_request(self): + req = self._make_request() + result = anthropic_to_openai(req) + assert result.model == "default" + assert result.max_tokens == 100 + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert result.messages[0].content == "hi" + + def test_system_string(self): + req = self._make_request(system="Be helpful.") + result = anthropic_to_openai(req) + assert len(result.messages) == 2 + assert result.messages[0].role == "system" + assert result.messages[0].content == "Be helpful." + assert result.messages[1].role == "user" + + def test_system_list(self): + req = self._make_request(system=[{"type": "text", "text": "Be concise."}]) + result = anthropic_to_openai(req) + assert result.messages[0].role == "system" + assert result.messages[0].content == "Be concise." + + def test_temperature_default(self): + req = self._make_request() + result = anthropic_to_openai(req) + assert result.temperature == 0.7 + + def test_temperature_explicit(self): + req = self._make_request(temperature=0.3) + result = anthropic_to_openai(req) + assert result.temperature == 0.3 + + def test_top_p_default(self): + req = self._make_request() + result = anthropic_to_openai(req) + assert result.top_p == 0.9 + + def test_top_p_explicit(self): + req = self._make_request(top_p=0.5) + result = anthropic_to_openai(req) + assert result.top_p == 0.5 + + def test_stop_sequences(self): + req = self._make_request(stop_sequences=["END", "STOP"]) + result = anthropic_to_openai(req) + assert result.stop == ["END", "STOP"] + + def test_stream_flag(self): + req = self._make_request(stream=True) + result = anthropic_to_openai(req) + assert result.stream is True + + def test_tools_conversion(self): + req = self._make_request( + tools=[ + AnthropicToolDef( + name="search", + description="Search the web", + input_schema={ + "type": "object", + "properties": {"q": {"type": "string"}}, + }, + ) + ] + ) + result = anthropic_to_openai(req) + assert len(result.tools) == 1 + assert result.tools[0].function["name"] == "search" + + def test_tool_choice_conversion(self): + req = self._make_request(tool_choice={"type": "any"}) + result = anthropic_to_openai(req) + assert result.tool_choice == "required" + + def test_no_tools(self): + req = self._make_request() + result = anthropic_to_openai(req) + assert result.tools is None + assert result.tool_choice is None + + def test_multiple_messages(self): + msgs = [ + AnthropicMessage(role="user", content="hello"), + AnthropicMessage(role="assistant", content="hi"), + AnthropicMessage(role="user", content="how are you"), + ] + req = self._make_request(messages=msgs) + result = anthropic_to_openai(req) + assert len(result.messages) == 3 + assert result.messages[0].role == "user" + assert result.messages[1].role == "assistant" + assert result.messages[2].role == "user" + + +class TestOpenaiToAnthropic: + """Tests for openai_to_anthropic conversion.""" + + def _make_response(self, content="hello", finish_reason="stop", tool_calls=None): + msg = AssistantMessage(content=content, tool_calls=tool_calls) + choice = ChatCompletionChoice(message=msg, finish_reason=finish_reason) + return ChatCompletionResponse( + model="default", + choices=[choice], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + def test_simple_text_response(self): + resp = self._make_response(content="hi there") + result = openai_to_anthropic(resp, "default") + assert result.model == "default" + assert result.type == "message" + assert result.role == "assistant" + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "hi there" + assert result.stop_reason == "end_turn" + + def test_usage_mapping(self): + resp = self._make_response() + result = openai_to_anthropic(resp, "default") + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 5 + + def test_tool_calls_response(self): + tc = ToolCall( + id="call_1", + type="function", + function=FunctionCall( + name="search", + arguments='{"q": "test"}', + ), + ) + resp = self._make_response( + content="Let me search.", + finish_reason="tool_calls", + tool_calls=[tc], + ) + result = openai_to_anthropic(resp, "default") + assert len(result.content) == 2 + assert result.content[0].type == "text" + assert result.content[0].text == "Let me search." + assert result.content[1].type == "tool_use" + assert result.content[1].name == "search" + assert result.content[1].input == {"q": "test"} + assert result.stop_reason == "tool_use" + + def test_tool_call_invalid_json_arguments(self): + tc = ToolCall( + id="call_1", + type="function", + function=FunctionCall(name="search", arguments="not json"), + ) + resp = self._make_response( + content=None, finish_reason="tool_calls", tool_calls=[tc] + ) + result = openai_to_anthropic(resp, "default") + tool_block = [b for b in result.content if b.type == "tool_use"][0] + assert tool_block.input == {} + + def test_empty_choices(self): + resp = ChatCompletionResponse( + model="default", + choices=[], + usage=Usage(), + ) + result = openai_to_anthropic(resp, "default") + assert result.stop_reason == "end_turn" + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "" + + def test_no_content_adds_empty_text(self): + resp = self._make_response(content=None) + result = openai_to_anthropic(resp, "default") + assert len(result.content) >= 1 + has_text = any(b.type == "text" for b in result.content) + assert has_text + + def test_stop_reason_length(self): + resp = self._make_response(finish_reason="length") + result = openai_to_anthropic(resp, "default") + assert result.stop_reason == "max_tokens" + + def test_response_has_id(self): + resp = self._make_response() + result = openai_to_anthropic(resp, "test-model") + assert result.id.startswith("msg_") + assert result.model == "test-model" diff --git a/tests/test_anthropic_models.py b/tests/test_anthropic_models.py new file mode 100644 index 00000000..5be0f23d --- /dev/null +++ b/tests/test_anthropic_models.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for Anthropic Messages API Pydantic models. + +Tests all request/response models in vllm_mlx/api/anthropic_models.py. +These are pure Pydantic models with no MLX dependency. +""" + +import pytest +from pydantic import ValidationError + +from vllm_mlx.api.anthropic_models import ( + AnthropicContentBlock, + AnthropicMessage, + AnthropicRequest, + AnthropicResponse, + AnthropicResponseContentBlock, + AnthropicToolDef, + AnthropicUsage, +) + + +class TestAnthropicContentBlock: + """Tests for AnthropicContentBlock model.""" + + def test_text_block(self): + block = AnthropicContentBlock(type="text", text="hello") + assert block.type == "text" + assert block.text == "hello" + + def test_tool_use_block(self): + block = AnthropicContentBlock( + type="tool_use", + id="call_123", + name="get_weather", + input={"city": "Paris"}, + ) + assert block.type == "tool_use" + assert block.id == "call_123" + assert block.name == "get_weather" + assert block.input == {"city": "Paris"} + + def test_tool_result_block(self): + block = AnthropicContentBlock( + type="tool_result", + tool_use_id="call_123", + content="sunny", + ) + assert block.type == "tool_result" + assert block.tool_use_id == "call_123" + assert block.content == "sunny" + + def test_tool_result_with_error(self): + block = AnthropicContentBlock( + type="tool_result", + tool_use_id="call_123", + content="not found", + is_error=True, + ) + assert block.is_error is True + + def test_image_block(self): + block = AnthropicContentBlock( + type="image", + source={"type": "base64", "media_type": "image/png", "data": "abc"}, + ) + assert block.type == "image" + assert block.source["type"] == "base64" + + def test_optional_fields_default_to_none(self): + block = AnthropicContentBlock(type="text") + assert block.text is None + assert block.id is None + assert block.name is None + assert block.input is None + assert block.tool_use_id is None + assert block.content is None + assert block.is_error is None + assert block.source is None + + +class TestAnthropicMessage: + """Tests for AnthropicMessage model.""" + + def test_string_content(self): + msg = AnthropicMessage(role="user", content="hello") + assert msg.role == "user" + assert msg.content == "hello" + + def test_list_content(self): + blocks = [ + AnthropicContentBlock(type="text", text="look at this"), + AnthropicContentBlock( + type="image", + source={"type": "base64", "media_type": "image/png", "data": "abc"}, + ), + ] + msg = AnthropicMessage(role="user", content=blocks) + assert len(msg.content) == 2 + assert msg.content[0].type == "text" + assert msg.content[1].type == "image" + + def test_assistant_role(self): + msg = AnthropicMessage(role="assistant", content="hi there") + assert msg.role == "assistant" + + def test_missing_role_raises(self): + with pytest.raises(ValidationError): + AnthropicMessage(content="hello") + + def test_missing_content_raises(self): + with pytest.raises(ValidationError): + AnthropicMessage(role="user") + + +class TestAnthropicToolDef: + """Tests for AnthropicToolDef model.""" + + def test_minimal(self): + tool = AnthropicToolDef(name="get_weather") + assert tool.name == "get_weather" + assert tool.description is None + assert tool.input_schema is None + + def test_full(self): + tool = AnthropicToolDef( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ) + assert tool.description == "Get weather for a city" + assert tool.input_schema["required"] == ["city"] + + def test_missing_name_raises(self): + with pytest.raises(ValidationError): + AnthropicToolDef(description="no name") + + +class TestAnthropicRequest: + """Tests for AnthropicRequest model.""" + + def test_minimal_request(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + ) + assert req.model == "default" + assert req.max_tokens == 100 + assert req.stream is False + assert req.temperature is None + assert req.top_p is None + assert req.tools is None + assert req.system is None + + def test_with_system_string(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + system="You are helpful.", + ) + assert req.system == "You are helpful." + + def test_with_system_list(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + system=[{"type": "text", "text": "Be concise."}], + ) + assert isinstance(req.system, list) + assert req.system[0]["text"] == "Be concise." + + def test_with_tools(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + tools=[AnthropicToolDef(name="search")], + ) + assert len(req.tools) == 1 + assert req.tools[0].name == "search" + + def test_with_tool_choice(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + tool_choice={"type": "auto"}, + ) + assert req.tool_choice == {"type": "auto"} + + def test_streaming(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + stream=True, + ) + assert req.stream is True + + def test_all_optional_params(self): + req = AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=256, + temperature=0.5, + top_p=0.9, + top_k=40, + stop_sequences=["END"], + metadata={"user_id": "123"}, + ) + assert req.temperature == 0.5 + assert req.top_p == 0.9 + assert req.top_k == 40 + assert req.stop_sequences == ["END"] + assert req.metadata == {"user_id": "123"} + + def test_missing_model_raises(self): + with pytest.raises(ValidationError): + AnthropicRequest( + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + ) + + def test_missing_messages_raises(self): + with pytest.raises(ValidationError): + AnthropicRequest(model="default", max_tokens=100) + + def test_missing_max_tokens_raises(self): + with pytest.raises(ValidationError): + AnthropicRequest( + model="default", + messages=[AnthropicMessage(role="user", content="hi")], + ) + + +class TestAnthropicUsage: + """Tests for AnthropicUsage model.""" + + def test_defaults(self): + usage = AnthropicUsage() + assert usage.input_tokens == 0 + assert usage.output_tokens == 0 + assert usage.cache_creation_input_tokens is None + assert usage.cache_read_input_tokens is None + + def test_with_values(self): + usage = AnthropicUsage(input_tokens=100, output_tokens=50) + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + + def test_with_cache_fields(self): + usage = AnthropicUsage( + input_tokens=100, + output_tokens=50, + cache_creation_input_tokens=20, + cache_read_input_tokens=80, + ) + assert usage.cache_creation_input_tokens == 20 + assert usage.cache_read_input_tokens == 80 + + +class TestAnthropicResponseContentBlock: + """Tests for AnthropicResponseContentBlock model.""" + + def test_text_block(self): + block = AnthropicResponseContentBlock(type="text", text="hello") + assert block.type == "text" + assert block.text == "hello" + + def test_tool_use_block(self): + block = AnthropicResponseContentBlock( + type="tool_use", + id="call_abc", + name="search", + input={"query": "test"}, + ) + assert block.type == "tool_use" + assert block.id == "call_abc" + assert block.name == "search" + assert block.input == {"query": "test"} + + def test_optional_fields_default_to_none(self): + block = AnthropicResponseContentBlock(type="text") + assert block.text is None + assert block.id is None + assert block.name is None + assert block.input is None + + +class TestAnthropicResponse: + """Tests for AnthropicResponse model.""" + + def test_minimal_response(self): + resp = AnthropicResponse( + model="default", + content=[AnthropicResponseContentBlock(type="text", text="hi")], + ) + assert resp.model == "default" + assert resp.type == "message" + assert resp.role == "assistant" + assert resp.id.startswith("msg_") + assert len(resp.id) == len("msg_") + 24 + assert resp.stop_reason is None + assert resp.stop_sequence is None + assert resp.usage.input_tokens == 0 + assert resp.usage.output_tokens == 0 + + def test_with_stop_reason(self): + resp = AnthropicResponse( + model="default", + content=[AnthropicResponseContentBlock(type="text", text="done")], + stop_reason="end_turn", + ) + assert resp.stop_reason == "end_turn" + + def test_with_usage(self): + resp = AnthropicResponse( + model="default", + content=[AnthropicResponseContentBlock(type="text", text="hi")], + usage=AnthropicUsage(input_tokens=10, output_tokens=5), + ) + assert resp.usage.input_tokens == 10 + assert resp.usage.output_tokens == 5 + + def test_unique_ids(self): + r1 = AnthropicResponse( + model="default", + content=[AnthropicResponseContentBlock(type="text", text="a")], + ) + r2 = AnthropicResponse( + model="default", + content=[AnthropicResponseContentBlock(type="text", text="b")], + ) + assert r1.id != r2.id + + def test_tool_use_response(self): + resp = AnthropicResponse( + model="default", + content=[ + AnthropicResponseContentBlock(type="text", text="Let me search."), + AnthropicResponseContentBlock( + type="tool_use", + id="call_1", + name="search", + input={"q": "test"}, + ), + ], + stop_reason="tool_use", + ) + assert len(resp.content) == 2 + assert resp.content[0].type == "text" + assert resp.content[1].type == "tool_use" + assert resp.stop_reason == "tool_use" diff --git a/tests/test_batching.py b/tests/test_batching.py index 10e73c5b..7dc050ee 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -261,7 +261,7 @@ def test_add_duplicate_request(self, mock_model, mock_tokenizer): scheduler.add_request(request) def test_abort_waiting_request(self, mock_model, mock_tokenizer): - """Test aborting a waiting request.""" + """Test aborting a waiting request (deferred abort pattern).""" scheduler = Scheduler( model=mock_model, tokenizer=mock_tokenizer, @@ -276,21 +276,26 @@ def test_abort_waiting_request(self, mock_model, mock_tokenizer): scheduler.add_request(request) assert scheduler.get_num_waiting() == 1 + # abort_request() enqueues for deferred processing result = scheduler.abort_request("test-1") - assert result is True + + # Process pending aborts (normally happens inside step()) + scheduler._process_pending_aborts() + assert scheduler.get_num_waiting() == 0 assert "test-1" in scheduler.finished_req_ids def test_abort_nonexistent_request(self, mock_model, mock_tokenizer): - """Test aborting non-existent request.""" + """Test aborting non-existent request (deferred abort always enqueues).""" scheduler = Scheduler( model=mock_model, tokenizer=mock_tokenizer, ) + # abort_request() always returns True (enqueue is always successful) result = scheduler.abort_request("nonexistent") - assert result is False + assert result is True def test_get_stats(self, mock_model, mock_tokenizer): """Test getting scheduler stats.""" diff --git a/tests/test_memory_cache.py b/tests/test_memory_cache.py index 77c330d6..832dd7e9 100644 --- a/tests/test_memory_cache.py +++ b/tests/test_memory_cache.py @@ -10,6 +10,7 @@ MemoryAwarePrefixCache, MemoryCacheConfig, _CacheEntry, + _array_memory, _get_available_memory, estimate_kv_cache_memory, ) @@ -116,6 +117,21 @@ def __init__(self, nbytes: int): self.nbytes = nbytes +class MockDtype: + """Mock dtype with size attribute.""" + + def __init__(self, size: int): + self.size = size + + +class MockShapeArray: + """Mock array with shape and dtype (like MLX arrays) but no nbytes.""" + + def __init__(self, shape: tuple, dtype_size: int): + self.shape = shape + self.dtype = MockDtype(dtype_size) + + class MockKVCache: """Mock KV cache with keys/values attributes.""" @@ -136,6 +152,46 @@ def state(self): return (self._keys, self._values) +class TestArrayMemory: + """Tests for _array_memory helper (shape-based, no lazy eval trigger).""" + + def test_shape_dtype_estimation(self): + """Verify shape*dtype.size computation without .nbytes access.""" + arr = MockShapeArray(shape=(2, 16, 128, 64), dtype_size=2) + # 2 * 16 * 128 * 64 * 2 = 524288 + assert _array_memory(arr) == 2 * 16 * 128 * 64 * 2 + + def test_fallback_to_nbytes(self): + """Verify fallback to .nbytes when shape/dtype not available.""" + arr = MockArray(nbytes=4096) + assert _array_memory(arr) == 4096 + + def test_zero_for_unknown_object(self): + """Return 0 for objects without shape/dtype/nbytes.""" + assert _array_memory(42) == 0 + assert _array_memory("string") == 0 + + def test_shape_dtype_preferred_over_nbytes(self): + """When both shape+dtype and nbytes exist, shape+dtype is used.""" + + class DualArray: + def __init__(self): + self.shape = (10,) + self.dtype = MockDtype(4) + self.nbytes = 9999 # should NOT be used + + arr = DualArray() + assert _array_memory(arr) == 40 # 10 * 4, not 9999 + + def test_estimate_uses_shape_based_for_dict_state(self): + """estimate_kv_cache_memory uses _array_memory (shape-based) for dicts.""" + keys = MockShapeArray(shape=(1, 8, 100, 64), dtype_size=2) + values = MockShapeArray(shape=(1, 8, 100, 64), dtype_size=2) + layer = {"state": (keys, values)} + expected = 2 * (1 * 8 * 100 * 64 * 2) + assert estimate_kv_cache_memory([layer]) == expected + + class TestEstimateKvCacheMemory: """Tests for estimate_kv_cache_memory function.""" diff --git a/tests/test_memory_stability.py b/tests/test_memory_stability.py new file mode 100644 index 00000000..f332d1b4 --- /dev/null +++ b/tests/test_memory_stability.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for VRAM memory stability fixes. + +Verifies that: +1. BatchGenerator.close() is called when replacing/discarding generators +2. Periodic mx.clear_cache() is triggered in generation loop +3. Metal memory stats are reported in get_stats() +""" + +from unittest.mock import MagicMock, patch + +from vllm_mlx.request import SamplingParams +from vllm_mlx.scheduler import Scheduler, SchedulerConfig + + +def _make_scheduler( + enable_prefix_cache=False, +) -> Scheduler: + """Create a scheduler with mocked model/tokenizer for unit tests.""" + model = MagicMock() + tokenizer = MagicMock() + tokenizer.encode = lambda x: list(range(len(x.split()))) + tokenizer.eos_token_id = 0 + + config = SchedulerConfig( + max_num_seqs=4, + enable_prefix_cache=enable_prefix_cache, + ) + return Scheduler(model, tokenizer, config) + + +class TestBatchGeneratorClose: + """Tests that BatchGenerator.close() is called properly.""" + + def test_close_called_on_replacement(self): + """Verify .close() is called when BatchGenerator is replaced.""" + scheduler = _make_scheduler() + + # Create a mock BatchGenerator with close() + old_generator = MagicMock() + old_generator.close = MagicMock() + scheduler.batch_generator = old_generator + + # Trigger replacement via _close_batch_generator + scheduler._close_batch_generator() + + old_generator.close.assert_called_once() + assert scheduler.batch_generator is None + + def test_close_called_on_reset(self): + """Verify .close() is called during reset().""" + scheduler = _make_scheduler() + + mock_generator = MagicMock() + mock_generator.close = MagicMock() + scheduler.batch_generator = mock_generator + + scheduler.reset() + + mock_generator.close.assert_called_once() + assert scheduler.batch_generator is None + + def test_close_called_on_cache_error_recovery(self): + """Verify .close() is called during _recover_from_cache_error().""" + scheduler = _make_scheduler() + + mock_generator = MagicMock() + mock_generator.close = MagicMock() + scheduler.batch_generator = mock_generator + + scheduler._recover_from_cache_error() + + mock_generator.close.assert_called_once() + assert scheduler.batch_generator is None + + def test_close_not_called_when_none(self): + """Verify no error when batch_generator is already None.""" + scheduler = _make_scheduler() + assert scheduler.batch_generator is None + + # Should not raise + scheduler._close_batch_generator() + assert scheduler.batch_generator is None + + def test_close_exception_is_caught(self): + """Verify exceptions in close() are caught gracefully.""" + scheduler = _make_scheduler() + + mock_generator = MagicMock() + mock_generator.close = MagicMock(side_effect=RuntimeError("close failed")) + scheduler.batch_generator = mock_generator + + # Should not raise + scheduler._close_batch_generator() + assert scheduler.batch_generator is None + + def test_close_called_in_ensure_batch_generator(self): + """Verify _close_batch_generator is called when _ensure_batch_generator replaces.""" + scheduler = _make_scheduler() + + mock_generator = MagicMock() + mock_generator.close = MagicMock() + scheduler.batch_generator = mock_generator + scheduler._current_sampler_params = (0.5, 0.9, 0.0) + + # Patch _create_batch_generator to return a new mock + new_generator = MagicMock() + with patch.object( + scheduler, "_create_batch_generator", return_value=new_generator + ): + # Different params forces recreation + params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=100) + scheduler._ensure_batch_generator(params) + + # Old generator should have been closed + mock_generator.close.assert_called_once() + assert scheduler.batch_generator is new_generator + + +class TestClearCacheInterval: + """Tests for periodic mx.clear_cache() calls.""" + + def test_clear_cache_interval_configured(self): + """Verify default clear_cache interval is set.""" + scheduler = _make_scheduler() + assert scheduler._step_count == 0 + assert scheduler._clear_cache_interval == 32 + + @patch("vllm_mlx.scheduler.mx") + def test_clear_cache_called_periodically(self, mock_mx): + """Verify mx.clear_cache() is called every _clear_cache_interval steps.""" + scheduler = _make_scheduler() + scheduler._clear_cache_interval = 4 # Small interval for testing + + # Simulate steps without actual generation (no running requests) + for _i in range(8): + scheduler.step() + + # Should have been called at step 4 and 8 + assert mock_mx.clear_cache.call_count >= 2 + + @patch("vllm_mlx.scheduler.mx") + def test_clear_cache_called_on_cleanup(self, mock_mx): + """Verify mx.clear_cache() is called when requests finish.""" + scheduler = _make_scheduler() + + # Call _cleanup_finished with non-empty set + scheduler._cleanup_finished({"req-1"}) + + mock_mx.clear_cache.assert_called() + + @patch("vllm_mlx.scheduler.mx") + def test_clear_cache_not_called_on_empty_cleanup(self, mock_mx): + """Verify mx.clear_cache() is NOT called when no requests finish.""" + scheduler = _make_scheduler() + + scheduler._cleanup_finished(set()) + + mock_mx.clear_cache.assert_not_called() + + +class TestIncrementalCacheEval: + """Tests for incremental per-layer cache evaluation in _cleanup_finished().""" + + @patch("vllm_mlx.scheduler.mx") + def test_incremental_eval_called_per_layer(self, mock_mx): + """Verify mx.eval is called per layer during cleanup, not as one batch.""" + scheduler = _make_scheduler() + + # Create a mock request with extracted cache (dict-state format) + mock_request = MagicMock() + mock_request.prompt_token_ids = [1, 2, 3] + mock_request.output_token_ids = [4, 5] + mock_keys_1 = MagicMock() + mock_values_1 = MagicMock() + mock_keys_2 = MagicMock() + mock_values_2 = MagicMock() + mock_request._extracted_cache = [ + {"state": (mock_keys_1, mock_values_1)}, + {"state": (mock_keys_2, mock_values_2)}, + ] + + scheduler.running["req-1"] = mock_request + + scheduler._cleanup_finished({"req-1"}) + + # mx.eval should have been called once per layer (2 layers) + eval_calls = mock_mx.eval.call_args_list + assert len(eval_calls) == 2 + # First call with layer 1 keys/values + assert eval_calls[0] == ((mock_keys_1, mock_values_1),) + # Second call with layer 2 keys/values + assert eval_calls[1] == ((mock_keys_2, mock_values_2),) + + @patch("vllm_mlx.scheduler.mx") + def test_no_eval_when_no_extracted_cache(self, mock_mx): + """Verify mx.eval is not called when request has no extracted cache.""" + scheduler = _make_scheduler() + + mock_request = MagicMock() + mock_request.prompt_token_ids = [1, 2, 3] + mock_request.output_token_ids = [4, 5] + mock_request._extracted_cache = None + + scheduler.running["req-1"] = mock_request + + scheduler._cleanup_finished({"req-1"}) + + # mx.eval should NOT have been called (only mx.clear_cache for cleanup) + mock_mx.eval.assert_not_called() + + @patch("vllm_mlx.scheduler.mx") + def test_no_eager_eval_in_extraction_path(self, mock_mx): + """Verify mx.eval(mx.array(0)) is NOT called during cache extraction.""" + scheduler = _make_scheduler() + + # Create a mock response with prompt_cache + mock_response = MagicMock() + mock_response.uid = 42 + mock_response.token = 100 + mock_response.finish_reason = "stop" + mock_response.prompt_cache = [MagicMock()] + + # Setup request/uid mapping + mock_request = MagicMock() + mock_request.request_id = "req-1" + mock_request.output_token_ids = [100] + mock_request.num_output_tokens = 1 + mock_request.num_prompt_tokens = 3 + scheduler.running["req-1"] = mock_request + scheduler.uid_to_request_id[42] = "req-1" + + scheduler._process_batch_responses([mock_response]) + + # Verify mx.eval was NOT called with mx.array(0) — the old spike pattern + for call_args in mock_mx.eval.call_args_list: + args = call_args[0] + # Should not be called with a single mx.array argument + assert not (len(args) == 1 and args[0] == mock_mx.array(0)) + + +class TestMemoryStats: + """Tests for Metal memory stats in get_stats().""" + + @patch("vllm_mlx.scheduler.mx") + def test_metal_stats_included(self, mock_mx): + """Verify Metal memory stats appear in get_stats().""" + mock_mx.metal.is_available.return_value = True + mock_mx.metal.get_active_memory.return_value = 10_000_000_000 # 10GB + mock_mx.metal.get_peak_memory.return_value = 15_000_000_000 # 15GB + mock_mx.metal.get_cache_memory.return_value = 2_000_000_000 # 2GB + + scheduler = _make_scheduler() + stats = scheduler.get_stats() + + assert stats["metal_active_memory_gb"] == 10.0 + assert stats["metal_peak_memory_gb"] == 15.0 + assert stats["metal_cache_memory_gb"] == 2.0 + + @patch("vllm_mlx.scheduler.mx") + def test_metal_stats_graceful_on_error(self, mock_mx): + """Verify get_stats() works even if Metal stats fail.""" + mock_mx.metal.is_available.side_effect = RuntimeError("no metal") + + scheduler = _make_scheduler() + stats = scheduler.get_stats() + + # Should still return basic stats without Metal info + assert "num_waiting" in stats + assert "metal_active_memory_gb" not in stats diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py index 58044fc2..18411617 100644 --- a/tests/test_native_tool_format.py +++ b/tests/test_native_tool_format.py @@ -36,6 +36,7 @@ def test_parsers_with_native_support(self): GraniteToolParser, FunctionaryToolParser, KimiToolParser, + HermesToolParser, ] for parser_cls in native_parsers: assert ( @@ -49,7 +50,6 @@ def test_parsers_without_native_support(self): """Parsers that don't support native tool format should return False.""" non_native_parsers = [ QwenToolParser, - HermesToolParser, NemotronToolParser, xLAMToolParser, AutoToolParser, @@ -65,14 +65,22 @@ def test_parsers_without_native_support(self): def test_via_manager(self): """Test native format detection via ToolParserManager.""" # Native support - for name in ["mistral", "llama", "deepseek", "granite", "functionary", "kimi"]: + for name in [ + "mistral", + "llama", + "deepseek", + "granite", + "functionary", + "kimi", + "hermes", + ]: parser_cls = ToolParserManager.get_tool_parser(name) assert ( parser_cls.supports_native_format() is True ), f"Parser '{name}' should support native format" # No native support - for name in ["qwen", "hermes", "nemotron", "xlam", "auto"]: + for name in ["qwen", "nemotron", "xlam", "auto"]: parser_cls = ToolParserManager.get_tool_parser(name) assert ( parser_cls.supports_native_format() is False diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 778f4deb..8e36a397 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -211,19 +211,15 @@ def test_clear(self, cache_manager): # Stats should also be reset assert cache_manager.stats.hits == 0 - def test_cache_deep_copy(self, cache_manager): - """Test that fetched cache is a deep copy.""" + def test_cache_no_copy(self, cache_manager): + """Test that fetched cache is a reference (no copy) — MLX arrays are immutable.""" original = [[1, 2, 3]] cache_manager.store_cache([1, 2], original) cache, _ = cache_manager.fetch_cache([1, 2]) - # Modify returned cache - cache[0].append(4) - - # Original should be unchanged - cache2, _ = cache_manager.fetch_cache([1, 2]) - assert cache2[0] == [1, 2, 3] + # Returns the same object (no deep copy overhead) + assert cache is original def test_multiple_prefixes(self, cache_manager): """Test multiple different prefixes.""" diff --git a/vllm_mlx/api/anthropic_adapter.py b/vllm_mlx/api/anthropic_adapter.py new file mode 100644 index 00000000..dbb94200 --- /dev/null +++ b/vllm_mlx/api/anthropic_adapter.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Adapter for converting between Anthropic Messages API and OpenAI Chat Completions API. + +Handles translation of: +- Requests: Anthropic → OpenAI format +- Responses: OpenAI → Anthropic format +- Messages: Content blocks, tool calls, tool results +""" + +import json +import uuid + +from .anthropic_models import ( + AnthropicMessage, + AnthropicRequest, + AnthropicResponse, + AnthropicResponseContentBlock, + AnthropicToolDef, + AnthropicUsage, +) +from .models import ( + ChatCompletionRequest, + ChatCompletionResponse, + Message, + ToolDefinition, +) + + +def anthropic_to_openai(request: AnthropicRequest) -> ChatCompletionRequest: + """ + Convert an Anthropic Messages API request to OpenAI Chat Completions format. + + Handles: + - system field → system message + - Content blocks → OpenAI message format + - tool_use/tool_result → OpenAI tool_calls/tool messages + - Anthropic tools → OpenAI tools + + Args: + request: Anthropic Messages API request + + Returns: + OpenAI ChatCompletionRequest + """ + messages = [] + + # Convert system to system message + if request.system: + if isinstance(request.system, str): + system_text = request.system + elif isinstance(request.system, list): + # System can be a list of content blocks + parts = [] + for block in request.system: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + elif isinstance(block, str): + parts.append(block) + system_text = "\n".join(parts) + else: + system_text = str(request.system) + messages.append(Message(role="system", content=system_text)) + + # Convert each message + for msg in request.messages: + converted = _convert_message(msg) + messages.extend(converted) + + # Convert tools + tools = None + if request.tools: + tools = [_convert_tool(t) for t in request.tools] + + # Convert tool_choice + tool_choice = None + if request.tool_choice: + tool_choice = _convert_tool_choice(request.tool_choice) + + return ChatCompletionRequest( + model=request.model, + messages=messages, + max_tokens=request.max_tokens, + temperature=request.temperature if request.temperature is not None else 0.7, + top_p=request.top_p if request.top_p is not None else 0.9, + stream=request.stream, + stop=request.stop_sequences, + tools=tools, + tool_choice=tool_choice, + ) + + +def openai_to_anthropic( + response: ChatCompletionResponse, + model: str, +) -> AnthropicResponse: + """ + Convert an OpenAI Chat Completions response to Anthropic Messages API format. + + Args: + response: OpenAI ChatCompletionResponse + model: Model name for the response + + Returns: + Anthropic Messages API response + """ + content = [] + choice = response.choices[0] if response.choices else None + + if choice: + # Add text content + if choice.message.content: + content.append( + AnthropicResponseContentBlock( + type="text", + text=choice.message.content, + ) + ) + + # Add tool use blocks + if choice.message.tool_calls: + for tc in choice.message.tool_calls: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + + content.append( + AnthropicResponseContentBlock( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=tool_input, + ) + ) + + stop_reason = _convert_stop_reason(choice.finish_reason) + else: + stop_reason = "end_turn" + + # If no content blocks, add empty text + if not content: + content.append(AnthropicResponseContentBlock(type="text", text="")) + + return AnthropicResponse( + model=model, + content=content, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=response.usage.prompt_tokens if response.usage else 0, + output_tokens=response.usage.completion_tokens if response.usage else 0, + ), + ) + + +def _convert_message(msg: AnthropicMessage) -> list[Message]: + """ + Convert an Anthropic message to one or more OpenAI messages. + + Anthropic tool_result blocks (sent as user messages) need to be + split into separate OpenAI tool messages. + + Args: + msg: Anthropic message + + Returns: + List of OpenAI messages + """ + # Simple string content + if isinstance(msg.content, str): + return [Message(role=msg.role, content=msg.content)] + + # Content is a list of blocks + messages = [] + text_parts = [] + tool_calls_for_assistant = [] + tool_results = [] + + for block in msg.content: + if block.type == "text": + text_parts.append(block.text or "") + + elif block.type == "tool_use": + # Assistant message with tool calls + tool_input = block.input or {} + tool_calls_for_assistant.append( + { + "id": block.id or f"call_{uuid.uuid4().hex[:8]}", + "type": "function", + "function": { + "name": block.name or "", + "arguments": json.dumps(tool_input), + }, + } + ) + + elif block.type == "tool_result": + # Tool result → OpenAI tool message + result_content = block.content + if isinstance(result_content, list): + # Extract text from content blocks + parts = [] + for item in result_content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text", "")) + elif isinstance(item, str): + parts.append(item) + result_content = "\n".join(parts) + elif result_content is None: + result_content = "" + + tool_results.append( + Message( + role="tool", + content=str(result_content), + tool_call_id=block.tool_use_id or "", + ) + ) + + # Build the messages + if msg.role == "assistant": + combined_text = "\n".join(text_parts) if text_parts else None + if tool_calls_for_assistant: + messages.append( + Message( + role="assistant", + content=combined_text or "", + tool_calls=tool_calls_for_assistant, + ) + ) + elif combined_text is not None: + messages.append(Message(role="assistant", content=combined_text)) + else: + messages.append(Message(role="assistant", content="")) + elif msg.role == "user": + # User messages: collect text parts, then add tool results separately + if text_parts: + combined_text = "\n".join(text_parts) + messages.append(Message(role="user", content=combined_text)) + + # Tool results become separate tool messages + messages.extend(tool_results) + + # If no text and no tool results, add empty user message + if not text_parts and not tool_results: + messages.append(Message(role="user", content="")) + else: + # Other roles + combined_text = "\n".join(text_parts) if text_parts else "" + messages.append(Message(role=msg.role, content=combined_text)) + + return messages + + +def _convert_tool(tool: AnthropicToolDef) -> ToolDefinition: + """ + Convert an Anthropic tool definition to OpenAI format. + + Anthropic: {"name": "...", "description": "...", "input_schema": {...}} + OpenAI: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}} + """ + return ToolDefinition( + type="function", + function={ + "name": tool.name, + "description": tool.description or "", + "parameters": tool.input_schema or {"type": "object", "properties": {}}, + }, + ) + + +def _convert_tool_choice(tool_choice: dict) -> str | dict | None: + """ + Convert Anthropic tool_choice to OpenAI format. + + Anthropic: {"type": "auto"} | {"type": "any"} | {"type": "tool", "name": "..."} + OpenAI: "auto" | "none" | "required" | {"type": "function", "function": {"name": "..."}} + """ + choice_type = tool_choice.get("type", "auto") + + if choice_type == "auto": + return "auto" + elif choice_type == "any": + return "required" + elif choice_type == "tool": + return { + "type": "function", + "function": {"name": tool_choice.get("name", "")}, + } + elif choice_type == "none": + return "none" + + return "auto" + + +def _convert_stop_reason(openai_reason: str | None) -> str: + """ + Convert OpenAI finish_reason to Anthropic stop_reason. + + OpenAI: "stop" | "tool_calls" | "length" | "content_filter" + Anthropic: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" + """ + if openai_reason is None: + return "end_turn" + + mapping = { + "stop": "end_turn", + "tool_calls": "tool_use", + "length": "max_tokens", + "content_filter": "end_turn", + } + return mapping.get(openai_reason, "end_turn") diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py new file mode 100644 index 00000000..a5bc6f77 --- /dev/null +++ b/vllm_mlx/api/anthropic_models.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Pydantic models for Anthropic Messages API. + +These models define the request and response schemas for the +Anthropic-compatible /v1/messages endpoint, enabling clients like +Claude Code to communicate with vllm-mlx. +""" + +import uuid +from typing import Any + +from pydantic import BaseModel, Field + +# ============================================================================= +# Request Models +# ============================================================================= + + +class AnthropicContentBlock(BaseModel): + """A content block in an Anthropic message.""" + + type: str # "text", "image", "tool_use", "tool_result" + # text block + text: str | None = None + # tool_use block + id: str | None = None + name: str | None = None + input: dict | None = None + # tool_result block + tool_use_id: str | None = None + content: str | list | None = None + is_error: bool | None = None + # image block + source: dict | None = None + + +class AnthropicMessage(BaseModel): + """A message in an Anthropic conversation.""" + + role: str # "user" | "assistant" + content: str | list[AnthropicContentBlock] + + +class AnthropicToolDef(BaseModel): + """Definition of a tool in Anthropic format.""" + + name: str + description: str | None = None + input_schema: dict | None = None + + +class AnthropicRequest(BaseModel): + """Request for Anthropic Messages API.""" + + model: str + messages: list[AnthropicMessage] + system: str | list[dict] | None = None + max_tokens: int # Required in Anthropic API + temperature: float | None = None + top_p: float | None = None + stream: bool = False + stop_sequences: list[str] | None = None + tools: list[AnthropicToolDef] | None = None + tool_choice: dict | None = None + metadata: dict | None = None + top_k: int | None = None + + +# ============================================================================= +# Response Models +# ============================================================================= + + +class AnthropicUsage(BaseModel): + """Token usage for Anthropic response.""" + + input_tokens: int = 0 + output_tokens: int = 0 + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + +class AnthropicResponseContentBlock(BaseModel): + """A content block in the Anthropic response.""" + + type: str # "text" or "tool_use" + text: str | None = None + # tool_use fields + id: str | None = None + name: str | None = None + input: Any | None = None + + +class AnthropicResponse(BaseModel): + """Response for Anthropic Messages API.""" + + id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:24]}") + type: str = "message" + role: str = "assistant" + model: str + content: list[AnthropicResponseContentBlock] + stop_reason: str | None = None + stop_sequence: str | None = None + usage: AnthropicUsage = Field(default_factory=AnthropicUsage) diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index f9aa4be6..1443c167 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -82,7 +82,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: return tool_calls if tool_calls else None -def parse_tool_calls(text: str) -> Tuple[str, Optional[List[ToolCall]]]: +def parse_tool_calls( + text: str, request: dict[str, Any] | None = None +) -> Tuple[str, Optional[List[ToolCall]]]: """ Parse tool calls from model output. @@ -144,7 +146,13 @@ def parse_tool_calls(text: str) -> Tuple[str, Optional[List[ToolCall]]]: # Parse parameters from value format param_pattern = r"]+)>\s*(.*?)\s*" params = re.findall(param_pattern, params_block, re.DOTALL) - arguments = {p_name.strip(): p_value.strip() for p_name, p_value in params} + arguments = {} + for p_name, p_value in params: + val = p_value.strip() + try: + arguments[p_name.strip()] = json.loads(val) + except (json.JSONDecodeError, ValueError): + arguments[p_name.strip()] = val tool_calls.append( ToolCall( diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 91332bfa..b1d5005f 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -185,11 +185,28 @@ def extract_multimodal_content( tool_calls_list = [] for tc in tool_calls: if isinstance(tc, dict): - tool_calls_list.append(tc) + tc_copy = tc elif hasattr(tc, "model_dump"): - tool_calls_list.append(tc.model_dump()) + tc_copy = tc.model_dump() elif hasattr(tc, "dict"): - tool_calls_list.append(tc.dict()) + tc_copy = tc.dict() + else: + continue + + # Chat templates (e.g. Qwen3) iterate arguments|items, + # but OpenAI API sends arguments as a JSON string. + # Parse it into a dict so the template can iterate it. + func = tc_copy.get("function") or {} + args = func.get("arguments") + if isinstance(args, str): + try: + import json + + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + pass + + tool_calls_list.append(tc_copy) msg_dict = {"role": role, "content": content if content else ""} if tool_calls_list: diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 700b35b7..89ef5343 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -139,9 +139,13 @@ def serve_command(args): use_paged_cache=args.use_paged_cache, paged_cache_block_size=args.paged_cache_block_size, max_cache_blocks=args.max_cache_blocks, + # Chunked prefill + chunked_prefill_tokens=args.chunked_prefill_tokens, ) print("Mode: Continuous batching (for multiple concurrent users)") + if args.chunked_prefill_tokens > 0: + print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step") print(f"Stream interval: {args.stream_interval} tokens") if args.use_paged_cache: print( @@ -502,6 +506,14 @@ def main(): default=1000, help="Maximum number of cache blocks (default: 1000)", ) + # Chunked prefill + serve_parser.add_argument( + "--chunked-prefill-tokens", + type=int, + default=0, + help="Max prefill tokens per scheduler step (0=disabled). " + "Prevents starvation of active requests during long prefills.", + ) # MCP options serve_parser.add_argument( "--mcp-config", @@ -574,6 +586,19 @@ def main(): f"Options: {', '.join(reasoning_choices)}." ), ) + # Generation defaults + serve_parser.add_argument( + "--default-temperature", + type=float, + default=None, + help="Override default temperature for all requests (default: use model default)", + ) + serve_parser.add_argument( + "--default-top-p", + type=float, + default=None, + help="Override default top_p for all requests (default: use model default)", + ) # Embedding model option serve_parser.add_argument( "--embedding-model", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ddb95813..3b2e973e 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -260,6 +260,30 @@ async def _start_llm(self) -> None: tokenizer_config=tokenizer_config, ) + # Set Metal memory limits to make allocation failures graceful + # instead of fatal Metal command buffer errors (SIGABRT) + try: + import mlx.core as mx + + if mx.metal.is_available(): + device_info = mx.device_info() + max_recommended = device_info.get( + "max_recommended_working_set_size", + device_info.get("memory_size", 0), + ) + if max_recommended > 0: + soft_limit = int(max_recommended * 0.90) + mx.set_memory_limit(soft_limit) + mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB + logger.info( + f"Metal memory limits set: " + f"allocation_limit={soft_limit / 1e9:.1f}GB " + f"(90% of {max_recommended / 1e9:.1f}GB), " + f"cache_limit=32GB" + ) + except Exception as e: + logger.warning(f"Failed to set Metal memory limits: {e}") + # Create engine config scheduler_config = self._scheduler_config or SchedulerConfig() engine_config = EngineConfig( @@ -345,10 +369,11 @@ def _apply_chat_template( # Fall through to standard template if hasattr(tokenizer, "apply_chat_template"): + enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, - "enable_thinking": True, + "enable_thinking": enable_thinking, } if tools: template_kwargs["tools"] = tools @@ -498,9 +523,11 @@ async def stream_generate( stop=stop or [], ) + prefix_boundary = kwargs.pop("prefix_boundary", 0) request_id = await self._engine.add_request( prompt=prompt, sampling_params=sampling_params, + prefix_boundary=prefix_boundary, ) async for output in self._engine.stream_outputs(request_id): @@ -575,6 +602,57 @@ async def chat( **kwargs, ) + def _compute_prefix_boundary( + self, messages: list[dict[str, Any]], tools: list[dict] | None = None + ) -> int: + """Compute token count for the shared prefix across message variations. + + Uses a two-tokenization approach: tokenize the full prompt twice + (once as-is, once with the last user message replaced by a dummy) + and find the longest common prefix (LCP). This gives the exact + boundary where different user suffixes diverge, avoiding template + discrepancies (e.g. Qwen3 markers on last assistant). + """ + # Find index of last user message + last_user_idx = None + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + last_user_idx = i + break + if last_user_idx is None or last_user_idx == 0: + return 0 + try: + template_tools = convert_tools_for_template(tools) if tools else None + + # Tokenize the real prompt + real_prompt = self._apply_chat_template(messages, template_tools) + + # Build a dummy variant with different last user content + dummy_messages = list(messages) + dummy_messages[last_user_idx] = { + **messages[last_user_idx], + "content": "XXXXXXXXXX", + } + dummy_prompt = self._apply_chat_template(dummy_messages, template_tools) + + tokenizer = self.tokenizer + if hasattr(tokenizer, "tokenizer"): + tokenizer = tokenizer.tokenizer + + real_tokens = tokenizer.encode(real_prompt) + dummy_tokens = tokenizer.encode(dummy_prompt) + + # Find LCP — the point where the two diverge is the boundary + lcp = 0 + for j in range(min(len(real_tokens), len(dummy_tokens))): + if real_tokens[j] != dummy_tokens[j]: + break + lcp = j + 1 + + return lcp + except Exception: + return 0 + async def stream_chat( self, messages: list[dict[str, Any]], @@ -625,6 +703,11 @@ async def stream_chat( num_images=len(all_images), ) + # Compute prefix boundary for cache + prefix_boundary = self._compute_prefix_boundary(messages, tools) + if prefix_boundary > 0: + kwargs["prefix_boundary"] = prefix_boundary + async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens, @@ -660,3 +743,15 @@ def get_cache_stats(self) -> dict[str, Any] | None: elif self._engine: return self._engine.get_cache_stats() return None + + def save_cache_to_disk(self, cache_dir: str) -> bool: + """Save prefix cache to disk for persistence across restarts.""" + if self._engine: + return self._engine.save_cache_to_disk(cache_dir) + return False + + def load_cache_from_disk(self, cache_dir: str) -> int: + """Load prefix cache from disk. Returns number of entries loaded.""" + if self._engine: + return self._engine.load_cache_from_disk(cache_dir) + return 0 diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index b174145d..aaa0ccc4 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -18,6 +18,8 @@ from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, List, Optional, Union +import mlx.core as mx + from .request import Request, RequestOutput, SamplingParams from .scheduler import Scheduler, SchedulerConfig from .output_collector import RequestOutputCollector, RequestStreamState @@ -129,19 +131,68 @@ def is_running(self) -> bool: return self._running async def _engine_loop(self) -> None: - """Main engine loop - optimized for minimal overhead.""" - # Cache config values for faster access + """Main engine loop - hybrid executor for prefill vs generation. + + Prefill steps (long prompts) are run in a thread executor to keep + the asyncio event loop responsive. Generation-only steps (~1-3ms) + are called directly to avoid ~0.5-2ms context switch overhead, + giving ~5-10% throughput improvement during sustained generation. + """ + import concurrent.futures + + # Single-thread executor ensures MLX calls are never concurrent + _executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mlx-step" + ) + loop = asyncio.get_running_loop() + step_interval = self.config.step_interval stream_interval = self.config.stream_interval use_simple_streaming = stream_interval == 1 + # Emergency memory pressure threshold (200GB) + _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 + _memory_check_interval = 64 + while self._running: try: if self.scheduler.has_requests(): - # Run one generation step - output = self.scheduler.step() + # Hybrid approach: use executor only when prefill is likely. + # Prefill happens when there are waiting requests that need + # to be inserted into the batch (may block for seconds). + # Generation-only steps are fast (<3ms) and can run inline. + has_waiting = self.scheduler.get_num_waiting() > 0 + has_partial = ( + self.scheduler.batch_generator is not None + and getattr(self.scheduler.batch_generator, "_partial", None) + is not None + ) + needs_executor = has_waiting or has_partial + + if needs_executor: + output = await loop.run_in_executor( + _executor, self.scheduler.step + ) + else: + output = self.scheduler.step() + # Yield to event loop after inline step + await asyncio.sleep(0) self._steps_executed += 1 + # Emergency memory pressure check + if self._steps_executed % _memory_check_interval == 0: + try: + active_mem = mx.get_active_memory() + if active_mem > _memory_pressure_threshold: + mx.clear_cache() + logger.warning( + f"[Memory pressure] {active_mem / 1e9:.1f}GB > " + f"{_memory_pressure_threshold / 1e9:.0f}GB threshold, " + f"forced cache clear" + ) + except Exception: + pass + # Fast path: distribute outputs to collectors outputs = output.outputs if outputs: @@ -171,9 +222,15 @@ async def _engine_loop(self) -> None: if event: event.set() - # OPTIMIZATION: Only yield if streaming consumers are waiting - if RequestOutputCollector.has_waiting_consumers(): - await asyncio.sleep(0) + # Free Metal buffers after distributing finished outputs + if output.finished_request_ids: + mx.clear_cache() + + # Always yield to prevent event loop starvation. + # Without this, orphaned requests (client disconnected but + # request still in scheduler) block the entire event loop, + # making the server unresponsive to all HTTP requests. + await asyncio.sleep(0) else: # No work, yield control await asyncio.sleep(step_interval) @@ -191,6 +248,7 @@ async def add_request( request_id: Optional[str] = None, images: Optional[List[Any]] = None, videos: Optional[List[Any]] = None, + prefix_boundary: int = 0, ) -> str: """ Add a request for processing. @@ -201,6 +259,7 @@ async def add_request( request_id: Optional custom request ID images: Optional images for multimodal videos: Optional videos for multimodal + prefix_boundary: Token count for shared prefix (for cache) Returns: The request ID @@ -217,6 +276,7 @@ async def add_request( sampling_params=sampling_params, images=images, videos=videos, + prefix_boundary=prefix_boundary, ) # Setup output collector with stream_interval from config @@ -264,16 +324,24 @@ async def stream_outputs( Yields: RequestOutput objects as tokens are generated """ + import time as _time + + _t0 = _time.monotonic() + _token_count = 0 + collector = self._output_collectors.get(request_id) if collector is None: - # Request might not be added yet or already cleaned up + logger.warning( + f"[stream_outputs] {request_id[:12]} no collector found, returning immediately" + ) return + logger.info(f"[stream_outputs] {request_id[:12]} START waiting for tokens") + + finished_normally = False try: while True: try: - # Non-blocking drain pattern from vLLM - # Try get_nowait first to avoid task switch if output ready if timeout: output = collector.get_nowait() if output is None: @@ -283,17 +351,48 @@ async def stream_outputs( else: output = collector.get_nowait() or await collector.get() + _token_count += 1 + if _token_count == 1: + logger.info( + f"[stream_outputs] {request_id[:12]} first token after " + f"{_time.monotonic() - _t0:.1f}s" + ) + yield output if output.finished: + finished_normally = True + logger.info( + f"[stream_outputs] {request_id[:12]} finished normally, " + f"{_token_count} tokens in {_time.monotonic() - _t0:.1f}s" + ) break except asyncio.TimeoutError: - logger.warning(f"Timeout waiting for request {request_id}") + logger.warning( + f"[stream_outputs] {request_id[:12]} TIMEOUT after " + f"{_token_count} tokens, {_time.monotonic() - _t0:.1f}s" + ) break + except (GeneratorExit, asyncio.CancelledError) as exc: + logger.info( + f"[stream_outputs] {request_id[:12]} {type(exc).__name__} after " + f"{_token_count} tokens, {_time.monotonic() - _t0:.1f}s" + ) + finally: + if not finished_normally: + logger.info( + f"[stream_outputs] {request_id[:12]} ABORTING orphaned request " + f"({_token_count} tokens generated in {_time.monotonic() - _t0:.1f}s)" + ) + aborted = self.scheduler.abort_request(request_id) + logger.info( + f"[stream_outputs] {request_id[:12]} abort_request returned {aborted}" + ) self._cleanup_request(request_id) + logger.info(f"[stream_outputs] {request_id[:12]} cleanup done") async def generate( self, @@ -329,29 +428,35 @@ async def generate( if event is None: raise RuntimeError(f"No event for request {request_id}") - # Wait for the request to finish - await event.wait() + try: + # Wait for the request to finish + await event.wait() - # Get the final output from collector - collector = self._output_collectors.get(request_id) - if collector is None: - raise RuntimeError(f"No collector for request {request_id}") + # Get the final output from collector + collector = self._output_collectors.get(request_id) + if collector is None: + raise RuntimeError(f"No collector for request {request_id}") - # Drain all outputs and get the last one - final_output = None - while True: - output = collector.get_nowait() - if output is None: - break - final_output = output + # Drain all outputs and get the last one + final_output = None + while True: + output = collector.get_nowait() + if output is None: + break + final_output = output - # Cleanup - self._cleanup_request(request_id) + if final_output is None: + raise RuntimeError(f"No output for request {request_id}") + + return final_output - if final_output is None: - raise RuntimeError(f"No output for request {request_id}") + except (asyncio.CancelledError, GeneratorExit): + logger.info(f"[generate] {request_id[:12]} CANCELLED, aborting request") + self.scheduler.abort_request(request_id) + raise - return final_output + finally: + self._cleanup_request(request_id) def generate_batch_sync( self, @@ -416,6 +521,7 @@ def get_stats(self) -> Dict[str, Any]: "steps_executed": self._steps_executed, "active_requests": len(self._output_collectors), "stream_interval": self.config.stream_interval, + "requests": self.scheduler.get_running_requests_info(), **scheduler_stats, } @@ -423,6 +529,14 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]: """Get prefix cache statistics.""" return self.scheduler.get_cache_stats() + def save_cache_to_disk(self, cache_dir: str) -> bool: + """Save prefix cache to disk.""" + return self.scheduler.save_cache_to_disk(cache_dir) + + def load_cache_from_disk(self, cache_dir: str) -> int: + """Load prefix cache from disk.""" + return self.scheduler.load_cache_from_disk(cache_dir) + def _release_model(self) -> None: """Release model ownership.""" if self._owns_model and not self._closed: @@ -559,3 +673,11 @@ def get_stats(self) -> Dict[str, Any]: def get_cache_stats(self) -> Optional[Dict[str, Any]]: """Get prefix cache statistics.""" return self.engine.get_cache_stats() + + def save_cache_to_disk(self, cache_dir: str) -> bool: + """Save prefix cache to disk.""" + return self.engine.save_cache_to_disk(cache_dir) + + def load_cache_from_disk(self, cache_dir: str) -> int: + """Load prefix cache from disk.""" + return self.engine.load_cache_from_disk(cache_dir) diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f414f321..902f33f7 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -24,7 +24,9 @@ from __future__ import annotations +import bisect import logging +import math from collections import OrderedDict from dataclasses import dataclass from typing import Any @@ -57,12 +59,38 @@ def _get_available_memory() -> int: return 0 +def _array_memory(arr) -> int: + """ + Estimate array memory from shape+dtype without triggering lazy eval. + + Accessing .nbytes on a lazy MLX array forces evaluation of the entire + computation graph, causing a VRAM spike. This function uses shape and + dtype metadata (which are always available without eval) to compute + the same value. + + Args: + arr: An MLX array or similar object. + + Returns: + Estimated memory in bytes. + """ + if hasattr(arr, "shape") and hasattr(arr, "dtype"): + dtype = arr.dtype + if hasattr(dtype, "size"): + return math.prod(arr.shape) * dtype.size + # Fallback for non-MLX arrays or objects without shape/dtype + if hasattr(arr, "nbytes"): + return arr.nbytes + return 0 + + def estimate_kv_cache_memory(cache: list[Any]) -> int: """ Estimate memory usage of a KV cache in bytes. This function inspects MLX arrays in the cache and calculates their - total memory footprint. + total memory footprint using shape+dtype metadata to avoid triggering + lazy evaluation (which would cause a VRAM spike). Args: cache: List of layer cache objects, each containing keys/values tensors. @@ -81,18 +109,14 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: if isinstance(layer_cache, dict) and "state" in layer_cache: # Extracted state dict keys, values = layer_cache["state"] - if hasattr(keys, "nbytes"): - total_bytes += keys.nbytes - if hasattr(values, "nbytes"): - total_bytes += values.nbytes + total_bytes += _array_memory(keys) + total_bytes += _array_memory(values) elif hasattr(layer_cache, "state") and not isinstance(layer_cache, dict): # Cache with state property returning (keys, values) try: keys, values = layer_cache.state - if hasattr(keys, "nbytes"): - total_bytes += keys.nbytes - if hasattr(values, "nbytes"): - total_bytes += values.nbytes + total_bytes += _array_memory(keys) + total_bytes += _array_memory(values) except (TypeError, ValueError): pass elif hasattr(layer_cache, "keys") and hasattr(layer_cache, "values"): @@ -100,10 +124,10 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: keys_attr = layer_cache.keys values_attr = layer_cache.values # Ensure these are arrays, not methods - if not callable(keys_attr) and hasattr(keys_attr, "nbytes"): - total_bytes += keys_attr.nbytes - if not callable(values_attr) and hasattr(values_attr, "nbytes"): - total_bytes += values_attr.nbytes + if not callable(keys_attr): + total_bytes += _array_memory(keys_attr) + if not callable(values_attr): + total_bytes += _array_memory(values_attr) return total_bytes @@ -209,6 +233,28 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: ) +def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: + """Create shallow copies of KVCache layers with offset reduced by *trim_by*. + + This is used when returning a cached KV state to the scheduler so that + the last N positions are "freed" and the model will recompute them on the + next forward pass (preventing duplicate KV entries). + """ + from mlx_lm.models.cache import KVCache + + trimmed: list[Any] = [] + for layer_cache in cache: + if hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"): + tc = KVCache.__new__(KVCache) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc.offset = max(layer_cache.offset - trim_by, 0) + trimmed.append(tc) + else: + trimmed.append(layer_cache) + return trimmed + + class MemoryAwarePrefixCache: """ Prefix cache with memory-based eviction. @@ -246,6 +292,11 @@ def __init__( # Key: tuple(tokens), Value: _CacheEntry self._entries: OrderedDict[tuple[int, ...], _CacheEntry] = OrderedDict() + # Sorted index of token keys for efficient prefix/supersequence lookup. + # Tuple lexicographic ordering means a prefix key P is always < any + # extension of P, so bisect gives O(log N) range scans instead of O(N). + self._sorted_keys: list[tuple[int, ...]] = [] + # Memory tracking self._max_memory = self._config.compute_memory_limit() self._current_memory = 0 @@ -253,6 +304,9 @@ def __init__( # Statistics self._stats = CacheStats(max_memory_bytes=self._max_memory) + # Track the match type from the last fetch() call + self._last_match_type: str | None = None + logger.info( f"MemoryAwarePrefixCache initialized: " f"max_memory={self._max_memory / _BYTES_PER_MB:.1f}MB, " @@ -263,7 +317,10 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: """ Find cached KV state for the given tokens. - This method searches for exact matches and prefix matches. + This method searches for exact matches, prefix matches, supersequence + matches, and longest-common-prefix (LCP) matches. Uses a sorted key + index for O(log N) lookup instead of scanning all entries. + Returns the cached KV state directly (no copy) since MLX arrays are immutable and safe to share. @@ -277,47 +334,175 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: """ if not tokens: self._stats.misses += 1 + self._last_match_type = "miss" return None, tokens tokens_key = tuple(tokens) - # Check for exact match + # --- O(1) exact match --- if tokens_key in self._entries: entry = self._entries[tokens_key] - # Move to end (most recently used) self._entries.move_to_end(tokens_key) self._stats.hits += 1 self._stats.tokens_saved += len(tokens) - # Return reference directly - MLX arrays are immutable + self._last_match_type = "exact" return entry.cache, [] - # Check for prefix matches (shorter cached sequences) + # --- O(log N) prefix & supersequence match via sorted index --- best_match: _CacheEntry | None = None best_length = 0 + best_super: _CacheEntry | None = None + + sorted_keys = self._sorted_keys + if sorted_keys: + # Find insertion point for tokens_key in the sorted list. + # Keys that are prefixes of tokens_key or supersequences will be + # clustered around this position due to lexicographic ordering. + idx = bisect.bisect_left(sorted_keys, tokens_key) + + # Scan backwards from idx to find cached keys that are PREFIXES + # of tokens_key (shorter cached sequences). A prefix P of T + # satisfies P <= T lexicographically, so P is at idx-1 or earlier. + for i in range(idx - 1, -1, -1): + cached_key = sorted_keys[i] + cached_len = len(cached_key) + if cached_len >= len(tokens_key): + continue # Not a prefix (same length or longer) + # Check if cached_key is a prefix of tokens_key + if tokens_key[:cached_len] == cached_key: + if cached_len > best_length: + best_match = self._entries[cached_key] + best_length = cached_len + # Found best prefix — shorter entries can't be longer + break + # Once we go past the prefix range, stop + if cached_key[0] != tokens_key[0]: + break + + # Scan forward from idx to find cached keys that are SUPERSEQUENCES + # of tokens_key (longer cached sequences starting with tokens_key). + for i in range(idx, len(sorted_keys)): + cached_key = sorted_keys[i] + cached_len = len(cached_key) + if cached_len < len(tokens_key): + continue + # Check if tokens_key is a prefix of cached_key + if cached_key[: len(tokens_key)] == tokens_key: + if best_super is None or cached_len > len(best_super.tokens): + best_super = self._entries[cached_key] + else: + # Past the supersequence range + break + + # --- Supersequence match handling --- + if best_super is not None: + n_cached = len(best_super.tokens) + n_requested = len(tokens) + excess = n_cached - n_requested + + has_non_trimmable = any( + not (hasattr(lc, "offset") and hasattr(lc, "keys")) + for lc in best_super.cache + ) - for cached_key, entry in self._entries.items(): - cached_len = len(cached_key) - # Check if cached sequence is a prefix of requested tokens - if ( - cached_len < len(tokens) - and cached_len > best_length - and tokens_key[:cached_len] == cached_key - ): - best_match = entry - best_length = cached_len - + if excess > 0 and has_non_trimmable: + logger.debug( + "[cache_fetch] supersequence match skipped: " + "non-trimmable cache layers (hybrid model)" + ) + elif excess > 0: + trimmed_cache = _trim_cache_offset(best_super.cache, excess) + self._entries.move_to_end(best_super.tokens) + self._stats.hits += 1 + self._stats.tokens_saved += n_requested + self._last_match_type = "supersequence" + return trimmed_cache, [] + else: + self._entries.move_to_end(best_super.tokens) + self._stats.hits += 1 + self._stats.tokens_saved += n_requested + self._last_match_type = "supersequence" + return best_super.cache, [] + + # --- Prefix match --- if best_match is not None: - # Move matched entry to end (most recently used) self._entries.move_to_end(best_match.tokens) self._stats.hits += 1 self._stats.tokens_saved += best_length remaining = tokens[best_length:] + self._last_match_type = "prefix" return best_match.cache, remaining + # --- LCP (Longest Common Prefix) for divergent sequences --- + # This handles the agentic pattern: same system+context prefix + # but different final user message. Use the sorted index to find + # the nearest neighbor which likely shares the longest prefix. + best_lcp_entry: _CacheEntry | None = None + best_lcp_length = 0 + + if sorted_keys: + idx = bisect.bisect_left(sorted_keys, tokens_key) + # Check neighbors around insertion point (they share the most + # common prefix due to lexicographic ordering). + for i in (idx - 1, idx): + if i < 0 or i >= len(sorted_keys): + continue + cached_key = sorted_keys[i] + if cached_key == tokens_key: + continue # Skip exact (already handled) + min_len = min(len(cached_key), len(tokens_key)) + if min_len <= best_lcp_length: + continue + # Compute LCP length + lcp = 0 + for j in range(min_len): + if cached_key[j] != tokens_key[j]: + break + lcp = j + 1 + if lcp > best_lcp_length: + best_lcp_entry = self._entries[cached_key] + best_lcp_length = lcp + logger.debug( + f"[cache_fetch] LCP scan: cached_len={len(cached_key)} " + f"req_len={len(tokens_key)} lcp={lcp}" + ) + + if best_lcp_entry is not None and best_lcp_length > 0: + excess = len(best_lcp_entry.tokens) - best_lcp_length + + has_non_trimmable = any( + not (hasattr(lc, "offset") and hasattr(lc, "keys")) + for lc in best_lcp_entry.cache + ) + logger.debug( + f"[cache_fetch] LCP candidate: lcp={best_lcp_length} " + f"entry_len={len(best_lcp_entry.tokens)} excess={excess} " + f"non_trimmable={has_non_trimmable} " + f"cache_layers={len(best_lcp_entry.cache)} " + f"layer_types={[type(lc).__name__ for lc in best_lcp_entry.cache[:3]]}" + ) + + if not has_non_trimmable: + trimmed_cache = _trim_cache_offset(best_lcp_entry.cache, excess) + self._entries.move_to_end(best_lcp_entry.tokens) + self._stats.hits += 1 + self._stats.tokens_saved += best_lcp_length + remaining = tokens[best_lcp_length:] + logger.debug( + f"[cache_fetch] LCP hit: shared={best_lcp_length} " + f"trimmed={excess} remaining={len(remaining)}" + ) + self._last_match_type = "lcp" + return trimmed_cache, remaining + self._stats.misses += 1 + self._last_match_type = "miss" + return None, tokens - def store(self, tokens: list[int], cache: list[Any]) -> bool: + def store( + self, tokens: list[int], cache: list[Any], evict_prefixes: bool = True + ) -> bool: """ Store KV cache for future reuse. @@ -328,6 +513,11 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool: Args: tokens: Token sequence that was processed. cache: The computed KV cache to store. + evict_prefixes: If True, evict existing entries whose token + sequence is a strict prefix of ``tokens``. Set to False + when storing prompt+output entries to preserve prompt-only + entries created by prompt_cache_save (those are the entries + that future requests will actually match). Returns: True if stored successfully, False if rejected. @@ -353,6 +543,36 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool: ) return False + # Prefix-subset eviction: remove entries whose token sequence + # is a strict prefix of the new entry. Uses sorted index for + # O(log N + K) lookup instead of O(N) scan. + if evict_prefixes and self._sorted_keys: + to_remove = [] + idx = bisect.bisect_left(self._sorted_keys, tokens_key) + # Scan backwards — prefixes of tokens_key are immediately before idx + for i in range(idx - 1, -1, -1): + key = self._sorted_keys[i] + klen = len(key) + if klen >= len(tokens_key): + continue + if tokens_key[:klen] == key: + to_remove.append(key) + elif key[0] != tokens_key[0]: + break + for key in to_remove: + old = self._entries.pop(key) + self._current_memory -= old.memory_bytes + self._stats.evictions += 1 + self._remove_from_sorted(key) + logger.debug( + f"[prefix_evict] removed {len(key)} tokens, " + f"freed {old.memory_bytes / _BYTES_PER_MB:.2f}MB, " + f"new_entry={len(tokens_key)} tokens" + ) + if to_remove: + self._stats.entry_count = len(self._entries) + self._stats.current_memory_bytes = self._current_memory + # Evict until we have room while ( self._current_memory + entry.memory_bytes > self._max_memory @@ -363,6 +583,7 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool: # Store entry self._entries[tokens_key] = entry self._current_memory += entry.memory_bytes + bisect.insort(self._sorted_keys, tokens_key) self._stats.entry_count = len(self._entries) self._stats.current_memory_bytes = self._current_memory @@ -374,6 +595,12 @@ def store(self, tokens: list[int], cache: list[Any]) -> bool: return True + def _remove_from_sorted(self, key: tuple[int, ...]) -> None: + """Remove a key from the sorted index using bisect for O(log N).""" + idx = bisect.bisect_left(self._sorted_keys, key) + if idx < len(self._sorted_keys) and self._sorted_keys[idx] == key: + self._sorted_keys.pop(idx) + def _evict_lru(self) -> None: """Evict the least recently used entry.""" if not self._entries: @@ -382,12 +609,13 @@ def _evict_lru(self) -> None: # popitem(last=False) removes oldest entry (FIFO order = LRU) tokens_key, entry = self._entries.popitem(last=False) self._current_memory -= entry.memory_bytes + self._remove_from_sorted(tokens_key) self._stats.evictions += 1 self._stats.entry_count = len(self._entries) self._stats.current_memory_bytes = self._current_memory logger.debug( - f"Evicted cache: {len(tokens_key)} tokens, " + f"[lru_evict] removed {len(tokens_key)} tokens, " f"freed {entry.memory_bytes / _BYTES_PER_MB:.2f}MB" ) @@ -405,6 +633,7 @@ def remove(self, tokens: list[int]) -> bool: entry = self._entries.pop(tokens_key, None) if entry is not None: self._current_memory -= entry.memory_bytes + self._remove_from_sorted(tokens_key) self._stats.entry_count = len(self._entries) self._stats.current_memory_bytes = self._current_memory return True @@ -413,6 +642,7 @@ def remove(self, tokens: list[int]) -> bool: def clear(self) -> None: """Clear all cached entries.""" self._entries.clear() + self._sorted_keys.clear() self._current_memory = 0 self._stats = CacheStats(max_memory_bytes=self._max_memory) logger.debug("Cache cleared") @@ -446,3 +676,184 @@ def __len__(self) -> int: def __contains__(self, tokens: list[int]) -> bool: """Check if tokens are cached.""" return tuple(tokens) in self._entries + + # ----------------------------------------------------------------- + # Disk persistence — survives server restarts + # ----------------------------------------------------------------- + + def save_to_disk(self, cache_dir: str) -> bool: + """Save all cache entries to disk using mlx_lm's safetensors format. + + Directory layout:: + + cache_dir/ + index.json # token keys + metadata per entry + entry_0.safetensors # KV arrays for entry 0 + entry_1.safetensors + ... + + Returns True if at least one entry was saved. + """ + import json + import os + import time as _time + + if not self._entries: + logger.info("[cache_persist] nothing to save (0 entries)") + return False + + t0 = _time.monotonic() + os.makedirs(cache_dir, exist_ok=True) + + try: + from mlx_lm.models.cache import save_prompt_cache + except ImportError: + logger.warning("[cache_persist] mlx_lm not available, cannot save") + return False + + index = { + "version": 2, + "num_entries": len(self._entries), + "total_memory_bytes": self._current_memory, + "entries": [], + } + + saved = 0 + for i, (tokens_key, entry) in enumerate(self._entries.items()): + entry_path = os.path.join(cache_dir, f"entry_{i}.safetensors") + try: + save_prompt_cache( + entry_path, + entry.cache, + metadata={"num_tokens": str(len(tokens_key))}, + ) + # Save tokens separately (can be 100K+ ints → binary is smaller) + tokens_path = os.path.join(cache_dir, f"entry_{i}_tokens.bin") + import array as _array + + arr = _array.array("i", tokens_key) # 32-bit signed ints + with open(tokens_path, "wb") as f: + arr.tofile(f) + + index["entries"].append( + { + "index": i, + "num_tokens": len(tokens_key), + "memory_bytes": entry.memory_bytes, + } + ) + saved += 1 + logger.info( + f"[cache_persist] saved entry {i}: " + f"{len(tokens_key)} tokens, " + f"{entry.memory_bytes / _BYTES_PER_MB:.1f}MB KV, " + f"file={entry_path}" + ) + except Exception as e: + logger.warning(f"[cache_persist] failed to save entry {i}: {e}") + + index_path = os.path.join(cache_dir, "index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + dt = _time.monotonic() - t0 + logger.info( + f"[cache_persist] SAVED {saved}/{len(self._entries)} entries " + f"to {cache_dir} in {dt:.1f}s " + f"({self._current_memory / _BYTES_PER_MB:.0f}MB total)" + ) + return saved > 0 + + def load_from_disk(self, cache_dir: str) -> int: + """Load cache entries from disk. + + Returns the number of entries successfully loaded. + """ + import json + import os + import time as _time + + index_path = os.path.join(cache_dir, "index.json") + if not os.path.exists(index_path): + logger.info(f"[cache_persist] no index at {index_path}, nothing to load") + return 0 + + t0 = _time.monotonic() + + try: + from mlx_lm.models.cache import load_prompt_cache + except ImportError: + logger.warning("[cache_persist] mlx_lm not available, cannot load") + return 0 + + with open(index_path) as f: + index = json.load(f) + + version = index.get("version", 1) + if version < 2: + logger.warning(f"[cache_persist] unsupported version {version}, skipping") + return 0 + + loaded = 0 + for entry_meta in index.get("entries", []): + i = entry_meta["index"] + entry_path = os.path.join(cache_dir, f"entry_{i}.safetensors") + tokens_path = os.path.join(cache_dir, f"entry_{i}_tokens.bin") + + if not os.path.exists(entry_path) or not os.path.exists(tokens_path): + logger.warning(f"[cache_persist] missing files for entry {i}, skipping") + continue + + try: + # Load tokens from binary + import array as _array + + arr = _array.array("i") + with open(tokens_path, "rb") as f: + arr.fromfile(f, entry_meta["num_tokens"]) + tokens = list(arr) + + # Load KV cache + cache = load_prompt_cache(entry_path) + + # Estimate memory + memory = estimate_kv_cache_memory(cache) + + # Check if it fits + if self._current_memory + memory > self._max_memory: + logger.info( + f"[cache_persist] entry {i} would exceed memory limit " + f"({(self._current_memory + memory) / _BYTES_PER_MB:.0f}MB > " + f"{self._max_memory / _BYTES_PER_MB:.0f}MB), stopping load" + ) + break + + tokens_key = tuple(tokens) + entry = _CacheEntry( + tokens=tokens_key, + cache=cache, + memory_bytes=memory, + ) + self._entries[tokens_key] = entry + self._current_memory += memory + bisect.insort(self._sorted_keys, tokens_key) + loaded += 1 + + logger.info( + f"[cache_persist] loaded entry {i}: " + f"{len(tokens)} tokens, " + f"{memory / _BYTES_PER_MB:.1f}MB KV" + ) + + except Exception as e: + logger.warning(f"[cache_persist] failed to load entry {i}: {e}") + + self._stats.entry_count = len(self._entries) + self._stats.current_memory_bytes = self._current_memory + + dt = _time.monotonic() - t0 + logger.info( + f"[cache_persist] LOADED {loaded} entries from {cache_dir} " + f"in {dt:.1f}s ({self._current_memory / _BYTES_PER_MB:.0f}MB total)" + ) + return loaded diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 91067e30..fba3ae02 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -359,7 +359,7 @@ def __init__( self._old_wired_limit = None if mx.metal.is_available(): self._old_wired_limit = mx.set_wired_limit( - mx.metal.device_info()["max_recommended_working_set_size"] + mx.device_info()["max_recommended_working_set_size"] ) def close(self) -> None: diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 6a086025..764f0543 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -662,15 +662,24 @@ async def stream_outputs( if output_queue is None: return - while True: - output = await output_queue.get() - if output is None: - break - yield output - - # Cleanup queue - if request_id in self.output_queues: - del self.output_queues[request_id] + finished_normally = False + try: + while True: + output = await output_queue.get() + if output is None: + finished_normally = True + break + yield output + if output.finished: + finished_normally = True + break + finally: + if not finished_normally: + logger.info(f"Aborting orphaned MLLM request {request_id}") + self.abort_request(request_id) + # Cleanup queue + if request_id in self.output_queues: + del self.output_queues[request_id] async def generate( self, diff --git a/vllm_mlx/optimizations.py b/vllm_mlx/optimizations.py index ac4aafd9..624d5d77 100644 --- a/vllm_mlx/optimizations.py +++ b/vllm_mlx/optimizations.py @@ -86,7 +86,7 @@ def get_system_memory_gb() -> float: except Exception: # Fallback: try to get from MLX device info try: - device_info = mx.metal.device_info() + device_info = mx.device_info() if "memory_size" in device_info: return device_info["memory_size"] / (1024**3) except Exception: @@ -105,7 +105,7 @@ def detect_hardware() -> HardwareInfo: HardwareInfo with detected hardware specifications """ try: - device_info = mx.metal.device_info() + device_info = mx.device_info() device_name = device_info.get("device_name", "") actual_memory_gb = get_system_memory_gb() @@ -182,7 +182,7 @@ def get_optimization_status() -> dict: dict with hardware info and MLX configuration """ hw = detect_hardware() - device_info = mx.metal.device_info() + device_info = mx.device_info() flash_available = hasattr(mx, "fast") and hasattr( mx.fast, "scaled_dot_product_attention" ) diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index 3e79803b..e8f47a32 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -186,8 +186,8 @@ def fetch_cache(self, tokens: List[int]) -> Tuple[Optional[List[Any]], List[int] self.stats.hits += 1 self.stats.tokens_saved += len(tokens) self._touch_lru(tokens_tuple) - # Deep copy to prevent mutation - return copy.deepcopy(cache_entry.prompt_cache), [] + # No copy needed - MLX arrays are immutable + return cache_entry.prompt_cache, [] if shorter: # Shorter prefix cached - return cache and remaining tokens @@ -197,7 +197,8 @@ def fetch_cache(self, tokens: List[int]) -> Tuple[Optional[List[Any]], List[int] self.stats.tokens_saved += len(shorter) self._touch_lru(tuple(shorter)) remaining = tokens[len(shorter) :] - return copy.deepcopy(cache_entry.prompt_cache), remaining + # No copy needed - MLX arrays are immutable + return cache_entry.prompt_cache, remaining if longer: # Longer prefix cached - trim to match and return diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index d7b9db1f..41679c0b 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -111,6 +111,7 @@ class Request: prompt_cache: Optional[List[Any]] = None # Cached KV state from prefix cache cached_tokens: int = 0 # Number of tokens retrieved from cache remaining_tokens: Optional[List[int]] = None # Tokens still needing processing + prefix_boundary: int = 0 # Token count for shared prefix (messages[:-1]) # Paged cache fields (for BlockAwarePrefixCache) block_table: Optional["BlockTable"] = None # Block table for paged cache @@ -129,6 +130,12 @@ class Request: # Metadata finish_reason: Optional[str] = None + first_token_time: Optional[float] = ( + None # Time when first output token was generated + ) + cache_hit_type: Optional[str] = ( + None # Type of cache hit: exact/prefix/supersequence/lcp/miss + ) @property def num_output_tokens(self) -> int: diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index c5f2c52d..26ef5315 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -17,6 +17,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Set, Tuple +import mlx.core as mx from mlx_lm.generate import BatchGenerator from mlx_lm.sample_utils import make_sampler @@ -77,6 +78,17 @@ class SchedulerConfig: paged_cache_block_size: int = 64 # Tokens per block max_cache_blocks: int = 1000 # Maximum number of cache blocks + # Chunked prefill: max tokens to prefill per scheduler step (0 = disabled) + # When enabled, large prompts are split into chunks so that active + # generation requests are not starved during long prefills. + chunked_prefill_tokens: int = 0 + + # Mid-prefill cache saving: save intermediate KV cache every N tokens + # during chunked prefill. If the client disconnects mid-prefill, the + # saved cache is reused for the next request with the same prefix. + # 0 = disabled. Only effective when chunked_prefill_tokens > 0. + mid_prefill_save_interval: int = 8192 + @dataclass class SchedulerOutput: @@ -98,6 +110,375 @@ class SchedulerOutput: has_work: bool = False +def _install_chunked_prefill( + batch_gen: "BatchGenerator", + budget: int, + mid_prefill_save=None, + prompt_cache_save=None, + pending_abort_ids: Optional[Set[str]] = None, + uid_to_request_id: Optional[Dict[int, str]] = None, + requests: Optional[Dict[str, Any]] = None, +) -> None: + """ + Monkey-patch a BatchGenerator instance so that large prefills are + broken into chunks of at most *budget* tokens each. + + Between chunks the generation loop gets a chance to produce one token + for every active request, preventing starvation during long prefills. + + Args: + batch_gen: The BatchGenerator to patch. + budget: Max tokens per prefill chunk. + mid_prefill_save: Optional callback(uid, processed, prompt_cache) + called after each chunk to save intermediate KV cache state. + """ + import time as _time + + from mlx_lm.generate import ( + Batch, + _left_pad_prompts, + _make_cache, + _merge_caches, + _right_pad_prompts, + ) + + # Keep references to originals + _orig_next = batch_gen._next + _orig_remove = batch_gen.remove + _orig_process_prompts = batch_gen._process_prompts + + # Partial prefill state (None when no prefill in progress) + batch_gen._partial = None + + # Monkey-patch _process_prompts to capture prompt-only cache state. + # At the point where _process_prompts returns, the Batch cache contains + # the exact prompt-only state: all prompt tokens have been processed + # through the model, but no output token has been fed back yet. + # This is the only safe capture point for hybrid Mamba+Transformer + # models whose MambaCache state is cumulative. + if prompt_cache_save is not None: + + def _patched_process_prompts(prompts, _self=batch_gen): + batch = _orig_process_prompts(prompts) + for e, uid in enumerate(batch.uids): + if batch.num_tokens[e] == 0: + try: + prompt_cache_save(uid, batch.extract_cache(e)) + except Exception: + pass + return batch + + batch_gen._process_prompts = _patched_process_prompts + + def _generation_step(self=batch_gen): + """Run one generation step on the active batch. Returns responses.""" + batch = self.active_batch + if batch is None or len(batch) == 0: + return [] + + tic_gen = _time.perf_counter() + y, logprobs = batch.y, batch.logprobs + for i, toks in enumerate(batch.tokens): + batch.tokens[i] = mx.concatenate((toks, y[i : i + 1])) + batch.y, batch.logprobs = self._step( + y[:, None], + batch.cache, + batch.samplers, + batch.logits_processors, + batch.tokens, + ) + mx.async_eval(batch.y, batch.logprobs) + + y = y.tolist() + self._stats.generation_time += _time.perf_counter() - tic_gen + + keep_idx = [] + end_idx = [] + responses = [] + for e, (t, uid, num_tok, max_tok) in enumerate( + zip(y, batch.uids, batch.num_tokens, batch.max_tokens) + ): + cache_out = None + num_tok += 1 + batch.num_tokens[e] = num_tok + if t in self.stop_tokens: + finish_reason = "stop" + end_idx.append(e) + elif num_tok >= max_tok: + finish_reason = "length" + end_idx.append(e) + else: + finish_reason = None + keep_idx.append(e) + if finish_reason is not None: + cache_out = batch.extract_cache(e) + responses.append( + self.Response(uid, t, logprobs[e], finish_reason, cache_out) + ) + + if len(end_idx): + if len(keep_idx) > 0: + batch.filter(keep_idx) + else: + self.active_batch = None + + self._stats.generation_tokens += len(responses) + return responses + + def _chunked_next(self=batch_gen): # noqa: C901 + """ + Replacement for _next() that chunks large prefills. + + Only intercepts when: + 1. A partial prefill is in progress (_partial is not None) + 2. The next prompt batch exceeds the budget + + Everything else delegates to the original _next(). + """ + # ----- Continue a partial prefill ----- + if self._partial is not None: + # Check for pending aborts BEFORE processing next chunk + if pending_abort_ids is not None and uid_to_request_id is not None: + partial_rids = {uid_to_request_id.get(u) for u in self._partial["uids"]} + aborted_rids = partial_rids & pending_abort_ids + if aborted_rids: + logger.info( + f"[chunked_prefill] abort detected mid-prefill, " + f"clearing partial for: {aborted_rids}" + ) + self._partial = None + mx.clear_cache() + return _generation_step() + + tic = _time.perf_counter() + partial = self._partial + inputs = partial["inputs"] + prompt_cache = partial["cache"] + remaining = inputs.shape[1] + + n_to_process = min(budget, remaining - 1) if remaining > 1 else 0 + + if n_to_process > 0: + self.model(inputs[:, :n_to_process], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + inputs = inputs[:, n_to_process:] + partial["inputs"] = inputs + partial["processed"] += n_to_process + + self.prompt_progress_callback( + [ + (uid, partial["processed"], partial["total"]) + for uid in partial["uids"] + ] + ) + + # Save intermediate cache for disconnect resilience + if mid_prefill_save is not None and len(partial["uids"]) == 1: + mid_prefill_save( + partial["uids"][0], partial["processed"], prompt_cache + ) + + if partial.get("is_cached"): + mx.clear_cache() + + # Check if prefill is done (only 1 token left or 0) + if inputs.shape[1] <= 1: + # Finalize + if partial.get("is_cached"): + mx.eval([c.state for c in prompt_cache]) + inputs = partial["last_inputs"] + + for c in prompt_cache: + c.finalize() + mx.clear_cache() + + y, logprobs = self._step( + inputs, + prompt_cache, + partial["samplers"], + partial["logits_processors"], + partial["tokens"], + ) + mx.async_eval(y, logprobs) + + new_batch = Batch( + list(partial["uids"]), + y, + logprobs, + list(partial["max_tokens"]), + [0] * len(partial["uids"]), + prompt_cache, + list(partial["samplers"]), + list(partial["logits_processors"]), + partial["tokens"], + ) + + # Save prompt-only cache BEFORE merging into active batch. + # This is the chunked-prefill equivalent of the + # _patched_process_prompts hook — at this point the cache + # contains the exact prompt-only state (num_tokens == 0). + if prompt_cache_save is not None and len(partial["uids"]) == 1: + uid = partial["uids"][0] + try: + prompt_cache_save(uid, new_batch.extract_cache(0)) + except Exception: + pass + + if self.active_batch is None: + self.active_batch = new_batch + else: + self.active_batch.extend(new_batch) + + self._partial = None + self._stats.prompt_time += _time.perf_counter() - tic + else: + # Not done yet — record prompt time for this chunk + self._stats.prompt_time += _time.perf_counter() - tic + + # Generation step for active requests between chunks + return _generation_step() + + # ----- No partial — check if next prompt batch needs chunking ----- + num_active = len(self.active_batch) if self.active_batch else 0 + num_to_add = self.completion_batch_size - num_active + + if num_to_add >= self.prefill_batch_size and self.unprocessed_prompts: + batch_prompts = self.unprocessed_prompts[: self.prefill_batch_size] + if batch_prompts: + total_tokens = sum(len(p[1]) for p in batch_prompts) + + # Check if any prompt has a prefix_boundary that + # requires two-phase prefill for cache save at that boundary. + _needs_boundary_split = False + if requests is not None and uid_to_request_id is not None: + for _uid, _toks, *_ in batch_prompts: + _rid = uid_to_request_id.get(_uid) + _req = requests.get(_rid) if _rid else None + if _req and getattr(_req, "prefix_boundary", 0) > 0: + _needs_boundary_split = True + break + + if total_tokens > budget or _needs_boundary_split: + # Large prompt batch or prefix boundary — start partial prefill + tic = _time.perf_counter() + + # Eval outstanding generation tokens before switching + if self.active_batch is not None: + mx.eval(self.active_batch.y, self.active_batch.logprobs) + self._stats.generation_time += _time.perf_counter() - tic + tic = _time.perf_counter() + + ( + uids, + inputs_raw, + max_tokens_list, + caches, + samplers, + logits_processors, + ) = zip(*batch_prompts) + lengths = [len(p) for p in inputs_raw] + max_length = max(lengths) + padding = [max_length - ln for ln in lengths] + tokens = [mx.array(inp) for inp in inputs_raw] + is_cached = not all(c[0].empty() for c in caches) + + self._stats.prompt_tokens += sum(lengths) + + if not is_cached: + padded = _left_pad_prompts(inputs_raw, max_length=max_length) + prompt_cache = _make_cache(self.model, padding) + else: + last_inputs = mx.array([p[-1:] for p in inputs_raw]) + padded = _right_pad_prompts(inputs_raw, max_length=max_length) + prompt_cache = _merge_caches(caches) + for c in prompt_cache: + c.prepare( + lengths=[ln - 1 for ln in lengths], + right_padding=padding, + ) + + # Remove from unprocessed + self.unprocessed_prompts = self.unprocessed_prompts[ + self.prefill_batch_size : + ] + + # Process first chunk — if prefix_boundary is set, + # use it as the first chunk size so that mid_prefill_save + # can capture the exact prefix cache state (critical for + # hybrid Mamba+Transformer models where trim is unsafe). + # When the request already has cached tokens (cache hit), + # adjust the boundary relative to the remaining tokens. + _first_chunk = budget + if _needs_boundary_split and len(batch_prompts) == 1: + _uid0 = uids[0] + _rid0 = uid_to_request_id.get(_uid0) + _req0 = requests.get(_rid0) if _rid0 else None + _pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0 + _cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0 + _adjusted_pb = _pb - _cached + if 0 < _adjusted_pb < padded.shape[1]: + _first_chunk = _adjusted_pb + n_to_process = min(_first_chunk, padded.shape[1] - 1) + if n_to_process > 0: + self.model(padded[:, :n_to_process], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + padded = padded[:, n_to_process:] + if is_cached: + mx.clear_cache() + + self._partial = { + "uids": list(uids), + "inputs": padded, + "cache": prompt_cache, + "tokens": tokens, + "max_tokens": list(max_tokens_list), + "samplers": list(samplers), + "logits_processors": list(logits_processors), + "processed": n_to_process, + "total": max_length, + "is_cached": is_cached, + } + if is_cached: + self._partial["last_inputs"] = last_inputs + + self.prompt_progress_callback( + [ + (uid, n_to_process, max_length) + for uid in self._partial["uids"] + ] + ) + + # Save intermediate cache for disconnect resilience + if mid_prefill_save is not None and len(uids) == 1: + mid_prefill_save(uids[0], n_to_process, prompt_cache) + + self._stats.prompt_time += _time.perf_counter() - tic + + # Generation step for active requests + return _generation_step() + + # Small prompts, pure generation, or no work — delegate to original + return _orig_next() + + def _patched_remove(uids_to_remove, _self=batch_gen): + """Clear partial state if aborted request is being prefilled.""" + if _self._partial is not None: + partial_uids = set(_self._partial["uids"]) + if partial_uids & set(uids_to_remove): + logger.info( + f"[chunked_prefill] clearing partial state for aborted uids: " + f"{partial_uids & set(uids_to_remove)}" + ) + _self._partial = None + mx.clear_cache() # flush Metal encoders after dropping partial state + _orig_remove(uids_to_remove) + + batch_gen._next = _chunked_next + batch_gen.remove = _patched_remove + + logger.info(f"[chunked_prefill] installed with budget={budget} tokens per step") + + class Scheduler: """ Scheduler for continuous batching using mlx-lm BatchGenerator. @@ -192,11 +573,21 @@ def __init__( f"Prefix cache enabled with max_entries={self.config.prefix_cache_size}" ) + # Thread-safe set for deferred aborts (main thread → executor thread) + # CPython GIL guarantees set.add() and `x in set` are atomic. + self._pending_abort_ids: Set[str] = set() + # Statistics self.num_requests_processed = 0 self.total_prompt_tokens = 0 self.total_completion_tokens = 0 + # Memory management: periodic mx.clear_cache() to free Metal command buffers + # Lower interval = less VRAM spike during generation but slight throughput cost + self._step_count = 0 + self._clear_cache_interval = 32 + self._memory_log_interval = 256 + def _get_actual_tokenizer(self, tokenizer: Any) -> Any: """ Get the actual tokenizer from a processor or tokenizer. @@ -254,7 +645,16 @@ def _create_batch_generator( if sampling_params.stop_token_ids: stop_tokens.update(sampling_params.stop_token_ids) - return BatchGenerator( + def _prefill_progress(progress_list): + """Log prefill progress for each uid chunk.""" + for uid, processed, total in progress_list: + rid = self.uid_to_request_id.get(uid, "?") + logger.info( + f"[prefill] request={rid[:12] if isinstance(rid, str) else rid} " + f"tokens={processed}/{total}" + ) + + bg = BatchGenerator( model=self.model, max_tokens=sampling_params.max_tokens, stop_tokens=stop_tokens, @@ -262,8 +662,159 @@ def _create_batch_generator( prefill_batch_size=self.config.prefill_batch_size, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, + prompt_progress_callback=_prefill_progress, ) + # Install chunked prefill when explicitly configured OR when + # memory-aware cache is active (needed for prefix_boundary saves + # in agentic multi-turn workloads with hybrid Mamba+Transformer models). + chunked_budget = self.config.chunked_prefill_tokens + need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None + if need_chunked: + if chunked_budget <= 0: + # No explicit budget — use a very large value so normal + # prompts pass through unchanged. Prefix boundary splits + # still trigger via _needs_boundary_split. + chunked_budget = 999_999 + mid_prefill_cb = None + save_interval = self.config.mid_prefill_save_interval + if save_interval > 0 and self.memory_aware_cache is not None: + mid_prefill_cb = self._make_mid_prefill_save_callback(save_interval) + logger.info(f"[mid_prefill_cache] enabled, interval={save_interval}") + prompt_cache_cb = None + if self.memory_aware_cache is not None: + prompt_cache_cb = self._make_prompt_cache_save_callback() + _install_chunked_prefill( + bg, + chunked_budget, + mid_prefill_cb, + prompt_cache_save=prompt_cache_cb, + pending_abort_ids=self._pending_abort_ids, + uid_to_request_id=self.uid_to_request_id, + requests=self.requests, + ) + + return bg + + def _make_prompt_cache_save_callback(self): + """Create a callback that stores prompt-only KV/Mamba cache. + + Called from ``_generation_step`` right before the first output token + is fed into the model. At that point ``num_tokens == 0`` and the + batch cache contains the exact prompt-only state (correct for both + KVCache and MambaCache/ArraysCache layers). + + The cache is stored with key = prompt_token_ids so that a future + request with the identical prompt gets an exact hit. + """ + import time as _time + + def _prompt_cache_save(uid, extracted_cache): + request_id = self.uid_to_request_id.get(uid) + if not request_id: + return + request = self.requests.get(request_id) + if not request or not request.prompt_token_ids: + return + + prompt_tokens = list(request.prompt_token_ids) + _t0 = _time.monotonic() + # evict_prefixes=False: keep mid-prefill boundary entries so + # that future requests with the same prefix but different + # suffix get a prefix cache hit (critical for agentic multi-turn). + stored = self.memory_aware_cache.store( + prompt_tokens, extracted_cache, evict_prefixes=False + ) + _dt = _time.monotonic() - _t0 + if stored: + logger.info( + f"[prompt_cache_save] request={request_id[:12]} " + f"prompt_tokens={len(prompt_tokens)} " + f"store_time={_dt:.3f}s" + ) + + return _prompt_cache_save + + def _make_mid_prefill_save_callback(self, save_interval: int): + """Create a callback for saving intermediate KV cache during chunked prefill. + + The callback is called after each chunk with (uid, processed_tokens, + prompt_cache). It extracts the cache state (immutable MLX array + snapshots), reconstructs KVCache objects, and stores them in the + memory-aware prefix cache so that a subsequent request with the same + prompt prefix can skip the already-computed tokens. + """ + import time as _time + + def _mid_prefill_save(uid, processed_tokens, prompt_cache): + request_id = self.uid_to_request_id.get(uid) + if not request_id: + return + request = self.requests.get(request_id) + if not request or not request.prompt_token_ids: + return + + total_cached = (request.cached_tokens or 0) + processed_tokens + + # Always save at prefix_boundary (message boundary for cache + # reuse with different final user messages). + prefix_boundary = getattr(request, "prefix_boundary", 0) + at_prefix_boundary = prefix_boundary > 0 and total_cached == prefix_boundary + + # Throttle: only save every save_interval tokens, + # unless we're at the prefix boundary. + last_save = getattr(request, "_mid_prefill_last_save", 0) + if not at_prefix_boundary and total_cached - last_save < save_interval: + return + + # Extract immutable state snapshots + extracted = self._extract_cache_states(prompt_cache) + if not extracted: + return + + # Reconstruct cache objects (directly usable by BatchGenerator) + reconstructed = self._reconstruct_cache_from_states(extracted) + if not reconstructed: + return + + prefix_tokens = list(request.prompt_token_ids[:total_cached]) + + # Remove previous intermediate entry to avoid memory waste + old_key = getattr(request, "_mid_prefill_cache_key", None) + if old_key is not None: + self.memory_aware_cache.remove(list(old_key)) + + _t0 = _time.monotonic() + stored = self.memory_aware_cache.store(prefix_tokens, reconstructed) + _dt = _time.monotonic() - _t0 + + if stored: + request._mid_prefill_last_save = total_cached + request._mid_prefill_cache_key = tuple(prefix_tokens) + logger.info( + f"[mid_prefill_cache] request={request_id[:12]} " + f"saved {total_cached}/{len(request.prompt_token_ids)} tokens " + f"({total_cached * 100 // len(request.prompt_token_ids)}%) " + f"store_time={_dt:.3f}s" + ) + else: + logger.debug( + f"[mid_prefill_cache] request={request_id[:12]} " + f"store rejected for {total_cached} tokens" + ) + + return _mid_prefill_save + + def _close_batch_generator(self) -> None: + """Properly close BatchGenerator to restore wired_limit.""" + if self.batch_generator is not None: + try: + if hasattr(self.batch_generator, "close"): + self.batch_generator.close() + except Exception as e: + logger.debug(f"Error closing BatchGenerator: {e}") + self.batch_generator = None + def _ensure_batch_generator(self, sampling_params: SamplingParams) -> None: """Ensure BatchGenerator exists with compatible settings.""" sampler_params = ( @@ -285,20 +836,26 @@ def _ensure_batch_generator(self, sampling_params: SamplingParams) -> None: ) return - # Clear prefix cache when BatchGenerator changes - # BatchKVCache objects are tied to their generator instance + # Keep prefix cache across BatchGenerator recreations. + # KV cache entries depend only on the input tokens, not on + # sampling params (temperature, top_p, min_p). Since the + # server runs a single model, the cache is always valid. if self.batch_generator is not None: + n_entries = 0 if self.memory_aware_cache is not None: - logger.debug( - "Clearing memory-aware cache: BatchGenerator being recreated" - ) - self.memory_aware_cache.clear() + n_entries = len(self.memory_aware_cache._entries) elif self.prefix_cache is not None: - logger.debug( - "Clearing prefix cache: BatchGenerator being recreated" + n_entries = ( + len(self.prefix_cache) + if hasattr(self.prefix_cache, "__len__") + else 0 ) - self.prefix_cache.clear() + logger.info( + f"[batch_generator] recreating (sampler params changed), " + f"keeping {n_entries} cache entries" + ) + self._close_batch_generator() self.batch_generator = self._create_batch_generator(sampling_params) self._current_sampler_params = sampler_params @@ -306,8 +863,11 @@ def _validate_cache(self, cache: Any) -> bool: """ Validate that a cache object is usable. - This prevents NoneType errors when mlx-lm's BatchKVCache - contains invalid/stale references. + Checks for None references AND shape compatibility. Restored + cache entries must have batch_size == 1 (single sequence) so + they can be merged into the running batch by _merge_caches. + A shape mismatch here (e.g. batch=2 from a stale entry) would + cause a concatenation crash inside _merge_caches. Args: cache: The cache object to validate @@ -331,6 +891,23 @@ def _validate_cache(self, cache: Any) -> bool: return False if hasattr(layer_cache, "values") and layer_cache.values is None: return False + # Validate batch dimension == 1 for KVCache layers + if hasattr(layer_cache, "keys") and layer_cache.keys is not None: + if layer_cache.keys.shape[0] != 1: + logger.debug( + f"Cache layer invalid: keys batch={layer_cache.keys.shape[0]}, expected 1" + ) + return False + # Validate batch dimension for MambaCache layers + if hasattr(layer_cache, "cache") and isinstance( + layer_cache.cache, list + ): + for arr in layer_cache.cache: + if arr is not None and arr.shape[0] != 1: + logger.debug( + f"Cache layer invalid: mamba batch={arr.shape[0]}, expected 1" + ) + return False # Check BatchKVCache structure if hasattr(cache, "caches"): @@ -363,13 +940,14 @@ def _extract_cache_states(self, raw_cache: List[Any]) -> List[Dict[str, Any]]: for layer_cache in raw_cache: try: if hasattr(layer_cache, "state") and hasattr(layer_cache, "meta_state"): - state = layer_cache.state # (keys, values) MLX arrays + state = layer_cache.state # (keys, values) or more for Mamba meta = layer_cache.meta_state # (offset,) as strings extracted.append( { "state": state, "meta_state": meta, "class_name": type(layer_cache).__name__, + "class_ref": type(layer_cache), } ) except Exception as e: @@ -378,6 +956,72 @@ def _extract_cache_states(self, raw_cache: List[Any]) -> List[Dict[str, Any]]: return extracted if len(extracted) == len(raw_cache) else [] + def _reconstruct_cache_from_states( + self, extracted_states: List[Dict[str, Any]] + ) -> Optional[List[Any]]: + """ + Reconstruct cache objects from extracted cache states. + + This is the inverse of _extract_cache_states(). Uses mlx-lm's + _BaseCache.from_state() to reconstruct any cache type (KVCache, + MambaCache, etc.) from its state/meta_state. + + Args: + extracted_states: List of dicts from _extract_cache_states() + + Returns: + List of cache objects, or None if reconstruction fails + """ + if not extracted_states: + return None + + try: + caches = [] + for layer_state in extracted_states: + state = layer_state.get("state") + meta_state = layer_state.get("meta_state") + cache_cls = layer_state.get("class_ref") + if state is None: + return None + + if cache_cls is not None and hasattr(cache_cls, "from_state"): + # BatchKVCache doesn't inherit from KVCache, so + # _merge_caches can't handle it. Convert to KVCache + # (safe because mid-prefill save is always batch_size=1). + from mlx_lm.models.cache import ( + BatchKVCache as _BatchKVCache, + KVCache as _KVCache, + ) + + if cache_cls is _BatchKVCache: + # BatchKVCache.state = (keys, values, offset, left_padding) + keys, values = state[0], state[1] + cache = _KVCache() + cache.keys = keys + cache.values = values + cache.offset = keys.shape[2] + else: + cache = cache_cls.from_state(state, meta_state) + else: + # Fallback: try KVCache manual reconstruction + from mlx_lm.models.cache import KVCache + + if len(state) != 2: + return None + cache = KVCache() + cache.keys, cache.values = state + cache.offset = ( + int(meta_state[0]) if meta_state else cache.keys.shape[2] + ) + + caches.append(cache) + + return caches + + except Exception as e: + logger.info(f"[mid_prefill_cache] reconstruct EXCEPTION: {e}") + return None + def add_request(self, request: Request) -> None: """ Add a new request to the scheduler. @@ -418,6 +1062,7 @@ def add_request(self, request: Request) -> None: request.prompt_token_ids, ) if block_table and block_table.num_tokens > 0: + request.cache_hit_type = "hit" # Reconstruct actual KVCache objects from stored tensor data reconstructed = self.block_aware_cache.reconstruct_cache(block_table) if reconstructed: @@ -433,30 +1078,44 @@ def add_request(self, request: Request) -> None: ) else: # Reconstruction failed, treat as cache miss + request.cache_hit_type = "miss" request.remaining_tokens = request.prompt_token_ids logger.debug( f"Request {request.request_id}: paged cache reconstruction failed" ) else: + request.cache_hit_type = "miss" request.remaining_tokens = request.prompt_token_ids elif self.memory_aware_cache is not None: # Use memory-aware prefix cache + import time as _time + + _fetch_t0 = _time.monotonic() cache, remaining = self.memory_aware_cache.fetch(request.prompt_token_ids) + _fetch_dt = _time.monotonic() - _fetch_t0 + request.cache_hit_type = self.memory_aware_cache._last_match_type if cache: request.prompt_cache = cache request.cached_tokens = len(request.prompt_token_ids) - len(remaining) request.remaining_tokens = remaining - logger.debug( - f"Request {request.request_id}: memory-aware cache hit, " - f"{request.cached_tokens} tokens cached, " - f"{len(remaining)} tokens remaining" + logger.info( + f"[cache_fetch] request={request.request_id[:12]} HIT " + f"prompt_tokens={len(request.prompt_token_ids)} " + f"cached={request.cached_tokens} remaining={len(remaining)} " + f"time={_fetch_dt:.3f}s" ) else: request.remaining_tokens = request.prompt_token_ids + logger.info( + f"[cache_fetch] request={request.request_id[:12]} MISS " + f"prompt_tokens={len(request.prompt_token_ids)} " + f"time={_fetch_dt:.3f}s entries={len(self.memory_aware_cache._entries)}" + ) elif self.prefix_cache is not None: # Use legacy prefix cache cache, remaining = self.prefix_cache.fetch_cache(request.prompt_token_ids) if cache: + request.cache_hit_type = "hit" request.prompt_cache = cache request.cached_tokens = len(request.prompt_token_ids) - len(remaining) request.remaining_tokens = remaining @@ -466,8 +1125,10 @@ def add_request(self, request: Request) -> None: f"{len(remaining)} tokens remaining" ) else: + request.cache_hit_type = "miss" request.remaining_tokens = request.prompt_token_ids else: + request.cache_hit_type = "miss" request.remaining_tokens = request.prompt_token_ids # Add to tracking @@ -480,41 +1141,82 @@ def add_request(self, request: Request) -> None: def abort_request(self, request_id: str) -> bool: """ - Abort a request. + Queue request for abort. Thread-safe, called from any thread. + + The actual abort is deferred to the executor thread (inside step()) + to avoid race conditions with in-flight Metal GPU operations. Args: request_id: The request ID to abort Returns: - True if request was found and aborted, False otherwise + True (abort is always enqueued) + """ + self._pending_abort_ids.add(request_id) + logger.info(f"[abort_request] {request_id[:12]} enqueued for deferred abort") + return True + + def _process_pending_aborts(self) -> None: + """Drain and process pending abort requests. Called from executor thread.""" + while self._pending_abort_ids: + request_id = self._pending_abort_ids.pop() + self._do_abort_request(request_id) + + def _do_abort_request(self, request_id: str) -> bool: + """ + Actually abort a request. Must be called from the executor thread. + + Handles the case where the request was already removed from + self.requests by _cleanup_request() but still lives in the + BatchGenerator (e.g. in _partial or active_batch). + + Args: + request_id: The request ID to abort + + Returns: + True if any cleanup was performed, False otherwise """ request = self.requests.get(request_id) - if request is None: - return False + was_waiting = False + was_running = False + removed_from_batch = False # Remove from waiting queue - if request.status == RequestStatus.WAITING: + if request is not None and request.status == RequestStatus.WAITING: + was_waiting = True try: self.waiting.remove(request) except ValueError: pass - # Remove from running (BatchGenerator) - if request.request_id in self.request_id_to_uid: - uid = self.request_id_to_uid[request.request_id] + # Remove from running (BatchGenerator) — do this even if request + # was already cleaned up from self.requests, because the UID may + # still be live inside the BatchGenerator (_partial / active_batch). + if request_id in self.request_id_to_uid: + was_running = True + uid = self.request_id_to_uid[request_id] if self.batch_generator is not None: self.batch_generator.remove([uid]) + removed_from_batch = True del self.uid_to_request_id[uid] - del self.request_id_to_uid[request.request_id] + del self.request_id_to_uid[request_id] if request_id in self.running: del self.running[request_id] - # Mark as aborted - request.set_finished(RequestStatus.FINISHED_ABORTED) + if request is not None: + request.set_finished(RequestStatus.FINISHED_ABORTED) self.finished_req_ids.add(request_id) - logger.debug(f"Aborted request {request_id}") + # Flush Metal encoders after removing arrays from batch + mx.clear_cache() + + logger.info( + f"[abort_request] {request_id[:12]} ABORTED " + f"was_waiting={was_waiting} was_running={was_running} " + f"removed_from_batch={removed_from_batch} " + f"remaining_running={len(self.running)} remaining_waiting={len(self.waiting)}" + ) return True def has_requests(self) -> bool: @@ -577,12 +1279,34 @@ def _schedule_waiting(self) -> List[Request]: request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids - # Insert into BatchGenerator with optional cache - uids = self.batch_generator.insert( - [tokens_to_process], - max_tokens=[request.sampling_params.max_tokens], - caches=[cache_to_use] if cache_to_use else None, - ) + # Insert into BatchGenerator with optional cache. + # Wrap in try/except: if cache shapes are incompatible + # (e.g. stale entry after BatchGenerator recreation), + # fall back to no-cache insert instead of crashing. + try: + uids = self.batch_generator.insert( + [tokens_to_process], + max_tokens=[request.sampling_params.max_tokens], + caches=[cache_to_use] if cache_to_use else None, + ) + except Exception as e: + if cache_to_use is not None: + logger.warning( + f"[cache_insert_error] request={request.request_id[:12]} " + f"cache insert failed ({e}), retrying without cache" + ) + cache_to_use = None + request.prompt_cache = None + request.cached_tokens = 0 + request.remaining_tokens = request.prompt_token_ids + tokens_to_process = request.prompt_token_ids + uids = self.batch_generator.insert( + [tokens_to_process], + max_tokens=[request.sampling_params.max_tokens], + caches=None, + ) + else: + raise if uids: uid = uids[0] @@ -599,9 +1323,13 @@ def _schedule_waiting(self) -> List[Request]: if request.cached_tokens > 0 else "" ) - logger.debug( - f"Scheduled request {request.request_id} (uid={uid}) " - f"with {request.num_prompt_tokens} tokens{cache_info}" + tokens_to_prefill = len(tokens_to_process) + logger.info( + f"[schedule] request={request.request_id[:12]} uid={uid} " + f"prompt_tokens={request.num_prompt_tokens} " + f"tokens_to_prefill={tokens_to_prefill}{cache_info} " + f"max_tokens={request.sampling_params.max_tokens} " + f"running={len(self.running)} waiting={len(self.waiting)}" ) return scheduled @@ -633,8 +1361,17 @@ def _process_batch_responses( # Append token to request request.append_output_token(response.token) - # Decode the new token - new_text = self._decode_tokens([response.token]) + # Record first token time for TTFT metric + if request.first_token_time is None and request.num_output_tokens > 0: + import time as _time + + request.first_token_time = _time.time() + + # Decode the new token (skip stop tokens — they are not content) + if response.finish_reason == "stop": + new_text = "" + else: + new_text = self._decode_tokens([response.token]) # Create output output = RequestOutput( @@ -661,7 +1398,7 @@ def _process_batch_responses( output.output_text = self._decode_tokens(request.output_token_ids) request.output_text = output.output_text - # Extract cache for future reuse + # Extract cache for future reuse (critical for agentic multi-turn) if hasattr(response, "prompt_cache"): try: # prompt_cache may be callable or direct attribute @@ -735,6 +1472,11 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: # unused blocks when under memory pressure. elif self.memory_aware_cache is not None: + # Keep mid-prefill entry as prefix cache for future + # requests that share a common prefix (e.g. same system + # prompt + tools but different user message). LRU + # eviction handles memory pressure. + # Store in memory-aware prefix cache # Key includes both prompt and output tokens for multi-turn chat caching if ( @@ -745,13 +1487,33 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: full_token_sequence = list(request.prompt_token_ids) + list( request.output_token_ids ) - self.memory_aware_cache.store( + import time as _time + + _store_t0 = _time.monotonic() + stored = self.memory_aware_cache.store( full_token_sequence, request._extracted_cache, + evict_prefixes=False, ) - logger.debug( - f"Stored memory-aware cache for request {request_id} " - f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)" + _store_dt = _time.monotonic() - _store_t0 + # NOTE: We intentionally do NOT store a prompt-only + # cache entry. Hybrid Mamba+Transformer models + # (like Qwen3-Coder-Next) have MambaCache layers + # whose state is cumulative and cannot be trimmed + # back to "prompt only". Reusing such state causes + # the model to immediately produce EOS. + # The full prompt+output entry is stored above; a + # future request with the same prompt will hit the + # supersequence match path in the fetch, which is + # now disabled for safety (see memory_cache.py). + + logger.info( + f"[cache_store] request={request_id[:12]} " + f"tokens={len(full_token_sequence)} " + f"({len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output) " + f"stored={stored} time={_store_dt:.3f}s " + f"cache_entries={len(self.memory_aware_cache._entries)} " + f"cache_mem={self.memory_aware_cache._current_memory / 1e6:.0f}MB" ) except Exception as e: logger.debug( @@ -781,6 +1543,24 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: except Exception as e: logger.debug(f"Failed to store cache for {request_id}: {e}") + # Evaluate stored cache tensors incrementally (per-layer) to prevent + # a deferred batch evaluation spike when all lazy ops resolve at once. + # This spreads the VRAM cost across smaller per-layer evaluations. + if ( + request is not None + and hasattr(request, "_extracted_cache") + and request._extracted_cache + ): + for layer in request._extracted_cache: + if isinstance(layer, dict) and "state" in layer: + keys, values = layer["state"] + mx.eval(keys, values) + elif hasattr(layer, "keys") and hasattr(layer, "values"): + keys_attr = layer.keys + values_attr = layer.values + if not callable(keys_attr) and not callable(values_attr): + mx.eval(keys_attr, values_attr) + # Remove from running if request_id in self.running: del self.running[request_id] @@ -795,6 +1575,10 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: # Track as finished self.finished_req_ids.add(request_id) + # Free Metal command buffers after cleanup (prevents end-of-generation spike) + if finished_ids: + mx.clear_cache() + def _is_cache_corruption_error(self, error: Exception) -> bool: """Check if an error indicates cache corruption.""" error_str = str(error) @@ -802,8 +1586,8 @@ def _is_cache_corruption_error(self, error: Exception) -> bool: def _recover_from_cache_error(self) -> None: """Recover from cache corruption error.""" - # Clear batch generator (this is the source of the corruption) - self.batch_generator = None + # Properly close batch generator (this is the source of the corruption) + self._close_batch_generator() self._current_sampler_params = None # Clear caches @@ -856,6 +1640,9 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: """ output = SchedulerOutput() + # Process pending aborts FIRST (in executor thread, safe for MLX) + self._process_pending_aborts() + for attempt in range(max_retries + 1): try: # Schedule waiting requests @@ -907,6 +1694,27 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: old_finished = self.finished_req_ids self.finished_req_ids = set() + # Periodically clear Metal cache to prevent memory accumulation + self._step_count += 1 + if self._step_count % self._clear_cache_interval == 0: + mx.clear_cache() + + # Periodically log memory stats for monitoring + if self._step_count % self._memory_log_interval == 0: + try: + if mx.metal.is_available(): + active_gb = mx.get_active_memory() / 1e9 + peak_gb = mx.get_peak_memory() / 1e9 + cache_gb = mx.get_cache_memory() / 1e9 + logger.info( + f"[Metal memory] active={active_gb:.1f}GB " + f"peak={peak_gb:.1f}GB cache={cache_gb:.1f}GB " + f"step={self._step_count} " + f"running={len(self.running)} waiting={len(self.waiting)}" + ) + except Exception: + pass + return output def get_request(self, request_id: str) -> Optional[Request]: @@ -917,6 +1725,74 @@ def remove_finished_request(self, request_id: str) -> Optional[Request]: """Remove a finished request from tracking.""" return self.requests.pop(request_id, None) + def get_running_requests_info(self) -> List[Dict[str, Any]]: + """Per-request details for status endpoint.""" + import time as _time + + now = _time.time() + result = [] + + # Waiting requests + for req in self.waiting: + result.append( + { + "request_id": req.request_id, + "status": "waiting", + "phase": "queued", + "elapsed_s": round(now - req.arrival_time, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": 0, + "max_tokens": req.max_tokens, + "progress": 0.0, + "tokens_per_second": None, + "ttft_s": None, + "cache_hit_type": req.cache_hit_type, + "cached_tokens": req.cached_tokens, + } + ) + + # Running requests + for req in self.running.values(): + n_out = req.num_output_tokens + elapsed = now - req.arrival_time + + # Phase detection + if n_out == 0: + phase = "prefill" + else: + phase = "generation" + + # Tokens per second (generation phase only) + tok_s = None + ttft = None + if req.first_token_time is not None: + ttft = round(req.first_token_time - req.arrival_time, 3) + gen_elapsed = now - req.first_token_time + if gen_elapsed > 0 and n_out > 0: + tok_s = round(n_out / gen_elapsed, 1) + + # Progress: completion_tokens / max_tokens + progress = round(n_out / req.max_tokens, 3) if req.max_tokens > 0 else 0.0 + + result.append( + { + "request_id": req.request_id, + "status": "running", + "phase": phase, + "elapsed_s": round(elapsed, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": n_out, + "max_tokens": req.max_tokens, + "progress": min(progress, 1.0), + "tokens_per_second": tok_s, + "ttft_s": ttft, + "cache_hit_type": req.cache_hit_type, + "cached_tokens": req.cached_tokens, + } + ) + + return result + def get_stats(self) -> Dict[str, Any]: """Get scheduler statistics.""" stats = { @@ -926,6 +1802,15 @@ def get_stats(self) -> Dict[str, Any]: "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, } + # Include Metal memory stats + try: + if mx.metal.is_available(): + stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2) + stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2) + stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2) + except Exception: + pass + # Include cache stats if self.block_aware_cache is not None: stats["paged_cache"] = self.block_aware_cache.get_stats() @@ -947,9 +1832,12 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]: def reset(self) -> None: """Reset the scheduler state.""" - # Abort all requests + # Drain any pending deferred aborts + self._pending_abort_ids.clear() + + # Abort all requests directly (reset is synchronous) for request_id in list(self.requests.keys()): - self.abort_request(request_id) + self._do_abort_request(request_id) self.waiting.clear() self.running.clear() @@ -957,7 +1845,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() - self.batch_generator = None + self._close_batch_generator() self._current_sampler_params = None # Clear caches @@ -997,3 +1885,21 @@ def deep_reset(self) -> None: gc.collect() logger.info("Deep reset completed - all caches cleared") + + # ----------------------------------------------------------------- + # Cache persistence + # ----------------------------------------------------------------- + + def save_cache_to_disk(self, cache_dir: str) -> bool: + """Save prefix cache to disk for persistence across restarts.""" + if self.memory_aware_cache is not None: + return self.memory_aware_cache.save_to_disk(cache_dir) + logger.info("[cache_persist] no memory-aware cache to save") + return False + + def load_cache_from_disk(self, cache_dir: str) -> int: + """Load prefix cache from disk. Returns number of entries loaded.""" + if self.memory_aware_cache is not None: + return self.memory_aware_cache.load_from_disk(cache_dir) + logger.info("[cache_persist] no memory-aware cache to load into") + return 0 diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index d36779ef..7b131067 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -49,7 +49,6 @@ import uuid from collections import defaultdict from collections.abc import AsyncIterator -from contextlib import asynccontextmanager import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile @@ -58,6 +57,8 @@ # Import from new modular API # Re-export for backwards compatibility with tests +from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic +from .api.anthropic_models import AnthropicRequest from .api.models import ( AssistantMessage, # noqa: F401 ChatCompletionChoice, # noqa: F401 @@ -96,6 +97,7 @@ parse_tool_calls, ) from .api.utils import ( + SPECIAL_TOKENS_PATTERN, clean_output_text, extract_multimodal_content, is_mllm_model, # noqa: F401 @@ -157,7 +159,50 @@ def _resolve_top_p(request_value: float | None) -> float: _tool_parser_instance = None # Instantiated parser -@asynccontextmanager +def _load_prefix_cache_from_disk() -> None: + """Load prefix cache from disk during startup.""" + try: + d = _get_cache_dir() + logger.info(f"[lifespan] Loading prefix cache from {d}") + loaded = _engine.load_cache_from_disk(d) + if loaded > 0: + logger.info(f"[lifespan] Loaded {loaded} prefix cache entries") + else: + logger.info("[lifespan] No prefix cache entries found on disk") + except Exception as e: + logger.warning(f"[lifespan] Failed to load cache from disk: {e}", exc_info=True) + + +def _save_prefix_cache_to_disk() -> None: + """Save prefix cache to disk during shutdown.""" + try: + d = _get_cache_dir() + logger.info(f"[lifespan] Saving prefix cache to {d}") + saved = _engine.save_cache_to_disk(d) + if saved: + logger.info(f"[lifespan] Saved prefix cache to {d}") + else: + logger.info("[lifespan] No cache to save") + except Exception as e: + logger.warning(f"[lifespan] Failed to save cache to disk: {e}", exc_info=True) + + +def _get_cache_dir() -> str: + """Get cache persistence directory based on model name.""" + # Use global _model_name which is always a string, set during load_model() + model_name = _model_name if _model_name else "default" + logger.info( + f"[_get_cache_dir] _model_name={_model_name!r} type={type(_model_name)}" + ) + # Sanitize model name for filesystem + safe_name = str(model_name).replace("/", "--").replace("\\", "--") + cache_dir = os.path.join( + os.path.expanduser("~"), ".cache", "vllm-mlx", "prefix_cache", safe_name + ) + logger.info(f"[_get_cache_dir] cache_dir={cache_dir!r}") + return cache_dir + + async def lifespan(app: FastAPI): """FastAPI lifespan for startup/shutdown events.""" global _engine, _mcp_manager @@ -166,6 +211,10 @@ async def lifespan(app: FastAPI): if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded: await _engine.start() + # Load persisted cache from disk (AFTER engine start — AsyncEngineCore must exist) + if _engine is not None and hasattr(_engine, "load_cache_from_disk"): + _load_prefix_cache_from_disk() + # Initialize MCP if config provided mcp_config = os.environ.get("VLLM_MLX_MCP_CONFIG") if mcp_config: @@ -173,6 +222,10 @@ async def lifespan(app: FastAPI): yield + # Shutdown: Save cache to disk BEFORE stopping engine + if _engine is not None and hasattr(_engine, "save_cache_to_disk"): + _save_prefix_cache_to_disk() + # Shutdown: Close MCP connections and stop engine if _mcp_manager is not None: await _mcp_manager.stop() @@ -300,9 +353,11 @@ def _parse_tool_calls_with_parser( """ global _tool_parser_instance + request_dict = request.model_dump() if request else None + # If auto tool choice is not enabled, use the generic parser if not _enable_auto_tool_choice or not _tool_call_parser: - return parse_tool_calls(output_text) + return parse_tool_calls(output_text, request_dict) # Initialize parser if needed if _tool_parser_instance is None: @@ -319,13 +374,13 @@ def _parse_tool_calls_with_parser( f"Failed to initialize tool parser '{_tool_call_parser}': {e}" ) logger.warning("Falling back to generic parser") - return parse_tool_calls(output_text) + return parse_tool_calls(output_text, request_dict) # Use the configured parser try: # Reset parser state between requests _tool_parser_instance.reset() - result = _tool_parser_instance.extract_tool_calls(output_text) + result = _tool_parser_instance.extract_tool_calls(output_text, request_dict) if result.tools_called: tool_calls = [ ToolCall( @@ -340,10 +395,12 @@ def _parse_tool_calls_with_parser( ] return result.content or "", tool_calls else: - return result.content or output_text, None + # Fallback: specific parser didn't find tool calls, + # try generic parser which handles more formats (e.g. Nemotron XML) + return parse_tool_calls(output_text, request_dict) except Exception as e: logger.warning(f"Tool parser error: {e}") - return parse_tool_calls(output_text) + return parse_tool_calls(output_text, request_dict) def _detect_native_tool_support() -> bool: @@ -505,6 +562,36 @@ async def health(): } +@app.get("/v1/status") +async def status(): + """Real-time status with per-request details for debugging and monitoring.""" + if _engine is None: + return {"status": "not_loaded", "model": None, "requests": []} + + stats = _engine.get_stats() + + return { + "status": "running" if stats.get("running") else "stopped", + "model": _model_name, + "uptime_s": round(stats.get("uptime_seconds", 0), 1), + "steps_executed": stats.get("steps_executed", 0), + "num_running": stats.get("num_running", 0), + "num_waiting": stats.get("num_waiting", 0), + "total_requests_processed": stats.get("num_requests_processed", 0), + "total_prompt_tokens": stats.get("total_prompt_tokens", 0), + "total_completion_tokens": stats.get("total_completion_tokens", 0), + "metal": { + "active_memory_gb": stats.get("metal_active_memory_gb"), + "peak_memory_gb": stats.get("metal_peak_memory_gb"), + "cache_memory_gb": stats.get("metal_cache_memory_gb"), + }, + "cache": stats.get("memory_aware_cache") + or stats.get("paged_cache") + or stats.get("prefix_cache"), + "requests": stats.get("requests", []), + } + + @app.get("/v1/cache/stats") async def cache_stats(): """Get cache statistics for debugging and monitoring.""" @@ -892,6 +979,184 @@ async def list_voices(model: str = "kokoro"): return {"voices": ["default"]} +# ============================================================================= +# Streaming disconnect detection +# ============================================================================= + + +async def _disconnect_guard( + generator: AsyncIterator[str], + raw_request: Request, + poll_interval: float = 0.5, +) -> AsyncIterator[str]: + """Wrap streaming generator to abort on client disconnect. + + Uses asyncio racing: each __anext__() on the inner generator is + raced against a disconnect poller. This catches disconnects even + during prefill when no chunks are being yielded for tens of seconds. + + On disconnect, aclose() propagates down the generator chain to + engine_core.stream_outputs() finally-block → abort_request(). + """ + import time as _time + + _t0 = _time.monotonic() + + def _elapsed(): + return f"{_time.monotonic() - _t0:.1f}s" + + logger.info(f"[disconnect_guard] START poll_interval={poll_interval}s") + + async def _wait_disconnect(): + poll_count = 0 + while True: + await asyncio.sleep(poll_interval) + poll_count += 1 + is_disc = await raw_request.is_disconnected() + if poll_count % 10 == 0 or is_disc: + logger.info( + f"[disconnect_guard] poll #{poll_count} " + f"disconnected={is_disc} elapsed={_elapsed()}" + ) + if is_disc: + return + + chunk_count = 0 + disconnect_task: asyncio.Task | None = None + anext_task: asyncio.Task | None = None + try: + aiter = generator.__aiter__() + disconnect_task = asyncio.create_task(_wait_disconnect()) + while True: + anext_task = asyncio.ensure_future(aiter.__anext__()) + done, _ = await asyncio.wait( + [anext_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + logger.info( + f"[disconnect_guard] CLIENT DISCONNECTED after " + f"{chunk_count} chunks, elapsed={_elapsed()}" + ) + anext_task.cancel() + try: + await anext_task + except (asyncio.CancelledError, StopAsyncIteration): + pass + break + try: + chunk = anext_task.result() + except StopAsyncIteration: + logger.info( + f"[disconnect_guard] generator exhausted normally, " + f"{chunk_count} chunks, elapsed={_elapsed()}" + ) + break + chunk_count += 1 + if chunk_count == 1: + logger.info( + f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}" + ) + yield chunk + except GeneratorExit: + logger.info( + f"[disconnect_guard] GeneratorExit after {chunk_count} chunks, elapsed={_elapsed()}" + ) + finally: + if disconnect_task and not disconnect_task.done(): + disconnect_task.cancel() + if anext_task and not anext_task.done(): + anext_task.cancel() + # NOTE: Do NOT call generator.aclose() here. With run_in_executor, + # scheduler.step() runs in a background thread. aclose() would throw + # GeneratorExit into the async-generator chain, which can trigger + # mlx::core::eval on the main thread while the executor thread is also + # mid-eval → Metal assertion failure → SIGABRT. + # + # Instead, rely on the task cancellation propagation: + # anext_task.cancel() → CancelledError in stream_outputs() + # → finally block → abort_request() → request removed from scheduler + logger.info( + f"[disconnect_guard] CLEANUP done, {chunk_count} chunks total, elapsed={_elapsed()}" + ) + + +async def _wait_with_disconnect( + coro, + raw_request: Request, + timeout: float, + poll_interval: float = 0.5, +): + """Run a coroutine with both timeout and client disconnect detection. + + For non-streaming requests where _disconnect_guard() can't be used. + Races the coroutine against a disconnect poller, same pattern as + _disconnect_guard but for awaitable (non-generator) coroutines. + """ + import time as _time + + _t0 = _time.monotonic() + + task = asyncio.ensure_future(coro) + + async def _wait_disconnect(): + poll_count = 0 + while True: + await asyncio.sleep(poll_interval) + poll_count += 1 + is_disc = await raw_request.is_disconnected() + if poll_count % 10 == 0 or is_disc: + logger.info( + f"[disconnect_guard] poll #{poll_count} " + f"disconnected={is_disc} elapsed={_time.monotonic() - _t0:.1f}s" + ) + if is_disc: + return + + disconnect_task = asyncio.create_task(_wait_disconnect()) + + try: + done, _ = await asyncio.wait( + [task, disconnect_task], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + + if not done: + # Timeout + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + raise HTTPException( + status_code=504, + detail=f"Request timed out after {timeout:.1f} seconds", + ) + + if disconnect_task in done: + # Client disconnected + logger.info( + f"[disconnect_guard] CLIENT DISCONNECTED (non-stream) " + f"elapsed={_time.monotonic() - _t0:.1f}s" + ) + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + return None # Signal to caller that client disconnected + + # Task completed + return task.result() + + finally: + if not disconnect_task.done(): + disconnect_task.cancel() + if not task.done(): + task.cancel() + + # ============================================================================= # Completion Endpoints # ============================================================================= @@ -900,16 +1165,28 @@ async def list_voices(model: str = "kokoro"): @app.post( "/v1/completions", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)] ) -async def create_completion(request: CompletionRequest): +async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" engine = get_engine() # Handle single prompt or list of prompts prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt] + # --- Detailed request logging --- + prompt_preview = prompts[0][:200] if prompts else "(empty)" + prompt_len = sum(len(p) for p in prompts) + logger.info( + f"[REQUEST] POST /v1/completions stream={request.stream} " + f"max_tokens={request.max_tokens} temp={request.temperature} " + f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}" + ) + if request.stream: return StreamingResponse( - stream_completion(engine, prompts[0], request), + _disconnect_guard( + stream_completion(engine, prompts[0], request), + raw_request, + ), media_type="text/event-stream", ) @@ -921,21 +1198,19 @@ async def create_completion(request: CompletionRequest): total_prompt_tokens = 0 for i, prompt in enumerate(prompts): - try: - output = await asyncio.wait_for( - engine.generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=_resolve_temperature(request.temperature), - top_p=_resolve_top_p(request.top_p), - stop=request.stop, - ), - timeout=timeout, - ) - except asyncio.TimeoutError: - raise HTTPException( - status_code=504, detail=f"Request timed out after {timeout:.1f} seconds" - ) + output = await _wait_with_disconnect( + engine.generate( + prompt=prompt, + max_tokens=request.max_tokens or _default_max_tokens, + temperature=_resolve_temperature(request.temperature), + top_p=_resolve_top_p(request.top_p), + stop=request.stop, + ), + raw_request, + timeout=timeout, + ) + if output is None: + return Response(status_code=499) # Client closed request choices.append( CompletionChoice( @@ -970,7 +1245,7 @@ async def create_completion(request: CompletionRequest): "/v1/chat/completions", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], ) -async def create_chat_completion(request: ChatCompletionRequest): +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): """ Create a chat completion (supports multimodal content for VLM models). @@ -1014,6 +1289,27 @@ async def create_chat_completion(request: ChatCompletionRequest): """ engine = get_engine() + # --- Detailed request logging --- + n_msgs = len(request.messages) + msg_roles = [m.role for m in request.messages] + total_chars = 0 + last_user_preview = "" + for m in request.messages: + content = m.content if isinstance(m.content, str) else str(m.content) + total_chars += len(content) + if m.role == "user": + last_user_preview = content[:300] + has_tools = bool(request.tools) + n_tools = len(request.tools) if request.tools else 0 + logger.info( + f"[REQUEST] POST /v1/chat/completions stream={request.stream} " + f"model={request.model!r} max_tokens={request.max_tokens} " + f"temp={request.temperature} msgs={n_msgs} roles={msg_roles} " + f"total_chars={total_chars} tools={n_tools} " + f"response_format={request.response_format}" + ) + logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + # For MLLM models, keep original messages with embedded images # (MLLM.chat() extracts images from message content internally) if engine.is_mllm: @@ -1063,7 +1359,10 @@ async def create_chat_completion(request: ChatCompletionRequest): if request.stream: return StreamingResponse( - stream_chat_completion(engine, messages, request, **chat_kwargs), + _disconnect_guard( + stream_chat_completion(engine, messages, request, **chat_kwargs), + raw_request, + ), media_type="text/event-stream", ) @@ -1071,14 +1370,13 @@ async def create_chat_completion(request: ChatCompletionRequest): start_time = time.perf_counter() timeout = request.timeout or _default_timeout - try: - output = await asyncio.wait_for( - engine.chat(messages=messages, **chat_kwargs), timeout=timeout - ) - except asyncio.TimeoutError: - raise HTTPException( - status_code=504, detail=f"Request timed out after {timeout:.1f} seconds" - ) + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + if output is None: + return Response(status_code=499) # Client closed request elapsed = time.perf_counter() - start_time tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 @@ -1163,6 +1461,349 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: return messages +# ============================================================================= +# Anthropic Messages API Endpoints +# ============================================================================= + + +@app.post("/v1/messages") +async def create_anthropic_message( + request: Request, +): + """ + Anthropic Messages API endpoint. + + Translates Anthropic-format requests to OpenAI format, runs inference + through the existing engine, and converts the response back. + + Supports both streaming and non-streaming modes. + """ + engine = get_engine() + + # Parse the raw body to handle Anthropic request format + body = await request.json() + anthropic_request = AnthropicRequest(**body) + + # --- Detailed request logging --- + n_msgs = len(anthropic_request.messages) + total_chars = 0 + last_user_preview = "" + for m in anthropic_request.messages: + content = m.content if isinstance(m.content, str) else str(m.content) + total_chars += len(content) + if m.role == "user": + last_user_preview = content[:300] + sys_chars = len(anthropic_request.system) if anthropic_request.system else 0 + n_tools = len(anthropic_request.tools) if anthropic_request.tools else 0 + logger.info( + f"[REQUEST] POST /v1/messages (anthropic) stream={anthropic_request.stream} " + f"model={anthropic_request.model!r} max_tokens={anthropic_request.max_tokens} " + f"msgs={n_msgs} total_chars={total_chars} system_chars={sys_chars} " + f"tools={n_tools}" + ) + logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + + # Convert Anthropic request -> OpenAI request + openai_request = anthropic_to_openai(anthropic_request) + + if anthropic_request.stream: + return StreamingResponse( + _disconnect_guard( + _stream_anthropic_messages(engine, openai_request, anthropic_request), + request, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + # Non-streaming: run inference through existing engine + messages, images, videos = extract_multimodal_content( + openai_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": openai_request.max_tokens or _default_max_tokens, + "temperature": openai_request.temperature, + "top_p": openai_request.top_p, + } + + if openai_request.tools: + chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) + + start_time = time.perf_counter() + timeout = _default_timeout + + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + request, + timeout=timeout, + ) + if output is None: + return Response(status_code=499) # Client closed request + + elapsed = time.perf_counter() - start_time + tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Anthropic messages: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) + + # Parse tool calls + cleaned_text, tool_calls = _parse_tool_calls_with_parser( + output.text, openai_request + ) + + # Clean output text + final_content = None + if cleaned_text: + final_content = clean_output_text(cleaned_text) + + # Determine finish reason + finish_reason = "tool_calls" if tool_calls else output.finish_reason + + # Build OpenAI response to convert + openai_response = ChatCompletionResponse( + model=openai_request.model, + choices=[ + ChatCompletionChoice( + message=AssistantMessage( + content=final_content, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + ], + usage=Usage( + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + total_tokens=output.prompt_tokens + output.completion_tokens, + ), + ) + + # Convert to Anthropic response + anthropic_response = openai_to_anthropic(openai_response, anthropic_request.model) + return Response( + content=anthropic_response.model_dump_json(exclude_none=True), + media_type="application/json", + ) + + +@app.post("/v1/messages/count_tokens") +async def count_anthropic_tokens(request: Request): + """ + Count tokens for an Anthropic Messages API request. + + Uses the model's tokenizer for accurate counting. + Claude Code calls this endpoint for token budgeting. + Note: Don't parse via AnthropicRequest — count_tokens requests + from Claude Code don't include max_tokens. + """ + body = await request.json() + + engine = get_engine() + tokenizer = engine.tokenizer + + total_tokens = 0 + + # System message + system = body.get("system", "") + if isinstance(system, str) and system: + total_tokens += len(tokenizer.encode(system)) + elif isinstance(system, list): + for block in system: + if isinstance(block, dict): + text = block.get("text", "") + if text: + total_tokens += len(tokenizer.encode(text)) + + # Messages + for msg in body.get("messages", []): + content = msg.get("content", "") + if isinstance(content, str): + if content: + total_tokens += len(tokenizer.encode(content)) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + text = block.get("text", "") + if text: + total_tokens += len(tokenizer.encode(text)) + # tool_use input + if block.get("input"): + total_tokens += len( + tokenizer.encode(json.dumps(block["input"])) + ) + # tool_result content + sub_content = block.get("content", "") + if isinstance(sub_content, str) and sub_content: + total_tokens += len(tokenizer.encode(sub_content)) + elif isinstance(sub_content, list): + for item in sub_content: + if isinstance(item, dict): + item_text = item.get("text", "") + if item_text: + total_tokens += len(tokenizer.encode(item_text)) + + # Tools + for tool in body.get("tools", []): + name = tool.get("name", "") + if name: + total_tokens += len(tokenizer.encode(name)) + desc = tool.get("description", "") + if desc: + total_tokens += len(tokenizer.encode(desc)) + if tool.get("input_schema"): + total_tokens += len(tokenizer.encode(json.dumps(tool["input_schema"]))) + + return {"input_tokens": total_tokens} + + +async def _stream_anthropic_messages( + engine: BaseEngine, + openai_request: ChatCompletionRequest, + anthropic_request: AnthropicRequest, +) -> AsyncIterator[str]: + """ + Stream Anthropic Messages API SSE events. + + Converts OpenAI streaming chunks to Anthropic event format: + message_start -> content_block_start -> content_block_delta* -> + content_block_stop -> message_delta -> message_stop + """ + msg_id = f"msg_{uuid.uuid4().hex[:24]}" + start_time = time.perf_counter() + + # Extract messages for engine + messages, images, videos = extract_multimodal_content( + openai_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": openai_request.max_tokens or _default_max_tokens, + "temperature": openai_request.temperature, + "top_p": openai_request.top_p, + } + + if openai_request.tools: + chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) + + # Emit message_start + message_start = { + "type": "message_start", + "message": { + "id": msg_id, + "type": "message", + "role": "assistant", + "model": anthropic_request.model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" + + # Emit content_block_start for text + content_block_start = { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n" + + # Stream content deltas + accumulated_text = "" + completion_tokens = 0 + + async for output in engine.stream_chat(messages=messages, **chat_kwargs): + delta_text = output.new_text + + # Track token counts + if hasattr(output, "completion_tokens") and output.completion_tokens: + completion_tokens = output.completion_tokens + + if delta_text: + # Filter special tokens + content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + + if content: + accumulated_text += content + delta_event = { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": content}, + } + yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + + # Check for tool calls in accumulated text + _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) + + # Emit content_block_stop for text block + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" + + # If there are tool calls, emit tool_use blocks + if tool_calls: + for i, tc in enumerate(tool_calls): + tool_index = i + 1 + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + + # content_block_start for tool_use + tool_block_start = { + "type": "content_block_start", + "index": tool_index, + "content_block": { + "type": "tool_use", + "id": tc.id, + "name": tc.function.name, + "input": {}, + }, + } + yield f"event: content_block_start\ndata: {json.dumps(tool_block_start)}\n\n" + + # Send input as a single delta + input_json = json.dumps(tool_input) + input_delta = { + "type": "content_block_delta", + "index": tool_index, + "delta": {"type": "input_json_delta", "partial_json": input_json}, + } + yield f"event: content_block_delta\ndata: {json.dumps(input_delta)}\n\n" + + # content_block_stop + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': tool_index})}\n\n" + + # Determine stop reason + stop_reason = "tool_use" if tool_calls else "end_turn" + + # Emit message_delta with stop_reason and usage + message_delta = { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": completion_tokens}, + } + yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n" + + # Log throughput + elapsed = time.perf_counter() - start_time + tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) + + # Emit message_stop + yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" + + # ============================================================================= # Streaming Helpers # ============================================================================= @@ -1209,6 +1850,7 @@ async def stream_chat_completion( ) -> AsyncIterator[str]: """Stream chat completion response.""" response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" + start_time = time.perf_counter() # Check if we should include usage in the final chunk include_usage = request.stream_options and request.stream_options.include_usage @@ -1242,6 +1884,28 @@ async def stream_chat_completion( completion_tokens = 0 last_output = None + # Tool call streaming state + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_calls_detected = False + tool_markup_possible = False # Fast path: skip parsing until '<' seen + if _enable_auto_tool_choice and _tool_call_parser: + # Initialize parser if needed (same as _parse_tool_calls_with_parser) + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info(f"Initialized tool call parser: {_tool_call_parser}") + except Exception as e: + logger.warning(f"Failed to init tool parser for streaming: {e}") + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() + # Stream content async for output in engine.stream_chat(messages=messages, **kwargs): delta_text = output.new_text @@ -1284,11 +1948,60 @@ async def stream_chat_completion( # Standard path without reasoning parsing content = delta_text + # Filter special tokens that may leak into streaming output + if content: + content = SPECIAL_TOKENS_PATTERN.sub("", content) + # Add prefix on first content chunk for thinking models if is_thinking_model and not think_prefix_sent and content: content = "" + content think_prefix_sent = True + # Tool call streaming parsing + if tool_parser and delta_text: + # Fast path: skip full parsing until '<' is seen in the stream, + # which could start tool markup (e.g. ). This avoids + # per-token string scanning on the growing accumulated text. + if not tool_markup_possible and "<" not in delta_text: + tool_accumulated_text += delta_text + # No tool markup yet, fall through to normal chunk emission + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += delta_text + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, delta_text + ) + + if tool_result is None: + # Inside tool markup - suppress output + continue + + if "tool_calls" in tool_result: + # Emit structured tool calls + tool_calls_detected = True + chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason=( + "tool_calls" if output.finished else None + ), + ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + continue + + # Normal content from tool parser + content = tool_result.get("content", "") + chunk = ChatCompletionChunk( id=response_id, model=request.model, @@ -1297,13 +2010,59 @@ async def stream_chat_completion( delta=ChatCompletionChunkDelta( content=content if content else None ), - finish_reason=output.finish_reason if output.finished else None, + finish_reason=( + "tool_calls" + if (output.finished and tool_calls_detected) + else (output.finish_reason if output.finished else None) + ), ) ], usage=get_usage(output) if output.finished else None, ) yield f"data: {chunk.model_dump_json()}\n\n" + # Fallback: if tool parser accumulated text but never emitted tool_calls + # (e.g., never arrived - incomplete tool call) + if ( + tool_parser + and tool_accumulated_text + and not tool_calls_detected + and "" in tool_accumulated_text + ): + result = tool_parser.extract_tool_calls(tool_accumulated_text) + if result.tools_called: + tool_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=[ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + ), + finish_reason="tool_calls", + ) + ], + ) + yield f"data: {tool_chunk.model_dump_json()}\n\n" + + # Log throughput + elapsed = time.perf_counter() - start_time + tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) + # Send final chunk with usage if requested if include_usage: usage_chunk = ChatCompletionChunk( diff --git a/vllm_mlx/tool_parsers/hermes_tool_parser.py b/vllm_mlx/tool_parsers/hermes_tool_parser.py index 0605e640..da91a816 100644 --- a/vllm_mlx/tool_parsers/hermes_tool_parser.py +++ b/vllm_mlx/tool_parsers/hermes_tool_parser.py @@ -36,12 +36,23 @@ class HermesToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser hermes are set. """ + # Qwen3 / Hermes chat templates handle role="tool" and tool_calls natively. + # Without this, tool history is converted to "[Calling tool: ...]" text, + # which causes the model to mimic that text format instead of producing + # proper XML after a few rounds of tool use. + SUPPORTS_NATIVE_TOOL_FORMAT = True + # Standard format: {"name": ..., "arguments": ...} TOOL_CALL_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) # Lenient format: followed by JSON (handles malformed tags) TOOL_CALL_LENIENT_PATTERN = re.compile( r'v + NEMOTRON_PATTERN = re.compile( + r"\s*]+)>(.*?)\s*", re.DOTALL + ) + PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL) REASONING_PATTERN = re.compile( r"(.*?)", re.DOTALL ) @@ -91,7 +102,29 @@ def extract_tool_calls( if matches: cleaned_text = self.TOOL_CALL_PATTERN.sub("", cleaned_text).strip() - # Fallback 1: try lenient pattern for malformed tags like + # Try Nemotron XML format if no JSON tool calls found + if not tool_calls: + nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text) + for name, params_block in nemotron_matches: + params = self.PARAM_PATTERN.findall(params_block) + arguments = {} + for p_name, p_value in params: + val = p_value.strip() + try: + arguments[p_name.strip()] = json.loads(val) + except (json.JSONDecodeError, ValueError): + arguments[p_name.strip()] = val + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + if nemotron_matches: + cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip() + + # Fallback: try lenient pattern for malformed tags like if not tool_calls: lenient_matches = self.TOOL_CALL_LENIENT_PATTERN.findall(cleaned_text) for match in lenient_matches[:1]: # Only first to avoid hallucinations @@ -117,16 +150,14 @@ def extract_tool_calls( except json.JSONDecodeError: continue - # Fallback 2: try raw JSON format if no tagged tool calls found + # Fallback: try raw JSON format if no tagged tool calls found # Only parse the FIRST valid tool call to avoid hallucinated multiple calls if not tool_calls: raw_matches = self.RAW_JSON_TOOL_PATTERN.findall(cleaned_text) if raw_matches: - # Only take the first match to avoid hallucinated tool calls name, args_str = raw_matches[0] try: arguments = json.loads(args_str) - # Validate: only accept if tool name exists in request tools valid_tool = True if request and "tools" in request: tool_names = [ @@ -144,7 +175,6 @@ def extract_tool_calls( "arguments": json.dumps(arguments, ensure_ascii=False), } ) - # Remove the matched tool call from text cleaned_text = self.RAW_JSON_TOOL_PATTERN.sub( "", cleaned_text, count=1 ).strip() diff --git a/vllm_mlx/worker.py b/vllm_mlx/worker.py index 7c3a9047..02f44573 100644 --- a/vllm_mlx/worker.py +++ b/vllm_mlx/worker.py @@ -212,7 +212,7 @@ def shutdown(self) -> None: try: import mlx.core as mx - mx.metal.clear_cache() + mx.clear_cache() except Exception: pass