diff --git a/README.md b/README.md index b77d116d..36a25d2d 100644 --- a/README.md +++ b/README.md @@ -377,6 +377,136 @@ Submit PRs to: [https://github.com/waybarrios/vllm-mlx](https://github.com/wayba Apache 2.0 - see [LICENSE](LICENSE) for details. +## Baseline Benchmark — MiniMax-M2.5 on M3 Ultra + +Benchmark measured on **Mac Studio M3 Ultra (256GB)** running **MiniMax-M2.5-MLX-4bit (229B MoE, ~120GB)** with `--reasoning-parser deepseek_r1 --tool-call-parser minimax --enable-auto-tool-choice --enable-prefix-cache`. + +Date: 2026-02-24 | Branch: `feat-minimax-parser` | Server mode: Simple (no continuous batching) + +### TTFT (Time To First Token) + +| Prompt Size | TTFT | Decode Speed | +|-------------|------|-------------| +| Short (~50 tok) | 0.331s | 54.1 tok/s | +| Medium (~500 tok) | 0.753s | 53.5 tok/s | +| Long (~2K tok) | 1.396s | 52.1 tok/s | + +### Decode Throughput + +| Output Length | Speed | Total Time | +|--------------|-------|------------| +| 128 tokens | 52.8 tok/s | 2.8s | +| 512 tokens | 52.0 tok/s | 10.2s | +| 2048 tokens | 49.6 tok/s | 41.7s | + +### Prefix Cache (Multi-Turn Conversation) + +| Turn | TTFT | Prompt Growth | +|------|------|--------------| +| Turn 1 | 0.497s | baseline | +| Turn 2 | 0.585s | +assistant reply | +| Turn 3 | 0.871s | +2 turns | +| Turn 4 | 1.040s | +3 turns | + +Turn 4 / Turn 1 TTFT ratio: **2.09x** (ideally should stay close to 1.0x with perfect cache hits) + +### Tool Calling + +| Test | Result | Latency | Tools Called | +|------|--------|---------|-------------| +| Simple (weather) | OK | 2.00s | get_weather | +| Multi-arg (search) | OK | 2.51s | search_web | +| Code execution | OK | 3.99s | run_python | +| Multi-tool selection | OK | 3.04s | get_weather, search_web | + +Accuracy: **4/4 (100%)** | Avg latency: **2.89s** + +### Reasoning Separation (deepseek_r1 parser) + +| Test | Reasoning | Content | Separated | +|------|-----------|---------|-----------| +| Math (17*23) | 1013 chars | 259 chars | OK | +| Logic | 2154 chars | 0 chars | PARTIAL | +| Code | 2011 chars | 0 chars | PARTIAL | + +Properly separated: **1/3** (reasoning parser needs improvement for non-math tasks) + +### Long Generation Stability + +| Metric | Value | +|--------|-------| +| Max tokens | 8192 | +| Completed | Yes (finish_reason=length) | +| Decode speed | 33.0 tok/s | +| Total time | 249.0s | +| Output chars | 36,609 | + +### Run the Benchmark + +```bash +# Start server +python -m vllm_mlx.cli serve \ + --tool-call-parser minimax --enable-auto-tool-choice \ + --reasoning-parser deepseek_r1 --enable-prefix-cache + +# Run all tests +python benchmark_minmax.py --output results.json + +# Run specific test +python benchmark_minmax.py --test ttft +python benchmark_minmax.py --test tool_call +``` + +--- + +## Optimization Roadmap + +Prioritized optimization plan based on analysis of upstream vLLM features, adapted for MLX on Apple Silicon. Focused on large-model serving (MiniMax-M2.5 229B MoE, ~120GB on M3 Ultra 256GB). + +See [docs/plans/](docs/plans/) for detailed implementation plans. + +### Tier 0: Quick Wins (In Progress) + +Low-risk, high-impact changes. [Full plan](docs/plans/tier0-pinned-prefix-cache.md). + +| Optimization | Impact | Upstream Ref | +|-------------|--------|-------------| +| **GC Control** -- disable Python GC during generation | Eliminates latency spikes | [vLLM #33575](https://github.com/vllm-project/vllm/pull/33575) | +| **Detokenizer Caching** -- incremental token counting | O(n^2) -> O(1) per chunk | [vLLM #20413](https://github.com/vllm-project/vllm/pull/20413) | +| **Schema Error Hardening** -- graceful fallback on bad schemas | Prevents server crashes | [vLLM #30346](https://github.com/vllm-project/vllm/pull/30346) | +| **Pinned Prefix Cache** -- prevent system prompt eviction | Turn4/Turn1 TTFT 2.09x -> ~1.0x | [vLLM #27097](https://github.com/vllm-project/vllm/pull/27097) | + +### Tier 1: Implement Soon (High Impact, Medium Complexity) + +| Optimization | Expected Impact | Upstream Ref | +|-------------|----------------|-------------| +| **Jump-Forward Decoding** -- skip constrained tokens in guided gen | 2-5x faster JSON generation | [vLLM #15490](https://github.com/vllm-project/vllm/pull/15490) | +| **Structural Tags for Tool Calling** -- guarantee valid JSON args | <1% tool call failure rate | [vLLM #32202](https://github.com/vllm-project/vllm/pull/32202) | +| **Frequency-Aware Cache Eviction** -- smarter LRU for multi-turn | Better cache retention for conversations | [vLLM #27539](https://github.com/vllm-project/vllm/pull/27539) | +| **KV Block Reuse Ordering** -- improved prefix cache hit rates | Higher cache hit ratio | [vLLM #33710](https://github.com/vllm-project/vllm/pull/33710) | +| **Fix Streaming Tool Truncation** -- reliable streaming tool calls | Eliminates truncated tool args | [vLLM #34101](https://github.com/vllm-project/vllm/pull/34101) | +| **LogitsProcessor Interface** -- clean extensibility for custom logits | Plugin architecture for constrained decoding | [vLLM #13360](https://github.com/vllm-project/vllm/pull/13360) | + +### Tier 2: Evaluate Later (High Potential, High Complexity) + +| Optimization | Expected Impact | Upstream Ref | +|-------------|----------------|-------------| +| **Beam Search + Structured Output** -- best-quality tool arguments | Highest quality constrained generation | [vLLM #35022](https://github.com/vllm-project/vllm/pull/35022) | +| **Draft Model Speculative Decoding** -- faster generation overall | 1.5-2x decode speedup | [vLLM #32662](https://github.com/vllm-project/vllm/pull/32662) | +| **Inter-Prefill-Budget** -- better TTFT under concurrent load | 37% TTFT reduction for concurrent requests | [vLLM #33743](https://github.com/vllm-project/vllm/pull/33743) | +| **Tool Calling Redesign** -- better architecture for tool call parsing | Cleaner, more maintainable tool call path | [vLLM #22977](https://github.com/vllm-project/vllm/pull/22977) | + +### Estimated Combined Impact (Tier 0 + Tier 1) + +For 10-turn agentic conversations with tool calling (e.g., OpenClaw on MiniMax-M2.5): + +- **TTFT**: 2x improvement (pinned prefix cache + GC control) +- **Tool call reliability**: 95% -> 99%+ (structural tags + streaming fix) +- **Throughput consistency**: Eliminate GC-induced spikes +- **Long generation**: O(n^2) -> O(1) per-chunk overhead + +--- + ## Citation If you use vLLM-MLX in your research or project, please cite: diff --git a/baseline_benchmark.json b/baseline_benchmark.json new file mode 100644 index 00000000..cc73eb1e --- /dev/null +++ b/baseline_benchmark.json @@ -0,0 +1,340 @@ +{ + "timestamp": "2026-02-24 20:21:20", + "model": "/Users/raullenstudio/.lmstudio/models/lmstudio-community/MiniMax-M2.5-MLX-4bit", + "results": [ + { + "test": "ttft", + "results": [ + { + "label": "Short (50 tok)", + "ttft": 0.331, + "total_time": 0.923, + "decode_time": 0.592, + "decode_tps": 54.1, + "prompt_tokens": 0, + "completion_tokens": 32, + "content": "", + "reasoning": "The user asks: \"What is 2+2? Answer in one sentence.\"\n\nWe need to respond with a single sentence answer: \"2+2 equals", + "tool_calls": [], + "finish_reason": "length", + "chunks": 31 + }, + { + "label": "Medium (500 tok)", + "ttft": 0.753, + "total_time": 1.351, + "decode_time": 0.598, + "decode_tps": 53.5, + "prompt_tokens": 0, + "completion_tokens": 32, + "content": "", + "reasoning": "The user is asking for a summary in 3 bullet points of the complete history of computing from Charles Babbage to modern GPUs. The user repeats the request many", + "tool_calls": [], + "finish_reason": "length", + "chunks": 31 + }, + { + "label": "Long (2K tok)", + "ttft": 1.396, + "total_time": 2.011, + "decode_time": 0.615, + "decode_tps": 52.1, + "prompt_tokens": 0, + "completion_tokens": 32, + "content": "", + "reasoning": "We need to respond to user: \"You are a senior software architect. Review this code and explain every design pattern used: [code repeated many times]\" The", + "tool_calls": [], + "finish_reason": "length", + "chunks": 32 + } + ] + }, + { + "test": "decode_throughput", + "results": [ + { + "max_tokens": 128, + "ttft": 0.341, + "total_time": 2.763, + "decode_time": 2.422, + "decode_tps": 52.8, + "prompt_tokens": 0, + "completion_tokens": 128, + "content": "", + "reasoning": "We need to write exactly 128 words about the history of artificial intelligence. Must be detailed. Provide exactly 128 words. Count carefully. The content must be about the history of AI, with detail. Must be exactly 128 words. We need to output only the 128-word text.\n\nWe must ensure it's exactly 128 words. Let's count.\n\nWe need to produce a paragraph or multiple paragraphs but total words = 128.\n\nWe must be careful: count words. We'll need to write a detailed history. Use typical timeline: 1950s, Dartmouth workshop 1956, early AI, symbolic AI, expert systems, AI", + "tool_calls": [], + "finish_reason": "length", + "chunks": 120 + }, + { + "max_tokens": 512, + "ttft": 0.341, + "total_time": 10.182, + "decode_time": 9.842, + "decode_tps": 52.0, + "prompt_tokens": 0, + "completion_tokens": 512, + "content": "", + "reasoning": "The user asks: \"Write exactly 512 words about the history of artificial intelligence. Be detailed.\"\n\nWe need to produce exactly 512 words. The content should be detailed about the history of AI. Need to be careful about word count. Must be exactly 512 words. Must be a single response. No extra content, no extra explanation, just the text.\n\nWe need to count words precisely. Let's prepare a 512-word essay.\n\nFirst, think of key points: early concepts, logic, Turing test, Dartmouth conference, early AI (symbolic, perceptrons), AI winters, expert systems, resurgence with machine learning, deep learning, modern achievements, future.\n\nWe need to ensure we have exactly 512 words.\n\nWe need to be careful about counting words. Let's draft a text and count manually.\n\nWe'll write in a single paragraph or multiple, but overall 512 words.\n\nWe must count each word. Let's start building.\n\nWe can do a rough count first, then adjust.\n\nWrite a text about AI history.\n\nWe'll need to count words precisely. Let's try to write around 512 words. Use standard word counting: each token separated by spaces counts as a word. Hyphenated words count as one? Typically hyphenated words count as one. We'll avoid hyphens to simplify.\n\nWe must ensure exactly 512 words. Let's start.\n\nWe'll write:\n\n\"Artificial intelligence (AI) as a field traces its roots to the mid‑twentieth century, when mathematicians and logicians began exploring the possibility of mechanizing reasoning. In 1950, Alan Turing proposed the “imitation game,” later known as the Turing test, as a criterion for judging whether a machine could exhibit intelligent behavior indistinguishable from that of a human. This philosophical benchmark sparked early speculation and laid groundwork for future research.\n\nThe official birth of AI is commonly dated to the Dartmouth Conference of 1956, organized by John McCarthy, Marvin Minsky, Nathaniel Rochester, and Claude Shannon. The workshop’s proposal defined AI as “every aspect of learning or any other feature of intelligence that can in principle be described so precisely that a machine can be made to simulate it.” Early work focused on symbolic methods, theorem proving, and game playing. Allen Newell and Herbert A. Simon developed the Logic Theorist (1956) and the General Problem Solver (1957), demonstrating that machines could emulate human problem‑solving steps.\n\nDuring the 1960s, AI research expanded with the creation of perceptrons by Frank Rosenblatt (1958) and the first natural language processing", + "tool_calls": [], + "finish_reason": "length", + "chunks": 501 + }, + { + "max_tokens": 2048, + "ttft": 0.347, + "total_time": 41.66, + "decode_time": 41.313, + "decode_tps": 49.6, + "prompt_tokens": 0, + "completion_tokens": 2048, + "content": "", + "reasoning": "The user asks: \"Write exactly 2048 words about the history of artificial intelligence. Be detailed.\"\n\nWe need to produce exactly 2048 words. The content must be about the history of AI, detailed. We must ensure exactly 2048 words. Must count words.\n\nWe need to carefully write an essay of exactly 2048 words. Must be careful with counting. The user didn't ask for any disclaimer or anything else. So we must comply. There's no disallowed content. It's permissible. Must be exactly 2048 words. Let's produce a comprehensive history of AI.\n\nWe need to count words. Let's think: The content should be about the history of AI: early concepts, Turing, Dartmouth conference, early AI research, perceptrons, expert systems, AI winters, resurgence, machine learning, deep learning, modern achievements, future.\n\nWe need to produce exactly 2048 words. Must count words accurately. Let's plan: maybe about 20-30 paragraphs. But we need to count words precisely.\n\nWe can write and then count. Let's aim to produce final content of exactly 2048 words. We'll need a systematic method to count. Let's start drafting. We'll aim for about 2048 words. We'll need to track word count.\n\nWe'll start writing, and after each paragraph, we'll note approximate count. Use a systematic approach: We'll write in a text editor-like style, counting manually. We need to ensure exactly 2048 words, no more, no less. Let's start.\n\nWe'll write a detailed essay: maybe 30-40 paragraphs. But must count precisely.\n\nWe'll start with introduction: \"Artificial intelligence (AI) is a field of computer science...\" We'll count words.\n\nWe'll keep track: We'll manually count words, but it's huge. Let's approximate and then adjust.\n\nWe need exactly 2048 words. Let's try to produce around 2048 words. We'll count as we write. Let's do a detailed plan: We'll produce content with roughly 200-250 words per paragraph? That would be too many paragraphs. Let's maybe 10-12 paragraphs of ~170-200 words each yields about 2000 words. Then adjust with filler.\n\nAlternatively, we can produce a single large block and count manually. Let's create an outline and then write each paragraph, counting words.\n\nWe'll need to be meticulous.\n\nWe'll do this: Write a paragraph, count words, adjust. We'll keep a running total.\n\nWe must ensure no extra words (like extra spaces or line breaks). Word count counts tokens separated by whitespace. So we must count words exactly. We'll count manually.\n\nWe'll write with proper punctuation. But we must ensure each word is separated by spaces. Hyphenated words count as one? Typically hyphenated counts as one word? In typical word count, hyphenated words can be counted as one. But we can avoid hyphens. Use simple words.\n\nWe need to be careful: \"AI\" counts as one word. \"Artificial\" counts as one. \"Intelligence\" counts as one.\n\nWe'll produce final answer. Let's start counting.\n\nWe'll have a table: We'll keep a running total.\n\nParagraph 1: introduction.\n\nLet's write:\n\n\"Artificial intelligence (AI) is a discipline that seeks to create machines capable of mimicking, extending, or augmenting human cognition. The concept of intelligent machines appears in ancient myths and early philosophical speculation, but the scientific pursuit of AI began in the mid‑twentieth century. This overview traces the major milestones, breakthroughs, and setbacks that have shaped the field from its inception to the contemporary era of deep learning and autonomous systems.\"\n\nNow count words.\n\nWe'll count each word. Let's write them out with numbers:\n\n\"Artificial(1) intelligence(2) (AI)(3) is(4) a(5) discipline(6) that(7) seeks(8) to(9) create(10) machines(11) capable(12) of(13) mimicking,(14) extending,(15) or(16) augmenting(17) human(18) cognition.(19) The(20) concept(21) of(22) intelligent(23) machines(24) appears(25) in(26) ancient(27) myths(28) and(29) early(30) philosophical(31) speculation,(32) but(33) the(34) scientific(35) pursuit(36) of(37) AI(38) began(39) in(40) the(41) mid‑twentieth(42) century.(43) This(44) overview(45) traces(46) the(47) major(48) milestones,(49) breakthroughs,(50) and(51) setbacks(52) that(53) have(54) shaped(55) the(56) field(57) from(58) its(59) inception(60) to(61) the(62) contemporary(63) era(64) of(65) deep(66) learning(67) and(68) autonomous(69) systems.(70)\n\nSo paragraph 1 has 70 words.\n\nParagraph 2: Early foundations.\n\nWrite:\n\n\"The intellectual groundwork for AI was laid by mathematicians and logicians such as Alan Turing, Alonzo Church, and Kurt Gödel. In 1936, Turing introduced the notion of a universal computing machine, later called a Turing machine, which formalized the concept of algorithmic computation. Church formulated the lambda calculus, providing a theoretical framework for functions and recursion. Gödel’s incompleteness theorems demonstrated the limits of formal systems, raising profound questions about mechanizing reasoning. These ideas established a mathematical basis for the possibility of machine intelligence.\"\n\nNow count words.\n\n\"The(1) intellectual(2) groundwork(3) for(4) AI(5) was(6) laid(7) by(8) mathematicians(9) and(10) logicians(11) such(12) as(13) Alan(14) Turing,(15) Alonzo(16) Church,(17) and(18) Kurt(19) Gödel.(20) In(21) 1936,(22) Turing(23) introduced(24) the(25) notion(26) of(27) a(28) universal(29) computing(30) machine,(31) later(32) called(33) a(34) Turing(35) machine,(36) which(37) formalized(38) the(39) concept(40) of(41) algorithmic(42) computation.(43) Church(44) formulated(45) the(46) lambda(47) calculus,(48) providing(49) a(50) theoretical(51) framework(52) for(53) functions(54) and(55) recursion.(56) Gödel’s(57) incompleteness(58) theorems(59) demonstrated(60) the(61) limits(62) of(63) formal(64) systems,(65) raising(66) profound(67) questions(68) about(69) mechanizing(70) reasoning.(71) These(72) ideas(73) established(74) a(75) mathematical(76) basis(77) for(78) the(79) possibility(80) of(81) machine(82) intelligence.(83)\n\nParagraph 2: 83 words. Running total: 70 + 83 = 153.\n\nParagraph 3: The Dartmouth conference and birth of AI.\n\nWrite:\n\n\"The official birth of AI as a research field is commonly dated to the Dartmouth Summer Research Project on Artificial Intelligence in 1956. Organized by John McCarthy, Marvin Minsky, Nathaniel Rochester, and Claude Shannon, the conference brought together researchers interested in symbolic reasoning, heuristic search, and problem‑solving programs. McCarthy coined the term “artificial intelligence” during the proposal, defining it as “the science and engineering of making intelligent machines.” The meeting set an ambitious agenda: to develop programs that could simulate every aspect of human learning or intelligence. It also catalyzed funding from the United States Department of Defense and major corporations, spawning early AI labs at MIT, Stanford, and Carnegie Mellon.\"\n\nNow count words.\n\n\"The(1) official(2) birth(3) of(4) AI(5) as(6) a(7) research(8) field(9) is(10) commonly(11) dated(12) to(13) the(14) Dartmouth(15) Summer(16) Research(17) Project(18) on(19) Artificial(20) Intelligence(21) in(22) 1956.(23) Organized(24) by(25) John(26) McCarthy,(27) Marvin(28) Minsky,(29) Nathaniel(30) Rochester,(31) and(32) Claude(33) Shannon,(34) the(35) conference(36) brought(37) together(38) researchers(39) interested(40) in(41) symbolic(42) reasoning,(43) heuristic(44) search,(45) and(46) problem‑solving(47) programs.(48) McCarthy(49) coined(50) the(51) term(52) “artificial(53) intelligence”(54) during(55) the(56) proposal,(57) defining(58) it(59) as(60) “the(61) science(62) and(63) engineering(64) of(65) making(66) intelligent(67) machines.”(68) The(69) meeting(70) set(71) an(72) ambitious(73) agenda:(74) to(75) develop(76) programs(77) that(78) could(79) simulate(80) every(81) aspect(82) of(83) human(84) learning(85) or(86) intelligence.(87) It(88) also(89) catalyzed(90) funding(91) from(92)", + "tool_calls": [], + "finish_reason": "length", + "chunks": 2018 + } + ] + }, + { + "test": "prefix_cache", + "results": [ + { + "turn": 1, + "user_msg": "How do I implement a thread-safe LRU cache in Python?", + "ttft": 0.497, + "total_time": 5.389, + "decode_time": 4.892, + "decode_tps": 52.3, + "prompt_tokens": 0, + "completion_tokens": 256, + "content": "", + "reasoning": "The user is asking about implementing a thread-safe LRU (Least Recently Used) cache in Python. Let me think about the best approach:\n\n1. **LRU Cache concepts**: An LRU cache evicts the least recently used items when the cache reaches its capacity. It needs to support:\n - `get(key)` - returns the value if key exists, else None\n - `put(key, value)` - adds or updates the key-value pair, evicting LRU item if capacity is reached\n\n2. **Thread-safety**: Multiple threads can access the cache concurrently without data corruption. This requires proper locking.\n\n3. **Implementation approaches**:\n - Use `collections.OrderedDict` (Python 3.7+) which maintains insertion order\n - Use `threading.Lock` or `threading.RLock` for thread safety\n - Or use `functools.lru_cache` for simple cases (but it's not easily customizable for thread-safe LRU with maxsize)\n\nLet me implement a comprehensive solution:\n\n1. **Basic approach using OrderedDict + Lock**\n2. **Alternative using threading + dict + doubly linked list**\n3. **Using concurrent.futures for a more advanced approach**\n\nLet me provide a well-documented implementation with", + "tool_calls": [], + "finish_reason": "length", + "chunks": 255 + }, + { + "turn": 2, + "user_msg": "Can you add TTL (time-to-live) expiration to that cache?", + "ttft": 0.585, + "total_time": 5.488, + "decode_time": 4.903, + "decode_tps": 52.2, + "prompt_tokens": 0, + "completion_tokens": 256, + "content": "# Thread-Safe LRU Cache with TTLHere's an extended implementation that adds time-to-live expiration:\n\n```pythonimport threadingimport timefrom collections import OrderedDictfrom typing import Any, Optional, Tupleimport logginglogging.basicConfig(level=logging.DEBUG)\nlogger = logging.getLogger(__name__)\n\n\nclass ThreadSafeLRUCacheWithTTL:\n \"\"\"\n Thread-safe LRU Cache with TTL (Time-To-Live) expiration.\n Features:\n - LRU eviction when capacity is reached - TTL expiration for entries - Thread-safe operations - Automatic cleanup of expired entries \"\"\"\n def __init__(self, capacity: int, default_ttl: float", + "reasoning": "The user wants me to extend the thread-safe LRU cache with TTL (time-to-live) expiration. This means each cache entry should expire after a certain time period.\n\nI'll need to:\n1. Track when each entry was added/accessed\n2. Periodically clean up expired entries\n3. Check expiration on access\n\nLet me implement this with:\n- OrderedDict or a custom doubly-linked list for LRU ordering\n- Thread-safe operations using threading locks\n- Timestamp tracking for each entry\n- A cleanup mechanism for expired entries\n", + "tool_calls": [], + "finish_reason": "length", + "chunks": 232 + }, + { + "turn": 3, + "user_msg": "Now make it work with async/await. Show the full implementat", + "ttft": 0.871, + "total_time": 5.796, + "decode_time": 4.925, + "decode_tps": 52.0, + "prompt_tokens": 0, + "completion_tokens": 256, + "content": "# Async LRU Cache with TTLHere's the full async implementation with all the same features:\n\n```pythonimport asyncioimport timefrom collections import OrderedDictfrom typing import Any, Optional, Dict, Tupleimport loggingfrom contextlib import asynccontextmanagerlogging.basicConfig(level=logging.DEBUG)\nlogger = logging.getLogger(__name__)\n\n\nclass AsyncLRUCacheWithTTL:\n \"\"\"\n Async-safe LRU Cache with TTL (Time-To-Live) expiration.\n Features:\n - LRU eviction when capacity is reached - TTL expiration for individual entries - Async-safe operations with asyncio locks - Optional default values - Statistics tracking - Context manager support \"\"\"\n def __init__(\n self, capacity: int = 100, default_ttl: Optional[float", + "reasoning": "The user wants me to create an async version of the thread-safe LRU cache with TTL. I'll need to:\n\n1. Use asyncio locks instead of threading locks\n2. Make the methods async\n3. Possibly add async context managers for cleaner syntax\n4. Keep all the existing functionality (LRU, TTL, thread-safety)\n\nLet me create a comprehensive async LRU cache implementation.\n", + "tool_calls": [], + "finish_reason": "length", + "chunks": 221 + }, + { + "turn": 4, + "user_msg": "What are the performance characteristics? Add benchmarks.", + "ttft": 1.04, + "total_time": 6.024, + "decode_time": 4.984, + "decode_tps": 51.4, + "prompt_tokens": 0, + "completion_tokens": 256, + "content": "# Performance Analysis & BenchmarksHere's a comprehensive analysis with benchmarks:\n\n```pythonimport asyncioimport timeimport threadingimport statisticsfrom collections import OrderedDictfrom typing import Any, Optional, Dict, Tupleimport loggingfrom functools import lru_cache as std_lru_cachefrom contextlib import asynccontextmanagerlogging.basicConfig(level=logging.WARNING)\nlogger = logging.getLogger(__name__)\n\n\n# ============================================================================\n# Async LRU Cache with TTL (Full Implementation)\n# ============================================================================\n\nclass AsyncLRUCacheWithTTL:\n \"\"\"Async-safe LRU Cache with TTL (Time-To-Live) expiration.\"\"\"\n\n def __init__(self, capacity: int = 100, default_ttl: float = 300):\n self._capacity = capacity self._default_ttl =", + "reasoning": "The user wants to know the performance characteristics of the async LRU cache with TTL and see benchmarks. I should:\n\n1. Explain the time/space complexity\n2. Add benchmark code to measure:\n - Get operations\n - Set operations\n - Concurrent access performance\n - TTL expiration performance\n - Comparison with other implementations (functools.lru_cache, cachetools)\n\nLet me create a comprehensive benchmark suite.\n", + "tool_calls": [], + "finish_reason": "length", + "chunks": 236 + } + ] + }, + { + "test": "tool_call", + "accuracy_pct": 100.0, + "correct": 4, + "total": 4, + "results": [ + { + "label": "Simple tool call", + "has_tool_call": true, + "json_valid": true, + "correct": true, + "tool_names": [ + "get_weather" + ], + "ttft": 1.049, + "total_time": 2.005, + "decode_time": 0.956, + "decode_tps": 52.3, + "prompt_tokens": 364, + "completion_tokens": 50, + "content": "", + "reasoning": "The user wants to know the current weather in Tokyo. I'll use the get_weather function to retrieve this information.\n", + "tool_calls": [ + "ChoiceDeltaToolCall(index=0, id='call_c8b3ac20', function=ChoiceDeltaToolCallFunction(arguments='{\"location\": \"Tokyo\"}', name='get_weather'), type='function')" + ], + "finish_reason": "tool_calls", + "chunks": 24 + }, + { + "label": "Multi-arg tool call", + "has_tool_call": true, + "json_valid": true, + "correct": true, + "tool_names": [ + "search_web" + ], + "ttft": 1.039, + "total_time": 2.514, + "decode_time": 1.475, + "decode_tps": 52.2, + "prompt_tokens": 374, + "completion_tokens": 77, + "content": "", + "reasoning": "The user wants me to search for information about 'vLLM MLX benchmarks' and show them 5 results. I'll use the search_web function to find this information.\n", + "tool_calls": [ + "ChoiceDeltaToolCall(index=0, id='call_a4cd4a2b', function=ChoiceDeltaToolCallFunction(arguments='{\"query\": \"vLLM MLX benchmarks\", \"num_results\": 5}', name='search_web'), type='function')" + ], + "finish_reason": "tool_calls", + "chunks": 35 + }, + { + "label": "Code execution", + "has_tool_call": true, + "json_valid": true, + "correct": true, + "tool_names": [ + "run_python" + ], + "ttft": 1.037, + "total_time": 3.993, + "decode_time": 2.956, + "decode_tps": 52.1, + "prompt_tokens": 370, + "completion_tokens": 154, + "content": "", + "reasoning": "The user wants me to calculate the Fibonacci sequence up to n=20 using Python. I'll write a Python script to generate the Fibonacci sequence.\n", + "tool_calls": [ + "ChoiceDeltaToolCall(index=0, id='call_152bc085', function=ChoiceDeltaToolCallFunction(arguments='{\"code\": \"def fibonacci(n):\\\\n fib_seq = [0, 1]\\\\n for i in range(2, n):\\\\n fib_seq.append(fib_seq[-1] + fib_seq[-2])\\\\n return fib_seq[:n]\\\\n\\\\n# Calculate Fibonacci sequence up to n=20\\\\nn = 20\\\\nresult = fibonacci(n)\\\\nprint(f\\\\\"Fibonacci sequence up to n={n}:\\\\\")\\\\nfor i, val in enumerate(result):\\\\n print(f\\\\\"F({i}) = {val}\\\\\")\"}', name='run_python'), type='function')" + ], + "finish_reason": "tool_calls", + "chunks": 29 + }, + { + "label": "Tool selection", + "has_tool_call": true, + "json_valid": true, + "correct": true, + "tool_names": [ + "get_weather", + "search_web" + ], + "ttft": 1.06, + "total_time": 3.042, + "decode_time": 1.982, + "decode_tps": 52.0, + "prompt_tokens": 372, + "completion_tokens": 103, + "content": "", + "reasoning": "The user wants two things:\n1. Get the weather in Paris\n2. Search for restaurants in Paris\n\nThese two requests are independent of each other, so I can make both calls at the same time.\n", + "tool_calls": [ + "ChoiceDeltaToolCall(index=0, id='call_f9ba3378', function=ChoiceDeltaToolCallFunction(arguments='{\"location\": \"Paris\"}', name='get_weather'), type='function')", + "ChoiceDeltaToolCall(index=1, id='call_8b27a832', function=ChoiceDeltaToolCallFunction(arguments='{\"num_results\": 10, \"query\": \"best restaurants in Paris\"}', name='search_web'), type='function')" + ], + "finish_reason": "tool_calls", + "chunks": 43 + } + ] + }, + { + "test": "reasoning", + "results": [ + { + "label": "Math", + "has_reasoning": true, + "has_content": true, + "separated": true, + "reasoning_chars": 1013, + "content_chars": 259, + "ttft": 0.34, + "total_time": 8.984, + "decode_time": 8.644, + "decode_tps": 52.1, + "prompt_tokens": 51, + "completion_tokens": 450, + "content": "Sure! Let’s multiply17by23 step by step.\n\n1. **Break down one of the numbers** Write 23 as \\(20 + 3\\).\n\n2. **Multiply each part by 17** - \\(17 \\times 20 = 340\\) - \\(17 \\times 3 = 51\\)\n\n3. **Add the two products** \\(340 + 51 = 391\\)\n\nSo,\\[\n17 \\times 23 = 391\\]", + "reasoning": "The user asks: \"What is 17 * 23? Think step by step.\"\n\nWe need to compute 17 * 23. That's straightforward: 17 * 23 = 391? Let's compute: 17*20 = 340, 17*3 = 51, sum = 391. Yes.\n\nBut they want step by step. So we can show multiplication: 17 * 23 = (17 * (20 + 3)) = 17*20 + 17*3 = 340 + 51 = 391. Or we can use standard multiplication: 23*17 = 23*10 + 23*7 = 230 + 161 = 391. Or we can do 17*23 = (10+7)*23 = 230 + 161 = 391. So answer is 391.\n\nWe can also show using long multiplication: 17 * 23 = 17*3 + 17*20 = 51 + 340 = 391.\n\nThus answer: 391.\n\nWe should respond with step-by-step reasoning. That is fine.\n\nWe must ensure we follow policy. The user is asking a simple arithmetic question, no policy issues. So we can comply.\n\nThus final answer: 391 with explanation.\n\nWe can also mention that we can verify: 23*17 = 23*10 + 23*7 = 230 + 161 = 391.\n\nThus answer: 391.\n\nWe should respond in a friendly tone, step by step.\n\nWe can also mention that we can break down 23 as 20+3, multiply each, sum.\n\nThus final.\n", + "tool_calls": [], + "finish_reason": null, + "chunks": 374 + }, + { + "label": "Logic", + "has_reasoning": true, + "has_content": false, + "separated": false, + "reasoning_chars": 2154, + "content_chars": 0, + "ttft": 0.373, + "total_time": 10.255, + "decode_time": 9.881, + "decode_tps": 51.8, + "prompt_tokens": 0, + "completion_tokens": 512, + "content": "", + "reasoning": "We need to answer a logical reasoning question: \"If all roses are flowers and some flowers fade quickly, can we conclude all roses fade quickly? Explain your reasoning.\"\n\nWe need to discuss the logical structure: All roses are flowers (universal affirmative: All R are F). Some flowers fade quickly (particular affirmative: Some F are Q, where Q = \"fade quickly\"). The question: can we infer that all roses fade quickly? That would be \"All R are Q\". Is that valid? No, because we only know that some flowers fade quickly, not that all flowers fade quickly, and we don't know whether the roses are among those that fade quickly. The premises do not guarantee that. It's a logical fallacy: the categorical syllogism would need a second premise linking roses to the subset of flowers that fade quickly. So we cannot conclude that all roses fade quickly. The reasoning: The set of roses is a subset of the set of flowers, but the set of flowers that fade quickly is some subset of the set of flowers. There's no guarantee that the roses are in that subset. The only possible conclusion is \"some roses fade quickly\" (i.e., some R are Q). That follows: from \"All R are F\" and \"Some F are Q\", we can infer \"Some R are Q\" (by conversion? Actually, we can use the syllogism: All R are F (universal affirmative), Some F are Q (particular affirmative). The middle term is F, which is distributed in the first premise (All R are F) and in the second premise (Some F are Q). The conclusion can be \"Some R are Q\" (particular affirmative). This is valid: it's a valid syllogism in Barbara? Wait, Barbara is All M are P, All S are M => All S are P. That's universal to universal. But we have particular in second premise. The valid mood is \"AII-1\"? Let's recall the valid syllogism forms: The mood \"AII-1\" (All M are P, Some M are S, therefore Some S are P) is valid? Actually, \"All M are P\" (A), \"Some M are S\" (I) yields \"Some S are P\" (I). That's valid: it's the first figure. So yes, we can infer \"some roses fade quickly\". But not \"all roses fade quickly\". So answer: No, we cannot conclude that. Provide explanation: The premises do not provide enough info, it's a", + "tool_calls": [], + "finish_reason": "length", + "chunks": 512 + }, + { + "label": "Code", + "has_reasoning": true, + "has_content": false, + "separated": false, + "reasoning_chars": 2011, + "content_chars": 0, + "ttft": 0.377, + "total_time": 10.239, + "decode_time": 9.862, + "decode_tps": 51.9, + "prompt_tokens": 0, + "completion_tokens": 512, + "content": "", + "reasoning": "The user asks: \"What's wrong with this code: `x = [1,2,3]; x[5]`? Think carefully.\"\n\nWe need to respond with an answer. The code is in some language, likely Python: x = [1,2,3]; x[5]. In Python, indexing a list of length 3 with index 5 yields an IndexError: list index out of range. So the problem is that index 5 is out of range. The code tries to access element at index 5, but the list only has indices 0,1,2. So it's an IndexError.\n\nBut perhaps the user expects a deeper answer: The code is trying to access index 5 which is beyond the length of the list, causing an IndexError. The code is also missing a semicolon? In Python, semicolons are allowed but not required. The code might be in JavaScript? In JavaScript, x = [1,2,3]; x[5] returns undefined (since arrays are sparse). In JavaScript, accessing index 5 returns undefined, not an error. But the code is ambiguous. The user didn't specify language. The syntax `x = [1,2,3]; x[5]` could be either Python or JavaScript. In Python, it's an IndexError. In JavaScript, it's undefined. The question: \"What's wrong with this code: `x = [1,2,3]; x[5]`? Think carefully.\"\n\nWe need to think: The code is trying to access index 5 of a list of length 3. That's out-of-range. So the answer: It's an IndexError. Possibly also mention that in some languages like JavaScript it returns undefined, but in Python it's an error. The user likely expects the answer: The index is out of bounds.\n\nWe can also mention that the semicolon is unnecessary in Python, but it's allowed. The code is trying to access the sixth element (index 5) but there are only three elements (indices 0-2). So it's an IndexError.\n\nWe should also mention that if they intended to get the last element, they could use x[-1] or x[2]. Or they could extend the list.\n\nThus the answer: The code is wrong because it tries to access an index that doesn't exist. In Python, it raises an IndexError. In JavaScript, it returns undefined, but it's still not correct if you expect a value.\n\nWe should also", + "tool_calls": [], + "finish_reason": "length", + "chunks": 501 + } + ] + }, + { + "test": "long_generation", + "completed": true, + "ttft": 0.457, + "total_time": 248.967, + "decode_time": 248.51, + "decode_tps": 33.0, + "prompt_tokens": 0, + "completion_tokens": 8192, + "content": "# Building a Distributed System on Apple Silicon: A Comprehensive Technical Guide## IntroductionThe advent of Apple Silicon represents a paradigm shift in computing architecture, bringing ARM-based systems with unified memory architecture, high-performance GPU cores, and dedicated Neural Engine processing units to professional workloads. These capabilities open unprecedented opportunities for building distributed systems that can leverage tight integration between compute, memory, and specialized hardware. This guide provides a comprehensive technical exploration of designing, implementing, and operating a distributed system specifically optimized for Apple Silicon hardware.\n\nApple Silicon's M-series chips (M1, M2, M3, and their variants) feature a unified memory architecture that eliminates the traditional separation between CPU and GPU memory spaces. This design choice dramatically reduces latency for data transfer between processing units—a characteristic that proves transformative for distributed systems where inter-node communication latency often dominates performance profiles. Additionally, the efficiency of these chips enables high-density deployment scenarios, allowing operators to run more instances per rack unit while maintaining favorable power consumption characteristics.\n\nThis guide assumes familiarity with distributed systems concepts, proficiency in Swift or Go (we'll use both for different components), and access to a cluster of Apple Silicon machines. We'll build a reference implementation of a distributed key-value store with consensus capabilities, though the architectural patterns discussed apply broadly to various distributed system designs.\n\n---\n\n## 1. Architecture Design### 1.1 Apple Silicon Architectural ConsiderationsBefore diving into system design, we must understand the unique characteristics of Apple Silicon that influence architectural decisions:\n\n**Unified Memory Architecture (UMA)**: Unlike traditional x86_64 systems where CPU and GPU maintain separate memory pools, Apple Silicon provides a single pool of high-bandwidth memory accessible by all processing units. This eliminates the need for explicit memory copying between CPU and GPU workloads—a consideration that becomes relevant when designing systems that offload computation to the Neural Engine or GPU.\n\n**Core Configuration**: M-series chips combine performance cores (Firestorm) and efficiency cores (Icestorm), along with GPU cores and a 16-core Neural Engine. A well-designed distributed system should leverage efficiency cores for I/O-bound operations like network handling while utilizing performance cores for compute-intensive consensus protocols.\n\n**Scale-Out Capabilities**: While individual Apple Silicon units offer substantial compute, the primary advantage lies in horizontal scaling. The Mac Studio can support up to 128GB unified memory, while Mac Mini with M2 Pro supports up to 32GB—suitable for memory-intensive distributed workloads.\n\n### 1.2 System Architecture OverviewOur reference architecture follows a replicated state machine approach, where multiple nodes maintain identical state and coordinate through a consensus protocol. The system consists of three primary layers:\n\n```\n┌─────────────────────────────────────────────────────────────────┐│ APPLICATION LAYER │\n│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │\n│ │ Client API │ │ Query Path │ │ Admin API │ │\n│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │\n└─────────┼────────────────┼────────────────┼─────────────────────┘ │ │ │\n ▼ ▼ ▼┌─────────────────────────────────────────────────────────────────┐│ SERVICE LAYER │\n│ ┌─────────────────────────────────────────────────────────┐ │\n│ │ Request Router / Load Balancer │ │\n│ └─────────────────────────────────────────────────────────┘ │\n│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │\n│ │ Read Path │ │ Write Path │ │ Health Svc │ │\n│ └──────────────┘ └──────────────┘ └──────────────┘ │\n└─────────────────────────────────────────────────────────────────┘ │ │ │\n ▼ ▼ ▼┌─────────────────────────────────────────────────────────────────┐│ CONSENSUS LAYER │\n│ ┌─────────────────────────────────────────────────────────┐ │\n│ │ Raft Consensus Module │ │\n│ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ │ │\n│ │ │ Log │ │ State │ │ Config │ │ │\n│ │ │ Manager │ │ Machine │ │ Manager │ │ │\n│ │ └────────────┘ └────────────┘ └────────────┘ │ │\n│ └─────────────────────────────────────────────────────────┘ │\n└─────────────────────────────────────────────────────────────────┘ │ │ │\n ▼ ▼ ▼┌─────────────────────────────────────────────────────────────────┐│ REPLICATION LAYER │\n│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │\n│ │ Network │ │ Failure │ │ Snapshot │ │\n│ │ Transport │ │ Detector │ │ Manager │ │\n│ └──────────────┘ └──────────────┘ └──────────────┘ │\n└─────────────────────────────────────────────────────────────────┘```\n\n### 1.3 Component Design#### 1.3.1 Node ArchitectureEach node in our distributed system runs as a Swift application with the following internal structure:\n\n```swiftimport Foundationimport NIOimport Logging/// Core node configuration for Apple Silicon deploymentpublic struct NodeConfig {\n public let nodeId: UUID public let listenAddress: String public let port: Int public let dataDirectory: URL public let consensusPort: Int public let metricsPort: Int public let useNeuralEngine: Bool // Enable ML-based failure detection public init(\n nodeId: UUID = UUID(),\n listenAddress: String = \"0.0.0.0\",\n port: Int = 8080,\n dataDirectory: URL = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!,\n consensusPort: Int = 8081,\n metricsPort: Int = 9090,\n useNeuralEngine: Bool = false ) {\n self.nodeId = nodeId self.listenAddress = listenAddress self.port = port self.dataDirectory = dataDirectory self.consensusPort = consensusPort self.metricsPort = metricsPort self.useNeuralEngine = useNeuralEngine }\n}\n\n/// Main node coordinator managing all system componentspublic final class DistributedNode: @unchecked Sendable {\n private let config: NodeConfig private let eventLoopGroup: MultiThreadedEventLoopGroup private let consensusModule: RaftConsensusModule private let transport: NetworkTransport private let stateMachine: KeyValueStateMachine private let failureDetector: FailureDetector private let metricsCollector: MetricsCollector private var isRunning = false public init(config: NodeConfig) throws {\n self.config = config // Optimize event loop configuration for Apple Silicon // Use performance cores for consensus processing let threadCount = System.coreCount > 8 ? 8 : System.coreCount self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: threadCount)\n // Initialize core components self.stateMachine = try KeyValueStateMachine(\n dataDirectory: config.dataDirectory,\n nodeId: config.nodeId )\n self.consensusModule = try RaftConsensusModule(\n nodeId: config.nodeId,\n nodeAddress: config.listenAddress,\n port: config.consensusPort,\n stateMachine: stateMachine,\n eventLoopGroup: eventLoopGroup )\n self.transport = NetworkTransport(\n eventLoopGroup: eventLoopGroup,\n listenAddress: config.listenAddress,\n port: config.port,\n consensusPort: config.consensusPort )\n self.failureDetector = FailureDetector(\n nodeId: config.nodeId,\n eventLoopGroup: eventLoopGroup,\n useNeuralEngine: config.useNeuralEngine )\n self.metricsCollector = MetricsCollector(\n nodeId: config.nodeId,\n port: config.metricsPort )\n }\n public func start() async throws {\n guard !isRunning else { return }\n // Start metrics endpoint first for observability try await metricsCollector.start()\n // Initialize and start consensus module try await consensusModule.initialize()\n // Start network transport try await transport.start()\n // Register with cluster discovery try await registerWithCluster()\n isRunning = true print(\"Node \\(config.nodeId) started on \\(config.listenAddress):\\(config.port)\")\n }\n public func stop() async throws {\n guard isRunning else { return }\n try await consensusModule.shutdown()\n try await transport.shutdown()\n try await metricsCollector.shutdown()\n try await eventLoopGroup.shutdownGracefully()\n isRunning = false }\n}\n```\n\n#### 1.3.2 Memory Optimization for Apple SiliconThe unified memory architecture allows us to implement zero-copy data paths where appropriate. Here's how we optimize memory usage:\n\n```swiftimport Foundationimport Accelerate/// Memory-optimized buffer management leveraging unified memorypublic final class SiliconMemoryAllocator: @unchecked Sendable {\n private let pageSize: Int private let totalMemory: UInt64 private var availableMemory: UInt64 public init() {\n // Get system memory info var stats = vm_statistics64()\n var count = mach_msg_type_number_t(MemoryLayout.size / MemoryLayout.size)\n \n let result = withUnsafeMutablePointer(to: &stats) {\n $0.withMemoryRebound(to: integer_t.self, capacity: Int(count)) {\n host_statistics64(mach_host_self(), HOST_VM_INFO64, $0, &count)\n }\n }\n \n self.pageSize = Int(vm_page_size)\n self.totalMemory = ProcessInfo.processInfo.physicalMemory\n \n if result == KERN_SUCCESS {\n let freePages = UInt64(stats.free_count)\n let inactivePages = UInt64(stats.inactive_count)\n self.availableMemory = (freePages + inactivePages) * UInt64(pageSize)\n } else {\n self.availableMemory = totalMemory / 2\n }\n }\n \n /// Allocate buffer optimized for Neural Engine access\n public func allocateNeuralBuffer(size: Int) -> UnsafeMutableRawPointer? {\n // Neural Engine prefers 16-byte aligned buffers\n let alignment = 16\n var ptr: UnsafeMutableRawPointer?\n \n let result = posix_memalign(&ptr, alignment, size)\n guard result == 0, let allocated = ptr else {\n return nil\n }\n \n // Touch pages to ensure physical allocation in unified memory\n let pageAlignedSize = (size + pageSize - 1) / pageSize * pageSize\n madvise(allocated, pageAlignedSize, MADV_WILLNEED)\n \n return allocated\n }\n \n /// Calculate optimal buffer sizes based on available memory\n public func optimalBufferSize(forRequest requestSize: Int, connectionCount: Int) -> Int {\n let memoryPerConnection = availableMemory / UInt64(connectionCount * 4)\n let maxBuffer = min(UInt64(requestSize * 16), memoryPerConnection)\n return Int(min(maxBuffer, UInt64(Int.max)))\n }\n}\n```\n\n---\n\n## 2. Networking\n\n### 2.1 Network Topology Design\n\nNetwork architecture for Apple Silicon distributed systems must account for the physical limitations of Mac hardware while maximizing throughput. We recommend a hierarchical topology:\n\n```\n ┌─────────────────┐\n │ Load Balancer │\n │ (Software) │\n └────────┬────────┘\n │\n ┌──────────────────────┼──────────────────────┐\n │ │ │\n ▼ ▼ ▼\n ┌─────────────┐ ┌─────────────┐ ┌─────────────┐\n │ Node 1 │ │ Node 2 │ │ Node 3 │\n │ 192.168.1.10│◄─────►│ 192.168.1.11│◄─────►│ 192.168.1.12│\n │ │ gRPC │ │ gRPC │ │\n └─────────────┘ └─────────────┘ └─────────────┘\n │ │ │\n └──────────────────────┼──────────────────────┘\n │\n ▼\n ┌─────────────────────────┐\n │ Inter-node Consensus │\n │ (Dedicated Ports) │\n └─────────────────────────┘\n```\n\nKey networking principles:\n\n1. **Dedicated Consensus Ports**: Separate ports for consensus traffic (typically 8081) and client traffic (8080) allow independent scaling and QoS management.\n2. **gRPC for Inter-Node Communication**: We use gRPC with Protocol Buffers for efficient serialization and built-in streaming capabilities.\n3. **HTTP/1.1 for Client API**: REST-style endpoints for client compatibility, with HTTP/2 under the hood for multiplexing.\n\n### 2.2 Network Transport Implementation\n\n```swift\nimport Foundation\nimport NIO\nimport NIOSSL\nimport Logging\n\n/// Network transport layer handling all inter-node and client communication\npublic final class NetworkTransport: @unchecked Sendable {\n private let eventLoopGroup: MultiThreadedEventLoopGroup\n private let serverBootstrap: ServerBootstrap\n private let clientBootstrap: ClientBootstrap\n private let channel: Channel?\n \n private let listenAddress: String\n private let port: Int\n private let consensusPort: Int\n \n private let logger: Logger\n private var activeConnections = Set()\n private let connectionLock = NSLock()\n \n public init(\n eventLoopGroup: MultiThreadedEventLoopGroup,\n listenAddress: String,\n port: Int,\n consensusPort: Int\n ) {\n self.eventLoopGroup = eventLoopGroup\n self.listenAddress = listenAddress\n self.port = port\n self.consensusPort = consensusPort\n \n var logger = Logger(label: \"com.distributed.transport\")\n logger.logLevel = .info\n self.logger = logger\n \n // Configure server bootstrap optimized for Apple Silicon\n self.serverBootstrap = ServerBootstrap(group: eventLoopGroup)\n .serverChannelOption(ChannelOptions.backlog, value: 256)\n .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)\n .serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 16)\n .serverChannelOption(ChannelOptions.recvAllocator, value: AdaptiveRecvByteBufferAllocator())\n .childChannelInitializer { channel in\n // Add timeout and idle state handlers\n let idleHandler = IdleStateHandler(\n readTimeout: .seconds(30),\n writeTimeout: .seconds(10),\n allTimeout: .seconds(60)\n )\n \n return channel.pipeline.addHandler(idleHandler).flatMap {\n // Add protobuf codec\n let protobufHandler = ProtobufPacketHandler()\n return channel.pipeline.addHandler(protobufHandler)\n }\n }\n .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)\n .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 16)\n .childChannelOption(ChannelOptions.recvAllocator, value: AdaptiveRecvByteBufferAllocator())\n .childChannelOption(ChannelOptions.writeBufferWaterMark, low: 32 * 1024, high: 128 * 1024)\n \n // Configure client bootstrap\n self.clientBootstrap = ClientBootstrap(group: eventLoopGroup)\n .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)\n .channelOption(ChannelOptions.maxMessagesPerRead, value: 16)\n .channelOption(ChannelOptions.recvAllocator, value: AdaptiveRecvByteBufferAllocator())\n .channelOption(ChannelOptions.connectTimeout, value: .seconds(10))\n }\n \n public func start() async throws {\n let channel = try await serverBootstrap.bind(host: listenAddress, port: port).get()\n self.channel = channel\n logger.info(\"Transport started on \\(listenAddress):\\(port)\")\n }\n \n /// Send message to specific node with retry logic\n public func sendMessage(\n to address: String,\n port: Int,\n message: NetworkMessage,\n retryCount: Int = 3\n ) async throws -> NetworkMessage {\n var lastError: Error?\n \n for attempt in 0.. [PeerAddress: NetworkMessage] {\n try await withThrowingTaskGroup(of: (PeerAddress, NetworkMessage).self) { group in\n for peer in peers {\n group.addTask {\n let response = try await self.sendMessage(\n to: peer.address,\n port: peer.consensusPort,\n message: message\n )\n return (peer, response)\n }\n }\n \n var results: [PeerAddress: NetworkMessage] = [:]\n for try await (peer, response) in group {\n results[peer] = response\n }\n return results\n }\n }\n}\n```\n\n### 2.3 Service Discovery\n\nService discovery enables nodes to find each other dynamically:\n\n```swift\nimport Foundation\nimport NIO\n\n/// Peer address representation\npublic struct PeerAddress: Hashable, Codable, Sendable {\n public let nodeId: UUID\n public let address: String\n public let port: Int\n public let consensusPort: Int\n public let isLeader: Bool\n \n public init(nodeId: UUID, address: String, port: Int, consensusPort: Int, isLeader: Bool = false) {\n self.nodeId = nodeId\n self.address = address\n self.port = port\n self.consensusPort = consensusPort\n self.isLeader = isLeader\n }\n}\n\n/// Service discovery module using gossip protocol\npublic final class ServiceDiscovery: @unchecked Sendable {\n private let nodeId: UUID\n private let eventLoopGroup: MultiThreadedEventLoopGroup\n private var peers: [UUID: PeerAddress] = [:]\n private let peersLock = NSLock()\n \n private var gossipTimer: RepeatedTask?\n private let gossipInterval: TimeAmount = .seconds(3)\n \n private let logger: Logger\n \n public init(nodeId: UUID, eventLoopGroup: MultiThreadedEventLoopGroup) {\n self.nodeId = nodeId\n self.eventLoopGroup = eventLoopGroup\n \n var logger = Logger(label: \"com.distributed.discovery\")\n logger.logLevel = .debug\n self.logger = logger\n }\n \n /// Register a new peer in the cluster\n public func registerPeer(_ peer: PeerAddress) {\n peersLock.lock()\n defer { peersLock.unlock() }\n \n let existingPeer = peers[peer.nodeId]\n peers[peer.nodeId] = peer\n \n if existingPeer == nil {\n logger.info(\"Registered new peer: \\(peer.nodeId) at \\(peer.address):\\(peer.port)\")\n } else {\n logger.debug(\"Updated peer: \\(peer.nodeId)\")\n }\n }\n \n /// Remove peer from cluster\n public func removePeer(nodeId: UUID) {\n peersLock.lock()\n defer { peersLock.unlock() }\n \n if peers.removeValue(forKey: nodeId) != nil {\n logger.info(\"Removed peer: \\(nodeId)\")\n }\n }\n \n /// Get all known peers\n public func getAllPeers() -> [PeerAddress] {\n peersLock.lock()\n defer { peersLock.unlock() }\n return Array(peers.values)\n }\n \n /// Get current leader if known\n public func getLeader() -> PeerAddress? {\n peersLock.lock()\n defer { peersLock.unlock() }\n return peers.values.first { $0.isLeader }\n }\n \n /// Start gossip protocol for peer information exchange\n public func startGossip(existingPeers: [PeerAddress]) {\n for peer in existingPeers {\n registerPeer(peer)\n }\n \n gossipTimer = eventLoopGroup.next().scheduleRepeatedTask(\n initialDelay: gossipInterval,\n delay: gossipInterval\n ) { [weak self] _ in\n Task {\n await self?.performGossipRound()\n }\n }\n \n logger.info(\"Gossip protocol started\")\n }\n \n private func performGossipRound() async {\n peersLock.lock()\n let currentPeers = Array(peers.values)\n peersLock.unlock()\n \n guard currentPeers.count > 1 else { return }\n \n // Select random peers for gossip (typically 3)\n let selectedPeers = Array(currentPeers.shuffled().prefix(3))\n \n for peer in selectedPeers where peer.nodeId != nodeId {\n do {\n let gossipMessage = GossipMessage(\n senderId: nodeId,\n peerList: currentPeers.map { PeerInfo(from: $0) },\n timestamp: Date()\n )\n \n // Send gossip and merge received peers\n let response = try await sendGossip(\n to: peer.address,\n port: peer.consensusPort,\n message: gossipMessage\n )\n \n await mergePeerInfo(response.peerList)\n } catch {\n logger.debug(\"Gossip exchange with \\(peer.nodeId) failed: \\(error)\")\n }\n }\n }\n \n private func mergePeerInfo(_ incomingPeers: [PeerInfo]) async {\n peersLock.lock()\n \n for peerInfo in incomingPeers {\n let peer = PeerAddress(\n nodeId: peerInfo.nodeId,\n address: peerInfo.address,\n port: peerInfo.port,\n consensusPort: peerInfo.consensusPort,\n isLeader: peerInfo.isLeader\n )\n \n // Merge: keep newer information\n if let existing = peers[peer.nodeId] {\n if peerInfo.timestamp > Date() { // Simplified comparison\n peers[peer.nodeId] = peer\n }\n } else {\n peers[peer.nodeId] = peer\n }\n }\n \n peersLock.unlock()\n }\n}\n```\n\n---\n\n## 3. Consensus Algorithms\n\n### 3.1 Consensus Protocol Selection\n\nFor our distributed key-value store, we implement the Raft consensus algorithm due to its understandability and practical implementability. Raft provides:\n\n- **Leader election**: Automatic leader selection with failover\n- **Log replication**: Consistent log across all replicas\n- **Safety**: Guarantees state machine safety despite node failures\n\nWe chose Raft over Paxos because its leader-based approach simplifies the implementation while maintaining equivalent guarantees. The algorithm's structure maps well to our Apple Silicon deployment, where we can leverage the efficiency cores for I/O-bound consensus message handling while using performance cores for log computation.\n\n### 3.2 Raft Implementation\n\n```swift\nimport Foundation\nimport NIO\n\n/// Raft server states\npublic enum ServerState: String, Codable, Sendable {\n case follower\n case candidate\n case leader\n}\n\n/// Log entry structure\npublic struct LogEntry: Codable, Sendable {\n public let term: UInt64\n public let index: UInt64\n public let command: Command\n public let timestamp: Date\n \n public init(term: UInt64, index: UInt64, command: Command, timestamp: Date = Date()) {\n self.term = term\n self.index = index\n self.command = command\n self.timestamp = timestamp\n }\n}\n\n/// Command types for our key-value store\npublic enum Command: Codable, Sendable {\n case put(key: String, value: Data)\n case delete(key: String)\n case set(key: String, value: Data, ttl: TimeInterval?)\n}\n\n/// Raft consensus module implementation\npublic final class RaftConsensusModule: @unchecked Sendable {\n private let nodeId: UUID\n private let nodeAddress: String\n private let port: Int\n \n private var currentTerm: UInt64 = 0\n private var votedFor: UUID?\n private var state: ServerState = .follower\n private var log: [LogEntry] = []\n \n // Leader-specific state\n private var nextIndex: [UUID: UInt64] = [:]\n private var matchIndex: [UUID: UInt64] = [:]\n \n // Leader election\n private var lastHeartbeat: Date = Date()\n private var electionTimer: RepeatedTask?\n private var heartbeatTimer: RepeatedTask?\n \n // Configuration\n private let electionTimeoutMin: TimeAmount = .milliseconds(150)\n private let electionTimeoutMax: TimeAmount = .milliseconds(300)\n private let heartbeatInterval: TimeAmount = .milliseconds(50)\n \n // Dependencies\n private let stateMachine: KeyValueStateMachine\n private let eventLoopGroup: MultiThreadedEventLoopGroup\n private let transport: NetworkTransport\n private let logger: Logger\n \n private let stateLock = NSLock()\n \n public init(\n nodeId: UUID,\n nodeAddress: String,\n port: Int,\n stateMachine: KeyValueStateMachine,\n eventLoopGroup: MultiThreadedEventLoopGroup\n ) throws {\n self.nodeId = nodeId\n self.nodeAddress = nodeAddress\n self.port = port\n self.stateMachine = stateMachine\n self.eventLoopGroup = eventLoopGroup\n \n var logger = Logger(label: \"com.distributed.raft.\\(nodeId)\")\n logger.logLevel = .debug\n self.logger = logger\n \n // Load persisted state\n try loadPersistedState()\n }\n \n public func initialize() async throws {\n startElectionTimer()\n logger.info(\"Raft module initialized, term: \\(currentTerm)\")\n }\n \n // MARK: - Server State Management\n \n private func becomeFollower(term: UInt64) {\n stateLock.lock()\n defer { stateLock.unlock() }\n \n state = .follower\n currentTerm = term\n votedFor = nil\n \n stopHeartbeatTimer()\n startElectionTimer()\n \n logger.info(\"Became follower, term: \\(term)\")\n }\n \n private func becomeCandidate() {\n stateLock.lock()\n defer { stateLock.unlock() }\n \n state = .candidate\n currentTerm += 1\n votedFor = nodeId\n \n // Reset election timeout\n resetElectionTimer()\n \n logger.info(\"Became candidate for term: \\(currentTerm)\")\n }\n \n private func becomeLeader() {\n stateLock.lock()\n defer { stateLock.unlock() }\n \n state = .leader\n \n // Initialize leader state\n let lastLogIndex = log.last?.index ?? 0\n for peerId in getAllPeerIds() {\n nextIndex[peerId] = lastLogIndex + 1\n matchIndex[peerId] = 0\n }\n \n stopElectionTimer()\n startHeartbeatTimer()\n \n logger.info(\"Became leader for term: \\(currentTerm)\")\n }\n \n // MARK: - Election Logic\n \n private func startElectionTimer() {\n stopElectionTimer()\n \n let eventLoop = eventLoopGroup.next()\n let timeout = UInt64.random(in: electionTimeoutMin.nanoseconds...electionTimeoutMax.nanoseconds)\n \n electionTimer = eventLoop.scheduleTask(deadline: .now() + .nanoseconds(Int(timeout))) { [weak self] in\n Task {\n await self?.handleElectionTimeout()\n }\n }\n }\n \n private func resetElectionTimer() {\n startElectionTimer()\n }\n \n private func stopElectionTimer() {\n electionTimer?.cancel()\n electionTimer = nil\n }\n \n private func handleElectionTimeout() async {\n stateLock.lock()\n let currentState = state\n stateLock.unlock()\n \n guard currentState != .leader else { return }\n \n logger.debug(\"Election timeout, starting election\")\n becomeCandidate()\n \n // Request votes from all peers\n let voteGranted = await requestVotes()\n \n stateLock.lock()\n let stillCandidate = state == .candidate\n stateLock.unlock()\n \n if voteGranted && stillCandidate {\n becomeLeader()\n } else {\n becomeFollower(term: currentTerm)\n }\n }\n \n private func requestVotes() async -> Bool {\n let peers = getAllPeers()\n var voteCount = 1 // Vote for self\n \n let lastLogIndex = log.last?.index ?? 0\n let lastLogTerm = log.last?.term ?? 0\n \n let results = await withTaskGroup(of: (UUID, Bool).self) { group in\n for peer in peers where peer.nodeId != nodeId {\n group.addTask {\n do {\n let response = try await self.sendRequestVote(\n to: peer,\n term: self.currentTerm,\n lastLogIndex: lastLogIndex,\n lastLogTerm: lastLogTerm\n )\n return (peer.nodeId, response.voteGranted)\n } catch {\n self.logger.debug(\"RequestVote to \\(peer.nodeId) failed: \\(error)\")\n return (peer.nodeId, false)\n }\n }\n }\n \n var results: [(UUID, Bool)] = []\n for await result in group {\n results.append(result)\n }\n return results\n }\n \n for (_, voteGranted) in results where voteGranted {\n voteCount += 1\n }\n \n // Majority check\n let majority = (peers.count + 1) / 2 + 1\n return voteCount >= majority\n }\n \n private func sendRequestVote(\n to peer: PeerAddress,\n term: UInt64,\n lastLogIndex: UInt64,\n lastLogTerm: UInt64\n ) async throws -> RequestVoteResponse {\n let request = RequestVoteRequest(\n term: term,\n candidateId: nodeId,\n lastLogIndex: lastLogIndex,\n lastLogTerm: lastLogTerm\n )\n \n // Implementation would use network transport\n // This is a simplified representation\n return try await transport.sendRequestVote(to: peer, request: request)\n }\n \n // MARK: - Log Replication\n \n /// Append a new entry to the log (client write path)\n public func appendEntry(command: Command) async throws -> UInt64 {\n stateLock.lock()\n let isLeader = state == .leader\n stateLock.unlock()\n \n guard isLeader else {\n throw RaftError.notLeader\n }\n \n stateLock.lock()\n let entry = LogEntry(\n term: currentTerm,\n index: (log.last?.index ?? 0) + 1,\n command: command\n )\n log.append(entry)\n let entryIndex = entry.index\n stateLock.unlock()\n \n // Replicate to followers\n await replicateToFollowers()\n \n // Apply to state machine if committed\n await checkAndApplyCommitted()\n \n return entryIndex\n }\n \n private func replicateToFollowers() async {\n let peers = getAllPeers()\n \n await withTaskGroup(of: Void.self) { group in\n for peer in peers {\n group.addTask {\n await self.replicateLog(to: peer)\n }\n }\n }\n }\n \n private func replicateLog(to peer: PeerAddress) async {\n stateLock.lock()\n let nextIdx = nextIndex[peer.nodeId] ?? 1\n let currentTerm = self.currentTerm\n stateLock.unlock()\n \n stateLock.lock()\n let entries = Array(log.drop { $0.index >= nextIdx })\n stateLock.unlock()\n \n do {\n let response = try await sendAppendEntries(\n to: peer,\n term: currentTerm,\n prevLogIndex: nextIdx - 1,\n prevLogTerm: log.first { $0.index == nextIdx - 1 }?.term ?? 0,\n entries: entries,\n leaderCommit: getCommitIndex()\n )\n \n // Handle success or failure\n if response.success {\n stateLock.lock()\n nextIndex[peer.nodeId] = nextIdx + UInt64(entries.count)\n matchIndex[peer.nodeId] = nextIdx + UInt64(entries.count) - 1\n stateLock.unlock()\n } else {\n // Decrement nextIndex and retry\n stateLock.lock()\n nextIndex[peer.nodeId] = max(1, nextIndex[peer.nodeId] ?? 1 - 1)\n stateLock.unlock()\n \n // Retry with decremented index\n await replicateLog(to: peer)\n }\n } catch {\n logger.debug(\"ReplicateLog to \\(peer.nodeId) failed: \\(error)\")\n }\n }\n \n private func getCommitIndex() -> UInt64 {\n // Simple majority-based commit index calculation\n // In production, use a proper algorithm\n stateLock.lock()\n defer { stateLock.unlock() }\n \n guard !log.isEmpty else { return 0 }\n \n var indices = log.map { $0.index }\n indices.sort()\n \n let majorityIndex = (getAllPeers().count + 1) / 2\n guard indices.count >= majorityIndex else { return 0 }\n \n return indices[majorityIndex - 1]\n }\n \n // MARK: - Heartbeat\n \n private func startHeartbeatTimer() {\n let eventLoop = eventLoopGroup.next()\n \n heartbeatTimer = eventLoop.scheduleRepeatedTask(\n initialDelay: heartbeatInterval,\n delay: heartbeatInterval\n ) { [weak self] _ in\n Task {\n await self?.sendHeartbeats()\n }\n }\n }\n \n private func stopHeartbeatTimer() {\n heartbeatTimer?.cancel()\n heartbeatTimer = nil\n }\n \n private func sendHeartbeats() async {\n let peers = getAllPeers()\n \n await withTaskGroup(of: Void.self) { group in\n for peer in peers {\n group.addTask {\n await self.sendAppendEntries(to: peer)\n }\n }\n }\n }\n \n // MARK: - Persistence\n \n private func loadPersistedState() throws {\n // Load term, votedFor, and log from disk\n // Implementation depends on storage backend\n }\n \n private func savePersistedState() throws {\n // Persist term, votedFor, and log to disk\n }\n}\n```\n\n### 3.3 Consensus Optimization for Apple Silicon\n\nWe can leverage Apple Silicon's capabilities for consensus performance:\n\n```swift\nimport Foundation\nimport Accelerate\n\n/// Performance optimizations for Raft on Apple Silicon\npublic final class RaftOptimizations {\n \n /// Batch process log entries using SIMD for performance\n public static func batchValidateEntries(\n _ entries: [LogEntry],\n expectedTerm: UInt64\n ) -> [Bool] {\n var results = [Bool](repeating: false, count: entries.count)\n \n // Use vectorized operations for term validation\n var expectedTerms = [UInt64](repeating: expectedTerm, count: entries.count)\n var actualTerms = entries.map { $0.term }\n \n // Compare terms in batches\n let count = vDSP_Length(entries.count)\n vDSP_vcmpEQ(actualTerms, 1, expectedTerms, 1, &results, count)\n \n return results\n }\n \n /// Parallel log compaction using Grand Central Dispatch\n public static func compactLog(\n inout log: [LogEntry],\n keepFromIndex: UInt64\n ) {\n // Filter log entries to keep\n let keptEntries = log.filter { $0.index >= keepFromIndex }\n \n // Parallel hash calculation for integrity verification\n let dispatchQueue = DispatchQueue(label: \"com.distributed.raft.compaction\", attributes: .concurrent)\n \n var hashResults = [Int: Data]()\n let group = DispatchGroup()\n \n let chunkSize = max(1, keptEntries.count / ProcessInfo.processInfo.activeProcessorCount)\n \n for (chunkIndex", + "reasoning": "The user wants a comprehensive 5000-word technical guide about building a distributed system on Apple Silicon. The guide should cover architecture design, networking, consensus algorithms, fault tolerance, and monitoring. It should use code examples and diagrams described in text.\n\nI need to produce a long, detailed technical guide. The user wants immediate start, so I should just start writing. No need for any meta discussion.\n\nLet me plan the structure:\n\n1. Introduction - overview of distributed systems on Apple Silicon\n2. Architecture Design\n - Apple Silicon specifics (ARM, unified memory, Neural Engine, etc.)\n - System architecture diagram (text description)\n - Component design\n3. Networking\n - Network topology\n - gRPC and protocol buffers\n - Service discovery\n - Code examples for networking\n4. Consensus Algorithms\n - Overview of consensus (Raft, Paxos)\n - Implementing Raft on Apple Silicon\n - Code examples for consensus\n5. Fault Tolerance\n - Replication strategies\n - Leader election\n - Failure detection\n - Code examples for fault tolerance\n6. Monitoring\n - Metrics collection\n - Logging\n - Distributed tracing\n - Code examples for monitoring\n7. Conclusion\n\nI need to make sure I hit around 5000 words. Let me write in detail, with substantial content for each section.\n\nLet me start writing:\n", + "tool_calls": [], + "finish_reason": "length", + "chunks": 7705 + } + ] +} \ No newline at end of file diff --git a/benchmark_minmax.py b/benchmark_minmax.py new file mode 100644 index 00000000..9b423ebb --- /dev/null +++ b/benchmark_minmax.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +""" +Comprehensive benchmark for vllm-mlx + MiniMax-M2.5 on M3 Ultra. + +Measures 6 dimensions that map to our optimization tiers: + 1. TTFT (Time To First Token) — prefill speed + 2. Decode throughput (tok/s) — generation speed + 3. Prefix cache effectiveness — multi-turn speedup + 4. Tool calling latency + correctness — agentic reliability + 5. Reasoning separation — thinking vs content split + 6. Long context stability — sustained generation without crash + +Usage: + # Start server first: + python -m vllm_mlx.cli serve --port 8000 ... + + # Run all benchmarks: + python benchmark_minmax.py + + # Run specific test: + python benchmark_minmax.py --test ttft + python benchmark_minmax.py --test decode + python benchmark_minmax.py --test prefix_cache + python benchmark_minmax.py --test tool_call + python benchmark_minmax.py --test reasoning + python benchmark_minmax.py --test long_gen + + # Save results to JSON: + python benchmark_minmax.py --output results.json +""" + +import argparse +import json +import statistics +import time + +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY") +MODEL = None # auto-detected + + +def detect_model(): + """Auto-detect the model name from the server.""" + global MODEL + try: + models = client.models.list() + MODEL = models.data[0].id + print(f"Detected model: {MODEL}") + except Exception: + MODEL = "MiniMax-M2.5-MLX-4bit" + print(f"Could not detect model, using default: {MODEL}") + + +def stream_and_measure(messages, max_tokens=512, temperature=0.7, tools=None): + """Core measurement function. Returns detailed metrics.""" + kwargs = dict( + model=MODEL, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + stream=True, + stream_options={"include_usage": True}, + ) + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + + t_start = time.perf_counter() + response = client.chat.completions.create(**kwargs) + + t_first_token = None + chunks = 0 + content_parts = [] + reasoning_parts = [] + tool_calls_received = [] + usage = None + finish_reason = None + + for chunk in response: + if chunk.usage: + usage = chunk.usage + + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + fr = chunk.choices[0].finish_reason + if fr: + finish_reason = fr + + has_content = delta.content is not None and len(delta.content) > 0 + has_reasoning = hasattr(delta, "reasoning") and delta.reasoning + has_tool = hasattr(delta, "tool_calls") and delta.tool_calls + + if has_content or has_reasoning or has_tool: + if t_first_token is None: + t_first_token = time.perf_counter() + chunks += 1 + + if has_content: + content_parts.append(delta.content) + if has_reasoning: + reasoning_parts.append(delta.reasoning) + if has_tool: + tool_calls_received.extend(delta.tool_calls) + + t_end = time.perf_counter() + + total_time = t_end - t_start + ttft = (t_first_token - t_start) if t_first_token else total_time + decode_time = (t_end - t_first_token) if t_first_token else 0 + + prompt_tokens = usage.prompt_tokens if usage else 0 + completion_tokens = usage.completion_tokens if usage else chunks + + decode_tps = completion_tokens / decode_time if decode_time > 0 else 0 + + return { + "ttft": round(ttft, 3), + "total_time": round(total_time, 3), + "decode_time": round(decode_time, 3), + "decode_tps": round(decode_tps, 1), + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "content": "".join(content_parts), + "reasoning": "".join(reasoning_parts), + "tool_calls": tool_calls_received, + "finish_reason": finish_reason, + "chunks": chunks, + } + + +# --------------------------------------------------------------------------- +# Test 1: TTFT — Time To First Token (prefill speed) +# --------------------------------------------------------------------------- +def test_ttft(): + """Measure TTFT across different prompt sizes.""" + print("\n" + "=" * 70) + print(" TEST 1: TTFT (Time To First Token)") + print("=" * 70) + + prompts = [ + ("Short (50 tok)", "What is 2+2? Answer in one sentence."), + ( + "Medium (500 tok)", + "You are a helpful AI assistant. " + + "Please explain the complete history of computing, " + + "from Charles Babbage to modern GPUs. " * 15 + + "\nSummarize in 3 bullet points.", + ), + ( + "Long (2K tok)", + "You are a senior software architect. " + + "Review this code and explain every design pattern used:\n" + + '```python\nclass AbstractFactory:\n def create(self): pass\n```\n' * 40 + + "\nList all patterns found.", + ), + ] + + results = [] + for label, prompt in prompts: + # Warm-up run (fills cache) + _ = stream_and_measure( + [{"role": "user", "content": prompt}], max_tokens=32 + ) + # Measured run + m = stream_and_measure( + [{"role": "user", "content": prompt}], max_tokens=32 + ) + results.append({"label": label, **m}) + print( + f" {label:20s} TTFT={m['ttft']:.3f}s " + f"prompt_tok={m['prompt_tokens']} " + f"decode={m['decode_tps']:.1f} tok/s" + ) + + return {"test": "ttft", "results": results} + + +# --------------------------------------------------------------------------- +# Test 2: Decode throughput (sustained generation speed) +# --------------------------------------------------------------------------- +def test_decode(): + """Measure sustained decode speed at different output lengths.""" + print("\n" + "=" * 70) + print(" TEST 2: Decode Throughput") + print("=" * 70) + + targets = [128, 512, 2048] + results = [] + + for max_tok in targets: + m = stream_and_measure( + [ + { + "role": "user", + "content": f"Write exactly {max_tok} words about the history of artificial intelligence. Be detailed.", + } + ], + max_tokens=max_tok, + temperature=0.7, + ) + results.append({"max_tokens": max_tok, **m}) + print( + f" max_tokens={max_tok:5d} " + f"generated={m['completion_tokens']:5d} tok " + f"decode={m['decode_tps']:.1f} tok/s " + f"total={m['total_time']:.1f}s" + ) + + return {"test": "decode_throughput", "results": results} + + +# --------------------------------------------------------------------------- +# Test 3: Prefix cache effectiveness (multi-turn simulation) +# --------------------------------------------------------------------------- +def test_prefix_cache(): + """Simulate multi-turn conversation to measure cache hit benefits.""" + print("\n" + "=" * 70) + print(" TEST 3: Prefix Cache (Multi-Turn)") + print("=" * 70) + + system_prompt = ( + "You are an expert AI coding assistant with deep knowledge of Python, " + "TypeScript, Rust, and system design. You always provide working code " + "examples and explain your reasoning step by step. You have access to " + "tools for file manipulation, web search, and code execution. " + "Always consider edge cases, error handling, and performance implications." + ) + + # Build up conversation turns + turns = [ + "How do I implement a thread-safe LRU cache in Python?", + "Can you add TTL (time-to-live) expiration to that cache?", + "Now make it work with async/await. Show the full implementation.", + "What are the performance characteristics? Add benchmarks.", + ] + + messages = [{"role": "system", "content": system_prompt}] + results = [] + + for i, user_msg in enumerate(turns): + messages.append({"role": "user", "content": user_msg}) + + m = stream_and_measure(messages, max_tokens=256, temperature=0.7) + + results.append({"turn": i + 1, "user_msg": user_msg[:60], **m}) + print( + f" Turn {i+1}: TTFT={m['ttft']:.3f}s " + f"prompt={m['prompt_tokens']} tok " + f"decode={m['decode_tps']:.1f} tok/s" + ) + + # Add assistant response to history for next turn + messages.append({"role": "assistant", "content": m["content"][:500]}) + + # Key metric: TTFT should decrease or stay flat despite growing context + ttfts = [r["ttft"] for r in results] + print(f"\n TTFT trend: {' → '.join(f'{t:.3f}s' for t in ttfts)}") + if len(ttfts) >= 2: + ratio = ttfts[-1] / ttfts[0] if ttfts[0] > 0 else 0 + print(f" Turn 4 / Turn 1 TTFT ratio: {ratio:.2f}x") + + return {"test": "prefix_cache", "results": results} + + +# --------------------------------------------------------------------------- +# Test 4: Tool calling latency + correctness +# --------------------------------------------------------------------------- +def test_tool_call(): + """Measure tool call generation speed and JSON validity.""" + print("\n" + "=" * 70) + print(" TEST 4: Tool Calling (Latency + Correctness)") + print("=" * 70) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name, e.g. San Francisco", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_web", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "num_results": { + "type": "integer", + "description": "Number of results", + }, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "run_python", + "description": "Execute Python code and return the output", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + }, + }, + "required": ["code"], + }, + }, + }, + ] + + test_prompts = [ + ("Simple tool call", "What's the weather in Tokyo?"), + ("Multi-arg tool call", "Search for 'vLLM MLX benchmarks' and show me 5 results"), + ( + "Code execution", + "Calculate the fibonacci sequence up to n=20 using Python", + ), + ( + "Tool selection", + "I need to know the weather in Paris and also search for restaurants there", + ), + ] + + results = [] + correct = 0 + total = 0 + + for label, prompt in test_prompts: + total += 1 + m = stream_and_measure( + [{"role": "user", "content": prompt}], + max_tokens=1024, + temperature=0.3, + tools=tools, + ) + + # Check correctness + has_tool = m["finish_reason"] == "tool_calls" or len(m["tool_calls"]) > 0 + json_valid = True + + if has_tool and m["tool_calls"]: + for tc in m["tool_calls"]: + if hasattr(tc, "function") and tc.function: + try: + args = tc.function.arguments + if isinstance(args, str): + json.loads(args) + except (json.JSONDecodeError, AttributeError): + json_valid = False + break + + is_correct = has_tool and json_valid + if is_correct: + correct += 1 + + tool_names = [] + if m["tool_calls"]: + for tc in m["tool_calls"]: + if hasattr(tc, "function") and tc.function: + tool_names.append(tc.function.name) + + results.append( + { + "label": label, + "has_tool_call": has_tool, + "json_valid": json_valid, + "correct": is_correct, + "tool_names": tool_names, + **m, + } + ) + + status = "OK" if is_correct else "FAIL" + tools_str = ", ".join(tool_names) if tool_names else "none" + print( + f" [{status}] {label:25s} TTFT={m['ttft']:.3f}s " + f"total={m['total_time']:.2f}s tools=[{tools_str}]" + ) + + accuracy = correct / total * 100 if total > 0 else 0 + print(f"\n Tool call accuracy: {correct}/{total} ({accuracy:.0f}%)") + + return { + "test": "tool_call", + "accuracy_pct": accuracy, + "correct": correct, + "total": total, + "results": results, + } + + +# --------------------------------------------------------------------------- +# Test 5: Reasoning separation (thinking vs content) +# --------------------------------------------------------------------------- +def test_reasoning(): + """Test that reasoning content is properly separated from final answer.""" + print("\n" + "=" * 70) + print(" TEST 5: Reasoning Separation") + print("=" * 70) + + prompts = [ + ("Math", "What is 17 * 23? Think step by step."), + ("Logic", "If all roses are flowers and some flowers fade quickly, can we conclude all roses fade quickly? Explain your reasoning."), + ("Code", "What's wrong with this code: `x = [1,2,3]; x[5]`? Think carefully."), + ] + + results = [] + for label, prompt in prompts: + m = stream_and_measure( + [{"role": "user", "content": prompt}], + max_tokens=512, + temperature=0.3, + ) + + has_reasoning = len(m["reasoning"]) > 0 + has_content = len(m["content"]) > 0 + separated = has_reasoning and has_content + + results.append( + { + "label": label, + "has_reasoning": has_reasoning, + "has_content": has_content, + "separated": separated, + "reasoning_chars": len(m["reasoning"]), + "content_chars": len(m["content"]), + **m, + } + ) + + status = "OK" if separated else ("PARTIAL" if has_reasoning or has_content else "FAIL") + print( + f" [{status:7s}] {label:10s} " + f"reasoning={len(m['reasoning']):5d} chars " + f"content={len(m['content']):5d} chars " + f"TTFT={m['ttft']:.3f}s" + ) + + separated_count = sum(1 for r in results if r["separated"]) + print(f"\n Properly separated: {separated_count}/{len(results)}") + + return {"test": "reasoning", "results": results} + + +# --------------------------------------------------------------------------- +# Test 6: Long generation stability +# --------------------------------------------------------------------------- +def test_long_gen(): + """Test sustained long generation without crash.""" + print("\n" + "=" * 70) + print(" TEST 6: Long Generation Stability") + print("=" * 70) + + m = stream_and_measure( + [ + { + "role": "user", + "content": ( + "Write a comprehensive 5000-word technical guide about building " + "a distributed system on Apple Silicon. Cover: architecture design, " + "networking, consensus algorithms, fault tolerance, and monitoring. " + "Use code examples and diagrams described in text. " + "Start writing immediately." + ), + } + ], + max_tokens=8192, + temperature=0.7, + ) + + completed = m["finish_reason"] in ("stop", "length") + print(f" Completed: {completed} finish_reason={m['finish_reason']}") + print( + f" Generated: {m['completion_tokens']} tokens in {m['total_time']:.1f}s" + ) + print(f" Decode speed: {m['decode_tps']:.1f} tok/s") + print(f" Output length: {len(m['content'])} chars") + + return { + "test": "long_generation", + "completed": completed, + **m, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def print_summary(all_results): + """Print a compact summary table.""" + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + + for r in all_results: + test = r["test"] + + if test == "ttft": + ttfts = [x["ttft"] for x in r["results"]] + print(f" TTFT: {' / '.join(f'{t:.3f}s' for t in ttfts)} (short/med/long)") + + elif test == "decode_throughput": + tps_list = [x["decode_tps"] for x in r["results"]] + print(f" Decode tok/s: {' / '.join(f'{t:.1f}' for t in tps_list)} (128/512/2048 tokens)") + + elif test == "prefix_cache": + ttfts = [x["ttft"] for x in r["results"]] + ratio = ttfts[-1] / ttfts[0] if ttfts[0] > 0 else 0 + print(f" Prefix cache: Turn1={ttfts[0]:.3f}s → Turn4={ttfts[-1]:.3f}s (ratio={ratio:.2f}x)") + + elif test == "tool_call": + print(f" Tool calling: {r['correct']}/{r['total']} correct ({r['accuracy_pct']:.0f}%)") + avg_time = statistics.mean(x["total_time"] for x in r["results"]) + print(f" avg latency={avg_time:.2f}s") + + elif test == "reasoning": + sep = sum(1 for x in r["results"] if x["separated"]) + print(f" Reasoning: {sep}/{len(r['results'])} properly separated") + + elif test == "long_generation": + print( + f" Long gen: {'PASS' if r['completed'] else 'FAIL'} " + f"{r['completion_tokens']} tok @ {r['decode_tps']:.1f} tok/s" + ) + + print("=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="vllm-mlx MiniMax benchmark") + parser.add_argument( + "--test", + choices=["ttft", "decode", "prefix_cache", "tool_call", "reasoning", "long_gen", "all"], + default="all", + help="Which test to run (default: all)", + ) + parser.add_argument("--output", type=str, help="Save results to JSON file") + args = parser.parse_args() + + detect_model() + + test_map = { + "ttft": test_ttft, + "decode": test_decode, + "prefix_cache": test_prefix_cache, + "tool_call": test_tool_call, + "reasoning": test_reasoning, + "long_gen": test_long_gen, + } + + if args.test == "all": + tests_to_run = list(test_map.keys()) + else: + tests_to_run = [args.test] + + all_results = [] + for test_name in tests_to_run: + try: + result = test_map[test_name]() + all_results.append(result) + except Exception as e: + print(f"\n ERROR in {test_name}: {e}") + all_results.append({"test": test_name, "error": str(e)}) + + print_summary(all_results) + + # Save results + output_file = args.output or f"benchmark_results_{int(time.time())}.json" + + def _serialize(obj): + """Handle non-serializable objects.""" + if hasattr(obj, "__dict__"): + return str(obj) + return repr(obj) + + with open(output_file, "w") as f: + json.dump( + { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "model": MODEL, + "results": all_results, + }, + f, + indent=2, + default=_serialize, + ensure_ascii=False, + ) + print(f"\nResults saved to: {output_file}") + + +if __name__ == "__main__": + main() diff --git a/docs/plans/tier0-pinned-prefix-cache.md b/docs/plans/tier0-pinned-prefix-cache.md new file mode 100644 index 00000000..fa540235 --- /dev/null +++ b/docs/plans/tier0-pinned-prefix-cache.md @@ -0,0 +1,160 @@ +# Tier 0 + Pinned Prefix Cache Implementation Plan + +**Status**: Planned +**Target**: MiniMax-M2.5-MLX-4bit on M3 Ultra 256GB (serving OpenClaw) +**Branch**: `feat-minimax-parser` +**Baseline**: See [README benchmark section](../../README.md#baseline-benchmark--minimax-m25-on-m3-ultra) + +## Context + +After establishing a baseline benchmark for MiniMax-M2.5-MLX-4bit on M3 Ultra 256GB, we're implementing the first round of optimizations: + +- **Tier 0**: GC control, detokenizer caching, grammar error handling (low-risk, high-impact quick wins) +- **Pinned Prefix Cache**: Prevent system prompt eviction under memory pressure + +Goal: measurable improvement in TTFT, throughput stability, and error resilience. Results will be benchmarked against the baseline and packaged as a PR. + +--- + +## 1. GC Control During Generation + +**Problem**: Python's GC runs unpredictably during streaming, causing latency spikes (especially with 120GB+ model). + +**File**: `vllm_mlx/server.py` -- `lifespan()` + `stream_chat_completion()` + +**Implementation**: +- At server startup in `lifespan()`: `gc.set_threshold(100_000, 50, 50)` to dramatically reduce GC frequency (vLLM upstream uses similar approach, see [PR #33575](https://github.com/vllm-project/vllm/pull/33575)) +- In `stream_chat_completion()` and `chat_completion()`: wrap generation with `gc.disable()` / `gc.enable()` + `gc.collect()` after completion +- Add `--gc-control` CLI flag (default: enabled) for easy toggle + +**Key pattern** (in streaming path): +```python +gc.disable() +try: + async for chunk in engine.stream_chat(...): + yield chunk +finally: + gc.enable() + gc.collect() +``` + +**Expected impact**: Eliminates GC-induced latency spikes during generation. Most noticeable with large models (120GB+) where GC cycle scanning is expensive. + +--- + +## 2. Detokenizer Result Caching + +**Problem**: `mlx_lm.stream_generate()` returns text deltas, but `server.py` re-encodes the full accumulated text to count tokens on every chunk. This is O(n^2) over the response. + +**File**: `vllm_mlx/server.py` -- `stream_chat_completion()` + +**Implementation**: +- Track `token_count` incrementally instead of re-encoding each chunk +- Use tokenizer's `encode()` only on the delta text, accumulate count +- This avoids the quadratic re-tokenization pattern + +**Key change**: In the streaming loop, replace any `len(tokenizer.encode(accumulated_text))` with an incrementally updated counter. + +**Reference**: vLLM upstream [PR #20413](https://github.com/vllm-project/vllm/pull/20413) achieves 13.7x speedup with similar approach. + +**Expected impact**: Reduces per-chunk overhead from O(n) to O(1), significant for long generations (2K+ tokens). + +--- + +## 3. Grammar/Schema Error Handling Hardening + +**Problem**: Bad JSON schemas from clients (OpenClaw sends diverse schemas) can crash the guided generation path, potentially killing the server. + +**Files**: +- `vllm_mlx/api/guided.py` -- `json_schema_to_pydantic()` already has try/except but returns None +- `vllm_mlx/server.py` -- Schema extraction and guided generation call + +**Implementation**: +- In `server.py`: When guided generation returns None or raises, fall back to standard generation with a warning (not a 500 error) +- Log the problematic schema at DEBUG level for debugging +- Ensure the `_disconnect_guard` error handler properly catches schema-related errors + +**Reference**: vLLM upstream [PR #30346](https://github.com/vllm-project/vllm/pull/30346) + +**Expected impact**: Server no longer crashes on malformed schemas. Graceful fallback to unconstrained generation. + +--- + +## 4. Pinned Prefix Cache + +**Problem**: OpenClaw's system prompt (~2K tokens) gets evicted under memory pressure, causing repeated re-computation. With MiniMax-M2.5 on 256GB, the first request after eviction adds 0.5-1s TTFT. + +**Files**: +- `vllm_mlx/paged_cache.py` -- `CacheBlock`, `FreeKVCacheBlockQueue.popleft()`, `_maybe_evict_cached_block()` +- `vllm_mlx/prefix_cache.py` -- `PrefixCacheManager`, `BlockAwarePrefixCache` +- `vllm_mlx/server.py` -- New CLI flag + auto-pin logic + +### 4a. CacheBlock pinning (`paged_cache.py`) + +Add `is_pinned: bool = False` to `CacheBlock` dataclass. + +In `FreeKVCacheBlockQueue.popleft()`: skip pinned blocks when popping: +```python +block = self.fake_head.next_free_block +while block is not self.fake_tail and block.is_pinned: + block = block.next_free_block +``` + +In `_maybe_evict_cached_block()`: refuse to evict pinned blocks: +```python +if block.is_pinned: + return False +``` + +### 4b. Pin/unpin API (`paged_cache.py`) + +Add `pin_blocks(block_ids: list[int])` and `unpin_blocks(block_ids: list[int])` methods to `PagedCacheManager`. These set `is_pinned = True/False` on specified blocks. + +### 4c. Prefix-level pinning (`prefix_cache.py`) + +Add `pin_prefix(token_ids: list[int])` to `PrefixCacheManager` / `BlockAwarePrefixCache`. Finds blocks covering the given token prefix and pins them. Add `unpin_prefix(token_ids: list[int])` for cleanup. + +### 4d. Auto-pin system prompt (`server.py`) + +- Add `--pin-system-prompt` CLI flag (default: False) +- When enabled: after first request with a system message, pin the prefix cache blocks for that system prompt +- Track pinned prefix hash to avoid re-pinning on every request + +**Reference**: vLLM upstream [PR #27097](https://github.com/vllm-project/vllm/pull/27097) + +**Expected impact**: Turn4/Turn1 TTFT ratio drops from 2.09x toward 1.0x. System prompt KV cache stays warm indefinitely. + +--- + +## 5. CLI Flags + +**File**: `vllm_mlx/cli.py` + +New flags: +- `--gc-control / --no-gc-control` (default: enabled) +- `--pin-system-prompt` (default: False) + +--- + +## File Change Summary + +| File | Changes | +|------|---------| +| `vllm_mlx/server.py` | GC control in lifespan + generation paths, schema fallback hardening, auto-pin system prompt | +| `vllm_mlx/paged_cache.py` | `is_pinned` field, skip-pinned eviction, pin/unpin API | +| `vllm_mlx/prefix_cache.py` | `pin_prefix()` / `unpin_prefix()` methods | +| `vllm_mlx/cli.py` | `--gc-control`, `--pin-system-prompt` flags | +| `vllm_mlx/api/guided.py` | Minor: ensure schema errors logged at DEBUG | + +--- + +## Verification Plan + +1. **Correctness**: Run the server with all new flags, send test requests (tool calling, reasoning, long gen) to confirm no regressions +2. **Benchmark**: Run `benchmark_minmax.py` with same parameters as baseline +3. **Compare**: Verify improvements in: + - TTFT (should improve with GC control + prefix pinning) + - Throughput consistency (GC control reduces spikes) + - Prefix cache turn ratio (pinning should keep Turn4/Turn1 <= 1.5x) + - Error resilience (bad schemas no longer crash) +4. **PR**: Create PR with baseline vs post-optimization comparison table diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 2ba8b75e..ac134ad0 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -38,6 +38,7 @@ vllm-mlx serve [options] | `--paged-cache-block-size` | Tokens per cache block | 64 | | `--max-cache-blocks` | Maximum cache blocks | 1000 | | `--max-num-seqs` | Max concurrent sequences | 256 | +| `--gpu-memory-utilization` | Fraction of device memory for Metal allocation limit (0.0-1.0) | 0.90 | | `--default-temperature` | Default temperature when not specified in request | None | | `--default-top-p` | Default top_p when not specified in request | None | | `--reasoning-parser` | Parser for reasoning models (`qwen3`, `deepseek_r1`) | None | @@ -88,6 +89,11 @@ vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ # With API key authentication vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --api-key your-secret-key +# Large models (200GB+) — raise memory limit to avoid cache thrashing +vllm-mlx serve mlx-community/Qwen3.5-397B-A17B-nvfp4 \ + --continuous-batching \ + --gpu-memory-utilization 0.95 + # Production setup with security options vllm-mlx serve mlx-community/Qwen3-4B-4bit \ --api-key your-secret-key \ diff --git a/pyproject.toml b/pyproject.toml index fa92e960..b036c78e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,10 @@ vision = [ "torch>=2.3.0", "torchvision>=0.18.0", ] +# Guided decoding with outlines for structured JSON output +guided = [ + "outlines[mlxlm]>=1.0.0", +] # Audio dependencies for TTS/STT (mlx-audio) audio = [ "mlx-audio>=0.2.9", diff --git a/vllm_mlx/api/guided.py b/vllm_mlx/api/guided.py new file mode 100644 index 00000000..454a7944 --- /dev/null +++ b/vllm_mlx/api/guided.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Guided generation for structured JSON output using outlines. + +This module provides constrained decoding for JSON schema enforcement, +ensuring model outputs strictly adhere to specified schemas. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# Check for outlines availability +try: + import mlx_lm + import outlines + + HAS_OUTLINES = True +except ImportError: + HAS_OUTLINES = False + outlines = None + mlx_lm = None + + +def is_guided_available() -> bool: + """Check if guided generation with outlines is available.""" + return HAS_OUTLINES + + +def json_schema_to_pydantic(schema: dict[str, Any]) -> type | None: + """ + Convert a JSON schema to a Pydantic model dynamically. + + Args: + schema: JSON schema dict + + Returns: + Dynamically created Pydantic model class, or None if conversion fails + """ + try: + from pydantic import create_model + + # Extract properties from schema + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + # Build field definitions for Pydantic + field_definitions = {} + + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "null": type(None), + } + + for prop_name, prop_spec in properties.items(): + prop_type = prop_spec.get("type", "string") + + # Handle array type + if prop_type == "array": + items_type = prop_spec.get("items", {}).get("type", "string") + inner_type = type_mapping.get(items_type, str) + python_type = list[inner_type] + # Handle object type (nested) + elif prop_type == "object": + # For nested objects, use dict + python_type = dict + else: + python_type = type_mapping.get(prop_type, str) + + # Make optional if not required + if prop_name not in required: + python_type = python_type | None + default = None + else: + default = ... + + field_definitions[prop_name] = (python_type, default) + + # Create the model dynamically + model = create_model("DynamicModel", **field_definitions) + return model + + except Exception as e: + logger.warning(f"Failed to convert JSON schema to Pydantic: {e}") + logger.debug(f"Problematic schema: {schema}") + return None + + +class GuidedGenerator: + """ + Guided generation using outlines for constrained JSON decoding. + + This class wraps an MLX model to provide structured output generation + that guarantees valid JSON matching a specified schema. + """ + + def __init__(self, model, tokenizer): + """ + Initialize the guided generator. + + Args: + model: MLX model instance + tokenizer: Tokenizer instance + """ + if not HAS_OUTLINES: + raise ImportError( + "outlines is required for guided generation. " + "Install with: pip install 'vllm-mlx[guided]'" + ) + + self._model = model + self._tokenizer = tokenizer + self._outlines_model = None + + def _get_outlines_model(self): + """Get or create the outlines model wrapper.""" + if self._outlines_model is None: + self._outlines_model = outlines.from_mlxlm(self._model, self._tokenizer) + return self._outlines_model + + def generate_json( + self, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, + ) -> str: + """ + Generate JSON output constrained to a schema. + + Args: + prompt: Input prompt + json_schema: JSON schema to constrain output + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + JSON string matching the schema + """ + # Convert schema to Pydantic model + pydantic_model = json_schema_to_pydantic(json_schema) + + if pydantic_model is None: + logger.warning( + "Could not convert schema to Pydantic, falling back to raw generation" + ) + return None + + try: + outlines_model = self._get_outlines_model() + + # Generate with schema constraint + result = outlines_model( + prompt, + output_type=pydantic_model, + max_tokens=max_tokens, + ) + + # result is a JSON string, validate and return + return result + + except Exception as e: + logger.error(f"Guided generation failed: {e}") + return None + + def generate_json_object( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + ) -> str: + """ + Generate any valid JSON object. + + Args: + prompt: Input prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + JSON string + """ + try: + from outlines import generate + + outlines_model = self._get_outlines_model() + + # Use regex to constrain to valid JSON + json_regex = r"\{[^{}]*\}" + generator = generate.regex(outlines_model, json_regex) + result = generator(prompt, max_tokens=max_tokens) + + return result + + except Exception as e: + logger.error(f"JSON object generation failed: {e}") + return None + + +def generate_with_schema( + model, + tokenizer, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, +) -> str | None: + """ + Convenience function for one-shot guided JSON generation. + + Args: + model: MLX model + tokenizer: Tokenizer + prompt: Input prompt + json_schema: JSON schema + max_tokens: Maximum tokens + temperature: Sampling temperature + + Returns: + JSON string or None if guided generation unavailable/failed + """ + if not HAS_OUTLINES: + return None + + try: + generator = GuidedGenerator(model, tokenizer) + return generator.generate_json( + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + except Exception as e: + logger.error(f"generate_with_schema failed: {e}") + return None diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 1443c167..764159b6 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -21,6 +21,42 @@ from .models import FunctionCall, ResponseFormat, ToolCall +def _is_tool_call_json(obj: dict) -> bool: + """ + Check if a JSON object looks like a tool call. + + A tool call must have: + - "name" key with a string value (function name) + - "arguments" key (the function arguments) + + This prevents false positives where regular JSON like {"name": "John", "age": 30} + would be incorrectly parsed as a tool call. + + Args: + obj: JSON object to check + + Returns: + True if object appears to be a tool call + """ + if not isinstance(obj, dict): + return False + + # Must have both "name" and "arguments" keys + if "name" not in obj or "arguments" not in obj: + return False + + # "name" must be a non-empty string (function name) + if not isinstance(obj["name"], str) or not obj["name"].strip(): + return False + + # "arguments" must be a dict or string + args = obj["arguments"] + if not isinstance(args, (dict, str)): + return False + + return True + + def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: """ Parse raw JSON tool calls from model output. @@ -30,6 +66,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: - Multiple objects separated by commas: {...}, {...} - JSON array: [{...}, {...}] + Only objects with BOTH "name" AND "arguments" keys are considered tool calls. + This prevents false positives with regular JSON objects. + Args: text: Raw model output text @@ -46,7 +85,7 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: try: parsed = json.loads(text) if isinstance(parsed, list) and all( - isinstance(item, dict) and "name" in item for item in parsed + _is_tool_call_json(item) for item in parsed ): return [ {"name": item["name"], "arguments": item.get("arguments", {})} @@ -71,7 +110,8 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: json_str = text[start : i + 1] try: obj = json.loads(json_str) - if isinstance(obj, dict) and "name" in obj: + # Only consider as tool call if it has both "name" AND "arguments" + if _is_tool_call_json(obj): tool_calls.append( {"name": obj["name"], "arguments": obj.get("arguments", {})} ) @@ -524,8 +564,14 @@ def build_json_system_prompt( if format_type == "json_object": return ( - "You must respond with valid JSON only. " - "Do not include any explanation or text outside the JSON object." + "⚠️ JSON OUTPUT REQUIRED ⚠️\n\n" + "You MUST respond with ONLY a valid JSON object.\n\n" + "RULES:\n" + "- Start response with { or [\n" + "- NO text before or after JSON\n" + "- NO thinking or explanations\n" + "- NO markdown code blocks (```)\n" + "- ONLY the raw JSON object" ) if format_type == "json_schema": @@ -533,14 +579,89 @@ def build_json_system_prompt( schema = json_schema_spec.get("schema", {}) name = json_schema_spec.get("name", "response") description = json_schema_spec.get("description", "") + strict = json_schema_spec.get("strict", False) + + # If strict mode is enabled and guided generation is available, + # the server should use guided decoding instead of prompt injection + if strict: + # Return stronger instruction for strict mode + prompt = ( + f"⚠️ STRICT JSON OUTPUT REQUIRED ⚠️\n\n" + f"You MUST respond with ONLY a valid JSON object matching the '{name}' schema.\n" + ) + if description: + prompt += f"Purpose: {description}\n" + prompt += ( + f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" + "STRICT RULES:\n" + "- Start response with {{ or [\n" + "- NO text before or after JSON\n" + "- NO thinking, reasoning, or explanations\n" + "- NO markdown code blocks\n" + "- ONLY the JSON object" + ) + return prompt - prompt = f"You must respond with valid JSON matching the '{name}' schema." + # Standard (non-strict) mode - still strong but less aggressive + prompt = ( + f"⚠️ JSON OUTPUT REQUIRED ⚠️\n\n" + f"Respond with a valid JSON object matching the '{name}' schema.\n" + ) if description: - prompt += f" {description}" + prompt += f"Purpose: {description}\n" prompt += ( - f"\n\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" - "Respond with only the JSON object, no additional text or explanation." + f"\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" + "RULES:\n" + "- Start response with { or [\n" + "- NO text before or after JSON\n" + "- NO markdown code blocks\n" + "- ONLY the JSON object" ) return prompt return None + + +def extract_json_schema_for_guided( + response_format: Optional[Union[ResponseFormat, Dict[str, Any]]] = None, +) -> Optional[Dict[str, Any]]: + """ + Extract JSON schema from response_format for guided generation. + + Returns the schema dict if response_format specifies json_schema type, + otherwise returns None. + + Args: + response_format: ResponseFormat specification + + Returns: + JSON schema dict or None + """ + if response_format is None: + return None + + # Normalize to dict + if isinstance(response_format, ResponseFormat): + rf_dict = {"type": response_format.type, "json_schema": None} + if response_format.json_schema: + rf_dict["json_schema"] = { + "name": response_format.json_schema.name, + "description": response_format.json_schema.description, + "schema": response_format.json_schema.schema_, + "strict": response_format.json_schema.strict, + } + else: + rf_dict = response_format + + format_type = rf_dict.get("type", "text") + + if format_type != "json_schema": + return None + + json_schema_spec = rf_dict.get("json_schema", {}) + schema = json_schema_spec.get("schema", {}) + + if not schema: + return None + + return schema diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index e916ae19..4426a319 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -17,7 +17,8 @@ r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|" r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|" r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|" - r"|||\[PAD\]|\[SEP\]|\[CLS\]" + r"|||\[PAD\]|\[SEP\]|\[CLS\]|" + r"\[e~\[|\]~b\][a-z]*|\]~!b\[" ) @@ -101,6 +102,86 @@ def clean_output_text(text: str) -> str: return text +# Pattern to match ... blocks +THINK_PATTERN = re.compile(r"[\s\S]*?\s*", re.DOTALL) + + +def strip_thinking_tags(text: str) -> str: + """ + Remove ... blocks from model output. + + Used when the client expects pure content (e.g., JSON) without + reasoning blocks that would break parsing. + + Args: + text: Model output that may contain thinking blocks + + Returns: + Text with thinking blocks removed + """ + if not text: + return text + return THINK_PATTERN.sub("", text).strip() + + +def extract_json_from_response(text: str) -> str: + """ + Extract JSON object/array from model response that may contain reasoning text. + + Qwen3 and other reasoning models often output: + "Let me think... reasoning text... {\"result\": 123}" + + This function extracts just the JSON part when present. + + Args: + text: Model output that may contain text before/after JSON + + Returns: + Extracted JSON string if found, otherwise original text + """ + if not text: + return text + + text = text.strip() + + # If already valid JSON, return as-is + if (text.startswith("{") and text.endswith("}")) or ( + text.startswith("[") and text.endswith("]") + ): + return text + + # Try to find JSON object at the end of the response + # Find the last { and match to the end + last_brace = text.rfind("{") + if last_brace != -1 and text.endswith("}"): + potential_json = text[last_brace:] + # Validate it's balanced + depth = 0 + for char in potential_json: + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + return potential_json + + # Try to find JSON array at the end + last_bracket = text.rfind("[") + if last_bracket != -1 and text.endswith("]"): + potential_json = text[last_bracket:] + depth = 0 + for char in potential_json: + if char == "[": + depth += 1 + elif char == "]": + depth -= 1 + if depth == 0: + return potential_json + + # No JSON found, return original + return text + + # ============================================================================= # Model Detection # ============================================================================= diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8ac..dc7a2a7b 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -37,6 +37,13 @@ def serve_command(args): print("Example: --enable-auto-tool-choice --tool-call-parser mistral") sys.exit(1) + # Validate gpu-memory-utilization range + if not (0.0 < args.gpu_memory_utilization <= 1.0): + print( + "Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)" + ) + sys.exit(1) + # Configure server security settings server._api_key = args.api_key server._default_timeout = args.timeout @@ -45,6 +52,13 @@ def serve_command(args): requests_per_minute=args.rate_limit, enabled=True ) + # Configure GC control + gc_control = args.gc_control and not args.no_gc_control + server._gc_control = gc_control + + # Configure system prompt pinning + server._pin_system_prompt = args.pin_system_prompt + # Configure tool calling if args.enable_auto_tool_choice and args.tool_call_parser: server._enable_auto_tool_choice = True @@ -103,10 +117,17 @@ def serve_command(args): print(f" Reasoning: ENABLED (parser: {args.reasoning_parser})") else: print(" Reasoning: Use --reasoning-parser to enable") + print(f" GC control: {'ENABLED' if gc_control else 'DISABLED'}") + if args.pin_system_prompt: + print(" Pin system prompt: ENABLED") print("=" * 60) print(f"Loading model: {args.model}") print(f"Default max tokens: {args.max_tokens}") + if args.draft_model: + print("Speculative decoding: ENABLED") + print(f" Draft model: {args.draft_model}") + print(f" Draft tokens: {args.num_draft_tokens}") # Store MCP config path for FastAPI startup if args.mcp_config: @@ -187,6 +208,9 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + gpu_memory_utilization=args.gpu_memory_utilization, + draft_model=args.draft_model, + num_draft_tokens=args.num_draft_tokens, ) # Start server @@ -681,6 +705,14 @@ def main(): action="store_true", help="Enable continuous batching for multiple concurrent users (slower for single user)", ) + serve_parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.90, + help="Fraction of device memory for Metal allocation limit and emergency " + "cache clear threshold (0.0-1.0, default: 0.90). Increase to 0.95 for " + "large models (200GB+) that need more memory headroom.", + ) # Paged cache options (experimental) serve_parser.add_argument( "--use-paged-cache", @@ -778,14 +810,28 @@ def main(): "xlam", "functionary", "glm47", + "minimax", ], help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " - "deepseek, kimi, granite, nemotron, xlam, functionary, glm47. " + "deepseek, kimi, granite, nemotron, xlam, functionary, glm47, minimax. " "Required for --enable-auto-tool-choice." ), ) + # Speculative decoding options + serve_parser.add_argument( + "--draft-model", + type=str, + default=None, + help="Draft model for speculative decoding (must use same tokenizer as main model)", + ) + serve_parser.add_argument( + "--num-draft-tokens", + type=int, + default=4, + help="Number of tokens to generate speculatively per step (default: 4)", + ) # Reasoning parser options - choices loaded dynamically from registry from .reasoning import list_parsers @@ -801,6 +847,25 @@ def main(): f"Options: {', '.join(reasoning_choices)}." ), ) + # GC control (Tier 0 optimization) + serve_parser.add_argument( + "--gc-control", + action="store_true", + default=True, + help="Disable Python GC during generation to avoid latency spikes (default: enabled)", + ) + serve_parser.add_argument( + "--no-gc-control", + action="store_true", + help="Disable GC control (allow normal GC during generation)", + ) + # Pinned prefix cache (Tier 0 optimization) + serve_parser.add_argument( + "--pin-system-prompt", + action="store_true", + default=False, + help="Auto-pin system prompt in prefix cache to prevent eviction under memory pressure", + ) # Multimodal option serve_parser.add_argument( "--mllm", diff --git a/vllm_mlx/engine/__init__.py b/vllm_mlx/engine/__init__.py index f6625abd..d33c4bc4 100644 --- a/vllm_mlx/engine/__init__.py +++ b/vllm_mlx/engine/__init__.py @@ -12,6 +12,7 @@ from .base import BaseEngine, GenerationOutput from .simple import SimpleEngine from .batched import BatchedEngine +from .hybrid import HybridEngine # Re-export from parent engine.py for backwards compatibility from ..engine_core import EngineCore, AsyncEngineCore, EngineConfig @@ -21,6 +22,7 @@ "GenerationOutput", "SimpleEngine", "BatchedEngine", + "HybridEngine", # Core engine components "EngineCore", "AsyncEngineCore", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..8bea4b63 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -137,6 +137,7 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, ): """ Initialize the batched engine. @@ -147,11 +148,14 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + gpu_memory_utilization: Fraction of device memory for Metal allocation + limit and emergency threshold (0.0-1.0, default 0.90) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval + self._gpu_memory_utilization = gpu_memory_utilization self._is_mllm = force_mllm or is_mllm_model(model_name) self._model = None @@ -161,6 +165,7 @@ def __init__( self._mllm_scheduler = None # MLLMScheduler for MLLM self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False + self._engine_started = False # Track if engine loop is running @property def model_name(self) -> str: @@ -283,13 +288,14 @@ async def _start_llm(self) -> None: device_info.get("memory_size", 0), ) if max_recommended > 0: - soft_limit = int(max_recommended * 0.90) + soft_limit = int(max_recommended * self._gpu_memory_utilization) mx.set_memory_limit(soft_limit) mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB + pct = self._gpu_memory_utilization * 100 logger.info( f"Metal memory limits set: " f"allocation_limit={soft_limit / 1e9:.1f}GB " - f"(90% of {max_recommended / 1e9:.1f}GB), " + f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), " f"cache_limit=32GB" ) except Exception as e: @@ -301,6 +307,7 @@ async def _start_llm(self) -> None: model_name=self._model_name, scheduler_config=scheduler_config, stream_interval=self._stream_interval, + gpu_memory_utilization=self._gpu_memory_utilization, ) # Create async engine @@ -800,3 +807,52 @@ def load_cache_from_disk(self, cache_dir: str) -> int: if self._engine: return self._engine.load_cache_from_disk(cache_dir) return 0 + + async def _inject_shared_model( + self, + model, + tokenizer, + start_engine: bool = True, + ) -> None: + """ + Inject a pre-loaded shared model instead of loading a new one. + + This is used by HybridEngine to share a single model instance + between SimpleEngine and BatchedEngine, saving ~44GB of RAM. + + Args: + model: Pre-loaded MLX model + tokenizer: Pre-loaded tokenizer + start_engine: Whether to start the engine loop immediately. + Set to False for HybridEngine (lazy start on first use). + """ + from ..engine_core import AsyncEngineCore, EngineConfig + from ..scheduler import SchedulerConfig + + self._model = model + self._tokenizer = tokenizer + + # Create engine config + scheduler_config = self._scheduler_config or SchedulerConfig() + engine_config = EngineConfig( + model_name=self._model_name, + scheduler_config=scheduler_config, + stream_interval=self._stream_interval, + ) + + # Create async engine with shared model + self._engine = AsyncEngineCore( + model=self._model, + tokenizer=self._tokenizer, + config=engine_config, + ) + + # Only start engine loop if requested (HybridEngine starts lazily) + if start_engine: + await self._engine.engine.start() + + self._loaded = True + self._engine_started = start_engine + logger.info( + f"BatchedEngine injected with shared model: {self._model_name} (started={start_engine})" + ) diff --git a/vllm_mlx/engine/hybrid.py b/vllm_mlx/engine/hybrid.py new file mode 100644 index 00000000..c0b7c2aa --- /dev/null +++ b/vllm_mlx/engine/hybrid.py @@ -0,0 +1,509 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Hybrid engine combining speculative decoding with continuous batching. + +This engine shares a single model instance between SimpleEngine and BatchedEngine, +saving ~44GB RAM while providing: +- Speculative decoding (80+ tok/s) for single-user scenarios +- Continuous batching (60-70 tok/s) for multi-user scenarios + +The engine automatically switches between modes based on concurrent request count. + +Architecture: + HybridEngine + ├── Shared components (loaded once) + │ ├── _shared_model (44GB) + │ ├── _shared_tokenizer + │ └── _ownership_lock + │ + ├── SimpleEngine (speculative) + │ └── + draft_model (350MB) + │ + └── BatchedEngine (batching) + └── Scheduler + BatchGenerator + +Mode switching logic: + if active_requests < switch_threshold: + → SimpleEngine (speculative, 80+ tok/s) + else: + → BatchedEngine (batching, 60-70 tok/s) +""" + +import asyncio +import logging +from collections.abc import AsyncIterator +from typing import Any + +from mlx_lm import load + +from .base import BaseEngine, GenerationOutput +from .batched import BatchedEngine +from .simple import SimpleEngine +from ..model_registry import get_registry + +logger = logging.getLogger(__name__) + + +class HybridEngine(BaseEngine): + """ + Hybrid engine: speculative decoding for single-user, + continuous batching for multi-user. Shares model instance. + + This engine provides the best of both worlds: + - When serving a single user: uses SimpleEngine with speculative decoding + for maximum throughput (80+ tok/s with Qwen3-0.6B draft model) + - When serving multiple concurrent users: switches to BatchedEngine + for efficient continuous batching + + The key innovation is sharing a single model instance (~44GB for Qwen3-Next-80B) + between both engines, cutting memory usage in half compared to running + separate servers. + """ + + def __init__( + self, + model_name: str, + draft_model: str | None = None, + num_draft_tokens: int = 5, + scheduler_config: Any | None = None, + stream_interval: int = 1, + trust_remote_code: bool = True, + force_mllm: bool = False, + switch_threshold: int = 2, + ): + """ + Initialize the hybrid engine. + + Args: + model_name: HuggingFace model name or local path + draft_model: Draft model for speculative decoding (e.g., Qwen3-0.6B-4bit) + num_draft_tokens: Number of tokens to generate speculatively (default: 5) + scheduler_config: Scheduler config for batched mode + stream_interval: Tokens to batch before streaming (batched mode only) + trust_remote_code: Whether to trust remote code + force_mllm: Force loading as MLLM even if not auto-detected + switch_threshold: Number of concurrent requests to trigger batch mode (default: 2) + """ + self._model_name = model_name + self._draft_model_name = draft_model + self._num_draft_tokens = num_draft_tokens + self._scheduler_config = scheduler_config + self._stream_interval = stream_interval + self._trust_remote_code = trust_remote_code + self._force_mllm = force_mllm + self._switch_threshold = switch_threshold + + # Shared resources (loaded once) + self._shared_model = None + self._shared_tokenizer = None + + # Engine instances + self._simple: SimpleEngine | None = None + self._batched: BatchedEngine | None = None + self._current_mode: str | None = None # 'simple' or 'batched' + + # Concurrency tracking + self._active_requests = 0 + self._lock = asyncio.Lock() + self._switch_lock = asyncio.Lock() + + # State + self._loaded = False + self._is_mllm = False + + @property + def model_name(self) -> str: + """Get the model name.""" + return self._model_name + + @property + def is_mllm(self) -> bool: + """Check if this is a multimodal model.""" + return self._is_mllm + + @property + def tokenizer(self) -> Any: + """Get the tokenizer.""" + return self._shared_tokenizer + + async def start(self) -> None: + """Start the engine (load shared model and initialize sub-engines).""" + if self._loaded: + return + + logger.info(f"HybridEngine loading shared model: {self._model_name}") + + # Load model once using mlx-lm + self._shared_model, self._shared_tokenizer = load(self._model_name) + + # Check if MLLM + from ..api.utils import is_mllm_model + + self._is_mllm = self._force_mllm or is_mllm_model(self._model_name) + + if self._is_mllm: + logger.warning( + "HybridEngine does not support MLLM models yet. " + "Using BatchedEngine only." + ) + # For MLLM, just use BatchedEngine (no speculative) + self._batched = BatchedEngine( + model_name=self._model_name, + scheduler_config=self._scheduler_config, + stream_interval=self._stream_interval, + force_mllm=True, + ) + await self._batched.start() + self._current_mode = "batched" + else: + # Create SimpleEngine with draft model support + self._simple = SimpleEngine( + model_name=self._model_name, + trust_remote_code=self._trust_remote_code, + draft_model=self._draft_model_name, + num_draft_tokens=self._num_draft_tokens, + ) + # Inject shared model instead of loading again + await self._simple._inject_shared_model( + self._shared_model, + self._shared_tokenizer, + ) + + # Create BatchedEngine (lazy start - don't start engine loop yet) + self._batched = BatchedEngine( + model_name=self._model_name, + trust_remote_code=self._trust_remote_code, + scheduler_config=self._scheduler_config, + stream_interval=self._stream_interval, + ) + # Inject shared model but DON'T start engine loop yet + # Engine will be started on first switch to batched mode + await self._batched._inject_shared_model( + self._shared_model, + self._shared_tokenizer, + start_engine=False, # Lazy start for HybridEngine + ) + self._batched_engine_started = False + + # Start in simple mode (speculative decoding) + self._current_mode = "simple" + + self._loaded = True + + spec_info = "" + if self._draft_model_name and not self._is_mllm: + spec_info = f", draft={self._draft_model_name}, k={self._num_draft_tokens}" + + logger.info( + f"HybridEngine ready: {self._model_name} " + f"(mode={self._current_mode}, threshold={self._switch_threshold}{spec_info})" + ) + + async def stop(self) -> None: + """Stop the engine and cleanup resources.""" + if self._simple: + await self._simple.stop() + self._simple = None + + if self._batched: + await self._batched.stop() + self._batched = None + + self._shared_model = None + self._shared_tokenizer = None + self._loaded = False + self._current_mode = None + + logger.info("HybridEngine stopped") + + def _get_engine_for_request(self) -> BaseEngine: + """ + Get the appropriate engine for the current request. + + Note: This doesn't switch modes - it returns the engine based on + current mode. Mode switching happens separately. + """ + # For MLLM, always use batched + if self._is_mllm: + return self._batched + + return self._simple if self._current_mode == "simple" else self._batched + + async def _switch_to_mode(self, target_mode: str) -> None: + """ + Switch to the specified mode, handling ownership transfer. + + This method ensures proper model ownership transfer between engines + to prevent KV cache conflicts. + + Args: + target_mode: 'simple' or 'batched' + """ + if self._current_mode == target_mode: + return + + async with self._switch_lock: + # Double-check after acquiring lock + if self._current_mode == target_mode: + return + + old_mode = self._current_mode + + if target_mode == "batched": + # Switching to batched mode + if self._batched and self._batched._engine: + # Start BatchedEngine's engine loop if not started yet (lazy start) + if not getattr(self, "_batched_engine_started", True): + logger.info("HybridEngine: starting BatchedEngine (lazy start)") + await self._batched._engine.engine.start() + self._batched_engine_started = True + + # BatchedEngine needs model ownership for its BatchGenerator + registry = get_registry() + try: + registry.acquire( + model=self._shared_model, + engine=self._batched._engine.engine, + engine_id=self._batched._engine.engine.engine_id, + force=True, # Force transfer from SimpleEngine + ) + logger.info( + "HybridEngine: model ownership transferred to BatchedEngine" + ) + except Exception as e: + logger.warning(f"Ownership transfer failed: {e}") + # Continue anyway - the transfer may have succeeded partially + else: + # Switching to simple mode + # SimpleEngine doesn't need registry ownership (uses mlx_lm.generate directly) + # Just reset BatchedEngine's scheduler to clear KV cache state + if self._batched and self._batched._engine: + try: + self._batched._engine.engine.scheduler.deep_reset() + logger.debug("Reset BatchedEngine scheduler for mode switch") + except Exception as e: + logger.warning(f"Failed to reset BatchedEngine: {e}") + + self._current_mode = target_mode + logger.info(f"HybridEngine: mode switched {old_mode} -> {target_mode}") + + async def _decide_and_switch_mode(self, entering: bool = True) -> str: + """ + Decide which mode to use and switch if necessary. + + Args: + entering: True when entering a request, False when exiting + + Returns: + The mode to use for this request ('simple' or 'batched') + """ + if self._is_mllm: + return "batched" + + # Decide based on active request count + # When entering: count includes this request + # When exiting: count excludes this request + active = self._active_requests + + if active >= self._switch_threshold: + target_mode = "batched" + else: + target_mode = "simple" + + # Only switch when safe (no active requests in the other engine) + # This is a simplified heuristic - we switch when crossing the threshold + if target_mode != self._current_mode: + # Log the decision + logger.debug( + f"HybridEngine: active_requests={active}, " + f"threshold={self._switch_threshold}, " + f"current={self._current_mode}, target={target_mode}" + ) + + # Only switch to batched immediately, switch to simple when quiet + if target_mode == "batched": + await self._switch_to_mode("batched") + elif active == 0: + # Switch back to simple only when completely idle + await self._switch_to_mode("simple") + + return self._current_mode + + async def generate( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + **kwargs, + ) -> GenerationOutput: + """Generate a complete response (non-streaming).""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + return await engine.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ) + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def stream_generate( + self, + prompt: str, + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: list[str] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Stream generation token by token.""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + async for output in engine.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + yield output + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def chat( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + images: list[str] | None = None, + videos: list[str] | None = None, + **kwargs, + ) -> GenerationOutput: + """Chat completion (non-streaming).""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + return await engine.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ) + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + async def stream_chat( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + images: list[str] | None = None, + videos: list[str] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Stream chat completion token by token.""" + if not self._loaded: + await self.start() + + async with self._lock: + self._active_requests += 1 + + try: + # Decide mode and switch if needed + mode = await self._decide_and_switch_mode(entering=True) + engine = self._simple if mode == "simple" else self._batched + + async for output in engine.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ): + yield output + finally: + async with self._lock: + self._active_requests -= 1 + # Check if we should switch back to simple mode + await self._decide_and_switch_mode(entering=False) + + def get_stats(self) -> dict[str, Any]: + """Get engine statistics.""" + stats = { + "engine_type": "hybrid", + "model_name": self._model_name, + "is_mllm": self._is_mllm, + "loaded": self._loaded, + "current_mode": self._current_mode, + "active_requests": self._active_requests, + "switch_threshold": self._switch_threshold, + "draft_model": self._draft_model_name, + "num_draft_tokens": self._num_draft_tokens, + } + + if self._simple: + stats["simple_engine"] = self._simple.get_stats() + if self._batched: + stats["batched_engine"] = self._batched.get_stats() + + return stats + + def get_cache_stats(self) -> dict[str, Any] | None: + """Get cache statistics from active engine.""" + if self._current_mode == "simple" and self._simple: + return self._simple.get_cache_stats() + elif self._batched: + return self._batched.get_cache_stats() + return None diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index ed118bbb..df42b3e8 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -17,6 +17,15 @@ logger = logging.getLogger(__name__) +# Check for guided generation availability +try: + from ..api.guided import GuidedGenerator, is_guided_available + + HAS_GUIDED = is_guided_available() +except ImportError: + HAS_GUIDED = False + GuidedGenerator = None + class SimpleEngine(BaseEngine): """ @@ -32,6 +41,8 @@ def __init__( trust_remote_code: bool = True, enable_cache: bool = True, force_mllm: bool = False, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Initialize the simple engine. @@ -41,11 +52,15 @@ def __init__( trust_remote_code: Whether to trust remote code enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected + draft_model: Optional draft model path for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) + self._draft_model_name = draft_model + self._num_draft_tokens = num_draft_tokens self._model = None self._loaded = False @@ -85,17 +100,27 @@ async def start(self) -> None: trust_remote_code=self._trust_remote_code, enable_cache=self._enable_cache, ) + if self._draft_model_name: + logger.warning("Speculative decoding is not supported with MLLM models") else: from ..models.llm import MLXLanguageModel self._model = MLXLanguageModel( self._model_name, trust_remote_code=self._trust_remote_code, + draft_model=self._draft_model_name, + num_draft_tokens=self._num_draft_tokens, ) self._model.load() self._loaded = True - logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})") + + spec_info = "" + if self._draft_model_name and not self._is_mllm: + spec_info = f", speculative={self._draft_model_name}" + logger.info( + f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{spec_info})" + ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" @@ -440,3 +465,174 @@ def get_cache_stats(self) -> dict[str, Any] | None: if self._is_mllm and self._model is not None: return self._model.get_cache_stats() return None + + async def _inject_shared_model( + self, + model, + tokenizer, + ) -> None: + """ + Inject a pre-loaded shared model instead of loading a new one. + + This is used by HybridEngine to share a single model instance + between SimpleEngine and BatchedEngine, saving ~44GB of RAM. + + Args: + model: Pre-loaded MLX model + tokenizer: Pre-loaded tokenizer + """ + from ..models.llm import MLXLanguageModel + + # Create MLXLanguageModel wrapper without loading + self._model = MLXLanguageModel.__new__(MLXLanguageModel) + self._model.model_name = self._model_name + self._model.tokenizer_name = self._model_name + self._model.trust_remote_code = self._trust_remote_code + self._model.draft_model_name = self._draft_model_name + self._model.num_draft_tokens = self._num_draft_tokens + self._model.model = model + self._model.tokenizer = tokenizer + self._model.draft_model = None + self._model._loaded = True + + # Load draft model separately if specified + if self._draft_model_name: + from mlx_lm import load as mlx_load + + logger.info( + f"Loading draft model for speculative decoding: {self._draft_model_name}" + ) + self._model.draft_model, draft_tokenizer = mlx_load(self._draft_model_name) + + # Validate tokenizer compatibility + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + logger.warning( + f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) " + f"differs from main model ({tokenizer.vocab_size}). " + "This may reduce speculative decoding effectiveness." + ) + + logger.info( + f"Speculative decoding enabled: draft={self._draft_model_name}, " + f"num_draft_tokens={self._num_draft_tokens}" + ) + + self._loaded = True + logger.info(f"SimpleEngine injected with shared model: {self._model_name}") + + @property + def supports_guided_generation(self) -> bool: + """Check if guided generation is available.""" + return HAS_GUIDED and not self._is_mllm + + async def generate_with_schema( + self, + messages: list[dict[str, Any]], + json_schema: dict[str, Any], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + **kwargs, + ) -> GenerationOutput: + """ + Generate JSON output constrained to a schema using guided decoding. + + This method uses outlines for constrained generation to guarantee + the output is valid JSON matching the specified schema. + + Args: + messages: List of chat messages + json_schema: JSON schema to constrain output + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling + **kwargs: Additional parameters + + Returns: + GenerationOutput with JSON text matching the schema + """ + if not self.supports_guided_generation: + raise RuntimeError( + "Guided generation not available. " + "Install with: pip install 'vllm-mlx[guided]'" + ) + + if not self._loaded: + await self.start() + + # Build prompt from messages + tokenizer = self._model.tokenizer + if hasattr(tokenizer, "apply_chat_template"): + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) + prompt += "\nassistant:" + + async with self._generation_lock: + # Run guided generation in thread pool + result = await asyncio.to_thread( + self._run_guided_generation, + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + + if result is None: + # Fallback to regular generation + logger.warning( + "Guided generation failed, falling back to regular generation" + ) + return await self.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + # Tokenize for completion count + tokens = tokenizer.encode(result) + + return GenerationOutput( + text=result, + tokens=tokens, + prompt_tokens=len(tokenizer.encode(prompt)), + completion_tokens=len(tokens), + finish_reason="stop", + ) + + def _run_guided_generation( + self, + prompt: str, + json_schema: dict[str, Any], + max_tokens: int, + temperature: float, + ) -> str | None: + """ + Run guided generation synchronously (called from thread pool). + + Args: + prompt: Input prompt + json_schema: JSON schema + max_tokens: Maximum tokens + temperature: Sampling temperature + + Returns: + JSON string or None if failed + """ + try: + generator = GuidedGenerator(self._model.model, self._model.tokenizer) + return generator.generate_json( + prompt=prompt, + json_schema=json_schema, + max_tokens=max_tokens, + temperature=temperature, + ) + except Exception as e: + logger.error(f"Guided generation error: {e}") + return None diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index aaa0ccc4..d210c473 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -36,6 +36,7 @@ class EngineConfig: scheduler_config: Optional[SchedulerConfig] = None step_interval: float = 0.001 # 1ms between steps stream_interval: int = 1 # Tokens to batch before streaming (1=every token) + gpu_memory_utilization: float = 0.90 # Fraction of device memory for allocation class EngineCore: @@ -150,8 +151,16 @@ async def _engine_loop(self) -> None: stream_interval = self.config.stream_interval use_simple_streaming = stream_interval == 1 - # Emergency memory pressure threshold (200GB) - _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 + # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization + # Uses a 5% gap above the soft limit (capped at 99%) to allow temporary spikes. + _gpu_mem_util = self.config.gpu_memory_utilization + try: + _device_mem = mx.device_info().get("memory_size", 200 * 1024 * 1024 * 1024) + _memory_pressure_threshold = int( + _device_mem * min(_gpu_mem_util + 0.05, 0.99) + ) + except Exception: + _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 _memory_check_interval = 64 while self._running: diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 092c060e..1c75fb3a 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -50,6 +50,8 @@ def __init__( model_name: str, tokenizer_name: str | None = None, trust_remote_code: bool = False, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Initialize the MLX language model. @@ -58,13 +60,18 @@ def __init__( model_name: HuggingFace model name or local path tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code + draft_model: Optional draft model path for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code + self.draft_model_name = draft_model + self.num_draft_tokens = num_draft_tokens self.model = None self.tokenizer = None + self.draft_model = None self._loaded = False def load(self) -> None: @@ -91,6 +98,28 @@ def load(self) -> None: tokenizer_config=tokenizer_config, ) + # Load draft model for speculative decoding if specified + if self.draft_model_name: + logger.info( + f"Loading draft model for speculative decoding: {self.draft_model_name}" + ) + from mlx_lm import load as mlx_load + + self.draft_model, draft_tokenizer = mlx_load(self.draft_model_name) + + # Validate tokenizer compatibility + if draft_tokenizer.vocab_size != self.tokenizer.vocab_size: + logger.warning( + f"Draft model tokenizer vocab size ({draft_tokenizer.vocab_size}) " + f"differs from main model ({self.tokenizer.vocab_size}). " + "This may reduce speculative decoding effectiveness." + ) + + logger.info( + f"Speculative decoding enabled: draft={self.draft_model_name}, " + f"num_draft_tokens={self.num_draft_tokens}" + ) + self._loaded = True logger.info(f"Model loaded successfully: {self.model_name}") @@ -147,15 +176,31 @@ def generate( # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) - # Generate text - output_text = generate( - self.model, - self.tokenizer, - prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, - verbose=False, - ) + # Note: mlx_lm.generate() doesn't support draft_model directly, + # speculative decoding is only available via stream_generate() + if self.draft_model is not None: + # Use streaming with draft model and collect result + output_text = "" + for chunk in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ): + output_text += chunk.text + if chunk.finished: + break + else: + # Generate text without speculative decoding + output_text = generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + verbose=False, + ) # Tokenize output to get token IDs tokens = self.tokenizer.encode(output_text) @@ -203,12 +248,22 @@ def stream_generate( token_count = 0 accumulated_text = "" + # Build generation kwargs + gen_kwargs = { + "max_tokens": max_tokens, + "sampler": sampler, + } + + # Add draft model for speculative decoding if available + if self.draft_model is not None: + gen_kwargs["draft_model"] = self.draft_model + gen_kwargs["num_draft_tokens"] = self.num_draft_tokens + for response in stream_generate( self.model, self.tokenizer, prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, + **gen_kwargs, ): token_count += 1 # response.text is the new token text (not accumulated) diff --git a/vllm_mlx/paged_cache.py b/vllm_mlx/paged_cache.py index b7357792..360c6bbd 100644 --- a/vllm_mlx/paged_cache.py +++ b/vllm_mlx/paged_cache.py @@ -110,6 +110,7 @@ class CacheBlock: # Special flags is_null: bool = False + is_pinned: bool = False # Pinned blocks are never evicted (e.g., system prompt) # Actual tensor data for this block # List of (keys, values) per layer, shape: (1, n_kv_heads, block_tokens, head_dim) @@ -202,7 +203,10 @@ def __init__(self, blocks: List[CacheBlock]) -> None: def popleft(self) -> CacheBlock: """ - Pop and return the first (LRU) free block. + Pop and return the first (LRU) non-pinned free block. + + Pinned blocks are skipped to prevent eviction of important + prefix cache entries (e.g., system prompts). Raises: ValueError: If no free blocks available @@ -210,13 +214,21 @@ def popleft(self) -> CacheBlock: if self.fake_head.next_free_block is self.fake_tail: raise ValueError("No free blocks available") + # Skip pinned blocks block = self.fake_head.next_free_block + while block is not self.fake_tail and block.is_pinned: + block = block.next_free_block + + if block is self.fake_tail: + raise ValueError("No free blocks available (all free blocks are pinned)") + assert block is not None - # Remove from list - self.fake_head.next_free_block = block.next_free_block + # Remove from list (may not be head if pinned blocks were skipped) + if block.prev_free_block: + block.prev_free_block.next_free_block = block.next_free_block if block.next_free_block: - block.next_free_block.prev_free_block = self.fake_head + block.next_free_block.prev_free_block = block.prev_free_block block.prev_free_block = None block.next_free_block = None @@ -226,7 +238,10 @@ def popleft(self) -> CacheBlock: def popleft_n(self, n: int) -> List[CacheBlock]: """ - Pop n blocks from the front. + Pop n non-pinned blocks from the front. + + Pinned blocks are skipped to prevent eviction of important + prefix cache entries (e.g., system prompts). Args: n: Number of blocks to allocate @@ -247,19 +262,27 @@ def popleft_n(self, n: int) -> List[CacheBlock]: result = [] curr = self.fake_head.next_free_block - for _ in range(n): - assert curr is not None and curr is not self.fake_tail - result.append(curr) - last = curr + while len(result) < n: + assert curr is not None and curr is not self.fake_tail, ( + f"Need {n} blocks, only found {len(result)} non-pinned" + ) + + if curr.is_pinned: + curr = curr.next_free_block + continue + + # Remove this block from the list + block = curr curr = curr.next_free_block - # Clear pointers - last.prev_free_block = None - last.next_free_block = None - # Reconnect list - self.fake_head.next_free_block = curr - if curr: - curr.prev_free_block = self.fake_head + if block.prev_free_block: + block.prev_free_block.next_free_block = block.next_free_block + if block.next_free_block: + block.next_free_block.prev_free_block = block.prev_free_block + + block.prev_free_block = None + block.next_free_block = None + result.append(block) self.num_free_blocks -= n return result @@ -613,6 +636,9 @@ def _maybe_evict_cached_block(self, block: CacheBlock) -> bool: Returns: True if block was evicted from cache """ + if block.is_pinned: + return False + if block.block_hash is None: return False @@ -1138,6 +1164,64 @@ def get_memory_usage(self) -> Dict[str, Any]: ), } + # ========================================================================= + # Block Pinning (prevent eviction of important prefixes) + # ========================================================================= + + def pin_blocks(self, block_ids: List[int]) -> int: + """ + Pin blocks to prevent eviction (e.g., system prompt cache). + + Args: + block_ids: Block IDs to pin + + Returns: + Number of blocks actually pinned + """ + with self._lock: + pinned = 0 + for block_id in block_ids: + block = self.allocated_blocks.get(block_id) + if block and not block.is_null and not block.is_pinned: + block.is_pinned = True + pinned += 1 + if pinned: + logger.info(f"Pinned {pinned} cache blocks: {block_ids}") + return pinned + + def unpin_blocks(self, block_ids: List[int]) -> int: + """ + Unpin blocks to allow eviction. + + Args: + block_ids: Block IDs to unpin + + Returns: + Number of blocks actually unpinned + """ + with self._lock: + unpinned = 0 + for block_id in block_ids: + block = self.allocated_blocks.get(block_id) + if not block: + # Also check blocks in the free queue + if block_id < len(self.blocks): + block = self.blocks[block_id] + if block and block.is_pinned: + block.is_pinned = False + unpinned += 1 + if unpinned: + logger.info(f"Unpinned {unpinned} cache blocks: {block_ids}") + return unpinned + + def get_pinned_block_ids(self) -> List[int]: + """Get list of all pinned block IDs.""" + with self._lock: + return [ + b.block_id for b in self.blocks + if b.is_pinned + ] + def reset_stats(self) -> None: """Reset statistics counters.""" with self._lock: diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a32..68301985 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -347,6 +347,52 @@ def clear(self) -> None: self._lru.clear() self.reset_stats() + def pin_prefix(self, tokens: List[int]) -> bool: + """ + Pin a prefix in the cache to prevent eviction. + + For the trie-based cache, this removes the entry from the LRU queue + so it is never evicted. The entry remains accessible for lookups. + + Args: + tokens: Token sequence of the prefix to pin + + Returns: + True if prefix was found and pinned + """ + tokens_tuple = tuple(tokens) + key = (self.model_key, tokens_tuple) + try: + self._lru.remove(key) + logger.info(f"Pinned prefix ({len(tokens)} tokens) - removed from LRU") + return True + except ValueError: + logger.warning(f"Cannot pin prefix: not found in cache") + return False + + def unpin_prefix(self, tokens: List[int]) -> bool: + """ + Unpin a prefix, making it eligible for LRU eviction again. + + Args: + tokens: Token sequence of the prefix to unpin + + Returns: + True if prefix was found and unpinned + """ + tokens_tuple = tuple(tokens) + key = (self.model_key, tokens_tuple) + # Check the entry exists in the trie + entry = self._get_cache_entry(tokens) + if entry is None: + return False + # Re-add to LRU (at MRU end) + if key not in self._lru: + self._lru.append(key) + logger.info(f"Unpinned prefix ({len(tokens)} tokens) - added back to LRU") + return True + return False + def __len__(self) -> int: """Return number of cached entries.""" return len(self._lru) @@ -947,6 +993,65 @@ def clear(self) -> None: self.paged_cache.clear() self.reset_stats() + def pin_prefix(self, tokens: List[int]) -> bool: + """ + Pin blocks covering a token prefix to prevent eviction. + + Args: + tokens: Token sequence of the prefix to pin + + Returns: + True if blocks were found and pinned + """ + # Find blocks covering this prefix + shared_block_ids, _ = self.paged_cache.find_shared_prefix(tokens) + if shared_block_ids: + pinned = self.paged_cache.pin_blocks(shared_block_ids) + if pinned > 0: + logger.info( + f"Pinned prefix: {pinned} blocks, " + f"{len(tokens)} tokens" + ) + return True + + # Try prefix index + best_match = self._find_best_prefix_match(tokens) + if best_match: + _, block_ids = best_match + pinned = self.paged_cache.pin_blocks(block_ids) + if pinned > 0: + logger.info( + f"Pinned prefix via index: {pinned} blocks, " + f"{len(tokens)} tokens" + ) + return True + + logger.warning(f"Cannot pin prefix: no cached blocks found for {len(tokens)} tokens") + return False + + def unpin_prefix(self, tokens: List[int]) -> bool: + """ + Unpin blocks covering a token prefix. + + Args: + tokens: Token sequence of the prefix to unpin + + Returns: + True if blocks were found and unpinned + """ + shared_block_ids, _ = self.paged_cache.find_shared_prefix(tokens) + if shared_block_ids: + unpinned = self.paged_cache.unpin_blocks(shared_block_ids) + return unpinned > 0 + + best_match = self._find_best_prefix_match(tokens) + if best_match: + _, block_ids = best_match + unpinned = self.paged_cache.unpin_blocks(block_ids) + return unpinned > 0 + + return False + def __len__(self) -> int: """Return number of active request entries.""" return len(self._request_tables) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 138c166a..067fe6ca 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -39,6 +39,8 @@ import argparse import asyncio +import gc +import hashlib import json import logging import os @@ -93,16 +95,25 @@ from .api.tool_calling import ( build_json_system_prompt, convert_tools_for_template, + extract_json_schema_for_guided, parse_json_output, parse_tool_calls, ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, clean_output_text, + extract_json_from_response, extract_multimodal_content, is_mllm_model, # noqa: F401 + strip_thinking_tags, +) +from .engine import ( + BaseEngine, + BatchedEngine, + GenerationOutput, + HybridEngine, + SimpleEngine, ) -from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine from .tool_parsers import ToolParserManager logging.basicConfig(level=logging.INFO) @@ -158,6 +169,82 @@ def _resolve_top_p(request_value: float | None) -> float: _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +# GC control (Tier 0 optimization) +_gc_control: bool = True # Disable GC during generation to avoid latency spikes + +# Pinned prefix cache (Tier 0 optimization) +_pin_system_prompt: bool = False # Auto-pin system prompt prefix cache blocks +_pinned_system_prompt_hash: str | None = None # Hash of pinned system prompt + + +def _maybe_pin_system_prompt(messages: list) -> None: + """Auto-pin system prompt prefix cache blocks on first request.""" + global _pinned_system_prompt_hash + + if not _pin_system_prompt or _engine is None: + return + + # Extract system message content + system_content = None + for msg in messages: + role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None) + if role == "system": + content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None) + if isinstance(content, str): + system_content = content + break + + if not system_content: + return + + # Check if this system prompt is already pinned + prompt_hash = hashlib.sha256(system_content.encode()).hexdigest()[:16] + if prompt_hash == _pinned_system_prompt_hash: + return # Already pinned + + # Try to pin via engine's cache manager + try: + # Get the tokenizer to encode the system prompt + tokenizer = None + if hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + elif hasattr(_engine, "_model") and hasattr(_engine._model, "tokenizer"): + tokenizer = _engine._model.tokenizer + + if tokenizer is None: + return + + system_tokens = tokenizer.encode(system_content) + if not system_tokens or len(system_tokens) < 16: + return # Too short to pin + + # Try BlockAwarePrefixCache (paged cache mode) + if hasattr(_engine, "_prefix_cache") and _engine._prefix_cache is not None: + cache = _engine._prefix_cache + if hasattr(cache, "pin_prefix"): + if cache.pin_prefix(system_tokens): + _pinned_system_prompt_hash = prompt_hash + logger.info( + f"Auto-pinned system prompt: {len(system_tokens)} tokens, " + f"hash={prompt_hash}" + ) + return + + # Try PrefixCacheManager (trie-based cache mode) + if hasattr(_engine, "_cache_manager") and _engine._cache_manager is not None: + cache = _engine._cache_manager + if hasattr(cache, "pin_prefix"): + if cache.pin_prefix(system_tokens): + _pinned_system_prompt_hash = prompt_hash + logger.info( + f"Auto-pinned system prompt (trie): {len(system_tokens)} tokens, " + f"hash={prompt_hash}" + ) + return + + except Exception as e: + logger.debug(f"System prompt pinning failed: {e}") + def _load_prefix_cache_from_disk() -> None: """Load prefix cache from disk during startup.""" @@ -207,6 +294,11 @@ async def lifespan(app: FastAPI): """FastAPI lifespan for startup/shutdown events.""" global _engine, _mcp_manager + # GC control: raise thresholds to reduce GC frequency with large models + if _gc_control: + gc.set_threshold(100_000, 50, 50) + logger.info("GC control enabled: thresholds set to (100000, 50, 50)") + # Startup: Start engine if loaded (needed for BatchedEngine in uvicorn's event loop) if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded: await _engine.start() @@ -467,6 +559,9 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, + draft_model: str | None = None, + num_draft_tokens: int = 4, ): """ Load a model (auto-detects MLLM vs LLM). @@ -478,6 +573,10 @@ def load_model( stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation force_mllm: Force loading as MLLM even if not auto-detected + gpu_memory_utilization: Fraction of device memory for Metal allocation + limit and emergency threshold (0.0-1.0, default 0.90) + draft_model: Optional draft model for speculative decoding + num_draft_tokens: Number of tokens to generate speculatively per step """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -489,20 +588,47 @@ def load_model( if force_mllm: logger.info("Force MLLM mode enabled via --mllm flag") - if use_batching: + if draft_model: + logger.info(f"Speculative decoding enabled with draft model: {draft_model}") + logger.info(f" num_draft_tokens: {num_draft_tokens}") + + if use_batching and draft_model: + # Hybrid mode: shared model with speculative decoding + continuous batching + logger.info(f"Loading model with HybridEngine: {model_name}") + logger.info( + " Hybrid mode: speculative decoding for single user, " + "batching for multiple users" + ) + _engine = HybridEngine( + model_name=model_name, + draft_model=draft_model, + num_draft_tokens=num_draft_tokens, + scheduler_config=scheduler_config, + stream_interval=stream_interval, + force_mllm=force_mllm, + ) + # HybridEngine will be started in lifespan (uvicorn's event loop) + logger.info(f"Model loaded (hybrid mode): {model_name}") + elif use_batching: logger.info(f"Loading model with BatchedEngine: {model_name}") _engine = BatchedEngine( model_name=model_name, scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, + gpu_memory_utilization=gpu_memory_utilization, ) # BatchedEngine will be started in lifespan (uvicorn's event loop) # Just log for now logger.info(f"Model loaded (batched mode): {model_name}") else: logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm) + _engine = SimpleEngine( + model_name=model_name, + force_mllm=force_mllm, + draft_model=draft_model, + num_draft_tokens=num_draft_tokens, + ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) loop = asyncio.new_event_loop() @@ -1052,6 +1178,26 @@ async def _wait_disconnect(): f"{chunk_count} chunks, elapsed={_elapsed()}" ) break + except Exception as exc: + logger.error( + f"[disconnect_guard] generator raised {type(exc).__name__}: " + f"{exc}, {chunk_count} chunks, elapsed={_elapsed()}", + exc_info=True, + ) + # Yield an error event so the client sees something before [DONE] + import json as _json + + error_data = _json.dumps( + { + "error": { + "message": f"Internal error during streaming: {exc}", + "type": type(exc).__name__, + } + } + ) + yield f"data: {error_data}\n\n" + yield "data: [DONE]\n\n" + break chunk_count += 1 if chunk_count == 1: logger.info( @@ -1329,6 +1475,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re has_media = bool(images or videos) + # Auto-pin system prompt prefix cache blocks (non-blocking, first request only) + if _pin_system_prompt: + _maybe_pin_system_prompt(messages) + # Handle response_format - inject system prompt if needed response_format = request.response_format if response_format: @@ -1370,11 +1520,57 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re start_time = time.perf_counter() timeout = request.timeout or _default_timeout - output = await _wait_with_disconnect( - engine.chat(messages=messages, **chat_kwargs), - raw_request, - timeout=timeout, - ) + # Disable GC during generation to avoid latency spikes + gc_was_enabled = gc.isenabled() + if _gc_control and gc_was_enabled: + gc.disable() + + # Check if we should use guided generation for JSON schema + use_guided = False + json_schema = None + if response_format and not request.tools: + json_schema = extract_json_schema_for_guided(response_format) + if json_schema and hasattr(engine, "supports_guided_generation"): + use_guided = engine.supports_guided_generation + if use_guided: + logger.info("Using guided generation for JSON schema enforcement") + + try: + if use_guided and json_schema: + # Use guided generation for constrained JSON output + # Fall back to standard generation if guided fails (bad schema, etc.) + try: + output = await _wait_with_disconnect( + engine.generate_with_schema( + messages=messages, + json_schema=json_schema, + **chat_kwargs, + ), + raw_request, + timeout=timeout, + ) + except Exception as guided_err: + logger.warning( + f"Guided generation failed, falling back to standard: {guided_err}" + ) + logger.debug(f"Problematic schema: {json_schema}") + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + else: + # Standard generation + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + finally: + if _gc_control and gc_was_enabled: + gc.enable() + gc.collect() + if output is None: return Response(status_code=499) # Client closed request @@ -1408,12 +1604,22 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Determine finish reason finish_reason = "tool_calls" if tool_calls else output.finish_reason + # Clean and strip thinking tags from content + # Thinking tags break JSON parsing in clients expecting pure content + # Also extract JSON if model outputs reasoning text before JSON + final_content = None + if cleaned_text: + final_content = strip_thinking_tags(clean_output_text(cleaned_text)) + # If response looks like it ends with JSON, extract just the JSON part + # This handles Qwen3 reasoning mode: "Let me think... {json}" + final_content = extract_json_from_response(final_content) + return ChatCompletionResponse( model=request.model, choices=[ ChatCompletionChoice( message=AssistantMessage( - content=clean_output_text(cleaned_text) if cleaned_text else None, + content=final_content, reasoning=reasoning_text, tool_calls=tool_calls, ), @@ -1432,7 +1638,11 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. - If a system message exists, append to it. Otherwise, prepend a new system message. + PREPENDS instruction to system message for better instruction following. + JSON formatting instructions at the beginning of system message are more + likely to be followed than instructions at the end. + + If a system message exists, prepend to it. Otherwise, prepend a new system message. """ messages = list(messages) # Make a copy @@ -1445,14 +1655,15 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: break if system_idx is not None: - # Append to existing system message + # PREPEND to existing system message (not append!) + # Instructions at the start are more effective msg = messages[system_idx] if isinstance(msg, dict): existing = msg.get("content", "") - msg["content"] = f"{existing}\n\n{instruction}" + msg["content"] = f"{instruction}\n\n{existing}" else: existing = getattr(msg, "content", "") or "" - msg.content = f"{existing}\n\n{instruction}" + msg.content = f"{instruction}\n\n{existing}" else: # Prepend new system message messages.insert(0, {"role": "system", "content": instruction}) @@ -1848,6 +2059,10 @@ async def stream_chat_completion( **kwargs, ) -> AsyncIterator[str]: """Stream chat completion response.""" + gc_was_enabled = gc.isenabled() + if _gc_control and gc_was_enabled: + gc.disable() + response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" start_time = time.perf_counter() @@ -1928,16 +2143,99 @@ async def stream_chat_completion( # Skip this chunk (e.g., token itself) continue + content = delta_msg.content + reasoning = delta_msg.reasoning + + # Some models (e.g. MiniMax) wrap tool calls in + # blocks, so reasoning parser captures tool call XML as + # reasoning while content stays None. Redirect reasoning + # to the content stream so the tool parser can handle it. + if tool_parser and reasoning and not content: + _check = tool_accumulated_text + reasoning + if ( + "" in _check + or "" in _check + or ' [positions] + self._ngram_index: dict[tuple, list[int]] = defaultdict(list) + + # Statistics + self.total_drafts = 0 + self.successful_drafts = 0 + self.total_draft_tokens = 0 + self.accepted_tokens = 0 + + def reset(self): + """Reset the decoder state for a new generation.""" + self._token_history = [] + self._ngram_index = defaultdict(list) + + def add_prompt_tokens(self, tokens: list[int]): + """ + Add prompt tokens to the history for lookup. + + Args: + tokens: List of prompt token IDs + """ + for token in tokens: + self._add_token(token) + + def _add_token(self, token: int): + """Add a single token to history and update n-gram index.""" + self._token_history.append(token) + + # Update n-gram index for all n-gram sizes up to ngram_size + pos = len(self._token_history) - 1 + for n in range(1, min(self.ngram_size + 1, pos + 1)): + if pos >= n: + ngram = tuple(self._token_history[pos - n : pos]) + self._ngram_index[ngram].append(pos) + + def add_generated_token(self, token: int): + """ + Add a generated token to history. + + Args: + token: Generated token ID + """ + self._add_token(token) + + def get_draft_tokens(self) -> list[int]: + """ + Get draft tokens based on n-gram lookup. + + Returns: + List of draft token IDs (may be empty if no match found) + """ + if len(self._token_history) < self.ngram_size: + return [] + + # Get the last ngram_size tokens as the query + query = tuple(self._token_history[-self.ngram_size :]) + + # Look up positions where this n-gram occurred + positions = self._ngram_index.get(query, []) + + if not positions: + return [] + + # Find the best match (most recent, or one with longest continuation) + draft_tokens = [] + best_continuation_length = 0 + + for pos in positions: + # Skip if this is the current position + if pos == len(self._token_history) - 1: + continue + + # Check if there are tokens after this position + if pos < len(self._token_history) - 1: + # Get continuation tokens + continuation_end = min( + pos + self.num_draft_tokens + 1, len(self._token_history) + ) + continuation = self._token_history[pos + 1 : continuation_end] + + if len(continuation) > best_continuation_length: + best_continuation_length = len(continuation) + draft_tokens = continuation[: self.num_draft_tokens] + + if len(draft_tokens) >= self.min_matches: + self.total_drafts += 1 + self.total_draft_tokens += len(draft_tokens) + return draft_tokens + + return [] + + def record_accepted(self, num_accepted: int): + """Record statistics about accepted draft tokens.""" + if num_accepted > 0: + self.successful_drafts += 1 + self.accepted_tokens += num_accepted + + def get_stats(self) -> dict: + """Get decoder statistics.""" + acceptance_rate = 0.0 + if self.total_draft_tokens > 0: + acceptance_rate = self.accepted_tokens / self.total_draft_tokens + + return { + "total_drafts": self.total_drafts, + "successful_drafts": self.successful_drafts, + "total_draft_tokens": self.total_draft_tokens, + "accepted_tokens": self.accepted_tokens, + "acceptance_rate": acceptance_rate, + "history_size": len(self._token_history), + } + + +def prompt_lookup_generate_step( + prompt: mx.array, + model, + *, + num_draft_tokens: int = 4, + ngram_size: int = 3, + max_tokens: int = 256, + sampler=None, + logits_processors=None, + prompt_cache=None, + prefill_step_size: int = 512, +): + """ + Generator for token generation with prompt lookup speculation. + + This is a drop-in replacement for generate_step that uses n-gram + matching for speculation instead of a draft model. + + Args: + prompt: Input token array + model: The main model + num_draft_tokens: Number of tokens to draft + ngram_size: N-gram size for matching + max_tokens: Maximum tokens to generate + sampler: Token sampler (default: argmax) + logits_processors: Optional logits processors + prompt_cache: Optional KV cache + prefill_step_size: Prefill chunk size + + Yields: + Tuple of (token, logprobs, from_draft) + """ + from mlx_lm.models import cache + + y = prompt.astype(mx.uint32) + + # Create KV cache if not provided + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + # Initialize prompt lookup decoder + decoder = PromptLookupDecoder( + num_draft_tokens=num_draft_tokens, + ngram_size=ngram_size, + ) + + # Add prompt tokens to decoder + decoder.add_prompt_tokens(prompt.tolist()) + + def _step(tokens: mx.array, n_predict: int = 1): + """Run model on tokens and get predictions.""" + logits = model(tokens[None], cache=prompt_cache) + logits = logits[:, -n_predict:, :] + + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + sampled = sampler(logprobs) + + return sampled.squeeze(0), logprobs.squeeze(0) + + def _prefill(tokens: mx.array): + """Process prompt tokens in chunks.""" + while tokens.size > prefill_step_size: + model(tokens[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + tokens = tokens[prefill_step_size:] + mx.clear_cache() + return tokens + + # Prefill prompt + y = _prefill(y) + + # Generate first token + current_token, logprobs = _step(y) + mx.eval(current_token, logprobs) + + ntoks = 0 + + while ntoks < max_tokens: + # Yield current token + yield current_token.item(), logprobs, False + ntoks += 1 + + if ntoks >= max_tokens: + break + + # Add token to decoder history + decoder.add_generated_token(current_token.item()) + + # Try to get draft tokens via n-gram lookup + draft_tokens = decoder.get_draft_tokens() + + if draft_tokens and len(draft_tokens) > 0: + # Speculative path: verify draft tokens + verify_input = mx.array([current_token.item()] + draft_tokens, mx.uint32) + verified_tokens, verified_logprobs = _step( + verify_input, n_predict=len(draft_tokens) + 1 + ) + mx.eval(verified_tokens, verified_logprobs) + + verified_tokens = verified_tokens.tolist() + + # Check how many draft tokens were accepted + n_accepted = 0 + for i, (draft_t, verify_t) in enumerate( + zip(draft_tokens, verified_tokens[:-1]) + ): + if draft_t == verify_t: + n_accepted += 1 + ntoks += 1 + decoder.add_generated_token(draft_t) + yield draft_t, verified_logprobs[i], True # from_draft=True + if ntoks >= max_tokens: + break + else: + break + + decoder.record_accepted(n_accepted) + + if ntoks >= max_tokens: + break + + # Trim cache for rejected tokens + if n_accepted < len(draft_tokens): + cache.trim_prompt_cache(prompt_cache, len(draft_tokens) - n_accepted) + + # Next token is the first non-accepted verification result + current_token = mx.array(verified_tokens[n_accepted], mx.uint32) + logprobs = verified_logprobs[n_accepted] + else: + # Standard path: single token generation + next_token, next_logprobs = _step( + mx.array([current_token.item()], mx.uint32) + ) + mx.eval(next_token, next_logprobs) + current_token = next_token + logprobs = next_logprobs + + if ntoks % 256 == 0: + mx.clear_cache() + + # Log final stats + stats = decoder.get_stats() + if stats["total_drafts"] > 0: + logger.info( + f"Prompt Lookup stats: {stats['accepted_tokens']}/{stats['total_draft_tokens']} " + f"tokens accepted ({stats['acceptance_rate']:.1%})" + ) diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 16f74408..7598b88e 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -57,6 +57,7 @@ from .xlam_tool_parser import xLAMToolParser from .glm47_tool_parser import Glm47ToolParser from .harmony_tool_parser import HarmonyToolParser +from .minimax_tool_parser import MiniMaxToolParser __all__ = [ # Base classes @@ -77,4 +78,5 @@ "FunctionaryToolParser", "Glm47ToolParser", "HarmonyToolParser", + "MiniMaxToolParser", ] diff --git a/vllm_mlx/tool_parsers/minimax_tool_parser.py b/vllm_mlx/tool_parsers/minimax_tool_parser.py new file mode 100644 index 00000000..7459fe97 --- /dev/null +++ b/vllm_mlx/tool_parsers/minimax_tool_parser.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MiniMax tool call parser for vllm-mlx. + +Parses the MiniMax-M2 native XML tool call format: + + +param-value + + +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + + +def generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module(["minimax", "minimax_m2"]) +class MiniMaxToolParser(ToolParser): + """ + Parser for MiniMax-M2 tool call format. + + Format: + + + value + + + """ + + TOOL_CALL_BLOCK = re.compile( + r"(.*?)", re.DOTALL + ) + INVOKE_PATTERN = re.compile(r'(.*?)', re.DOTALL) + PARAM_PATTERN = re.compile( + r'(.*?)', re.DOTALL + ) + THINK_PATTERN = re.compile(r".*?", re.DOTALL) + + def _extract_invokes(self, text: str) -> list[dict[str, Any]]: + """Extract tool calls from invoke elements, with or without wrapper.""" + tool_calls: list[dict[str, Any]] = [] + invokes = self.INVOKE_PATTERN.findall(text) + for func_name, params_block in invokes: + params = self.PARAM_PATTERN.findall(params_block) + # Skip bare tags without parameters (hallucinated junk) + if not params: + continue + arguments = {} + for p_name, p_value in params: + p_value = p_value.strip() + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + return tool_calls + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + # Try wrapped format first: ...... + blocks = self.TOOL_CALL_BLOCK.findall(model_output) + if blocks: + tool_calls: list[dict[str, Any]] = [] + for block in blocks: + tool_calls.extend(self._extract_invokes(block)) + + cleaned = self.TOOL_CALL_BLOCK.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + # Fallback: bare without wrapper + # (model sometimes emits tool calls inside without wrapper) + tool_calls = self._extract_invokes(model_output) + if tool_calls: + # Strip matched invoke blocks and thinking from content + cleaned = self.INVOKE_PATTERN.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + # Remove leftover closing tags + cleaned = cleaned.replace("", "").strip() + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _has_tool_start(self, text: str) -> bool: + """Check if text contains the start of a tool call block.""" + return "" in text or ( + '" in current: + return ( + "" in current + and "" not in previous + ) + # Bare invoke: just appeared + if "" in current and "" not in previous: + return True + return False + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + # Not inside a tool call block yet — pass content through + if not self._has_tool_start(current_text): + return {"content": delta_text} + + # Tool call block just completed + if self._has_tool_end(current_text, previous_text): + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + # Inside tool call block but not yet complete — suppress output + return None