Skip to content

fix: patch Gemma 4 attention and RotatingKVCache for BatchKVCache#256

Merged
waybarrios merged 1 commit intowaybarrios:mainfrom
janhilgard:fix/gemma4-batched-rope
Apr 10, 2026
Merged

fix: patch Gemma 4 attention and RotatingKVCache for BatchKVCache#256
waybarrios merged 1 commit intowaybarrios:mainfrom
janhilgard:fix/gemma4-batched-rope

Conversation

@janhilgard
Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard commented Apr 5, 2026

Summary

Fixes for Gemma 4 models running with --mllm and continuous batching, plus Anthropic endpoint improvements:

1. BatchKVCache offset bug fix

  • mx.array.__iadd__ (+=) mutates in-place, so reading cache.offset before update_and_fetch() gives wrong value (offset+1 instead of offset)
  • Causes incorrect RoPE positions → token repetition starting ~3 tokens into generation
  • Fix: defensive copy of cache.offset before cache mutation
  • Also fixes RotatingKVCache.max_size returning mx.array instead of int

2. Gemma 4 reasoning parser (--reasoning-parser gemma4)

  • Extracts thinking content from <|channel> (token 100) / <channel|> (token 101) into reasoning_content field
  • Works for both streaming and non-streaming responses
  • When no tags present, treats output as plain content (Gemma 4 doesn't inject tags in prompt)

3. MLLM stop tokens from generation_config.json

  • tokenizer.eos_token_id only returns primary EOS (token 1 <eos>)
  • Gemma 4's generation_config.json defines additional EOS: <turn|> (106), <|tool_response> (50)
  • Without this fix, model generates past turn boundaries, producing garbage after the response
  • Fix: _get_stop_tokens() now also reads generation_config.json

4. RotatingKVCache prefix cache fix (sliding window models)

Problem: Models with RotatingKVCache layers (Gemma 4 uses sliding window attention with max_size=1024 on 25/30 layers) crash with ValueError: [full] Negative dimensions not allowed when prefix cache attempts exact-hit or LCP-match restoration.

Root cause: During generation with a full rotating buffer, _update_in_place decrements left_padding by S on each step. After N generation tokens, left_padding = -N. Then extract() uses this as a slice start, truncating the 1024-entry buffer. The truncated cache violates RotatingKVCache's invariant: when offset >= max_size, the buffer MUST be full.

Fix:

  • MLLMBatch.extract_cache(): Custom extraction for BatchRotatingKVCache that clamps left_padding to >= 0
  • _trim_cache_offset(): Handle RotatingKVCache circular buffer trimming
  • _QuantizedCacheWrapper: Type-preserving quantization wrapper
  • _quantize_cache(): Skip RotatingKVCache (rotation state can't survive quantize/dequantize)
  • _dequantize_cache(): Deep-copy all cache layers to prevent model mutations from corrupting stored entries

5. Anthropic endpoint JSON escape handling

  • Some clients (e.g. Claude Code) send JSON with invalid escape sequences (\s, \d in regex patterns within tool definitions)
  • Python's json.loads is strict per RFC 8259 and rejects these
  • Fix: catch JSONDecodeError with "Invalid \escape" and sanitize lone backslashes

6. Strip billing header for prefix cache (13x speedup!)

  • Claude Code injects x-anthropic-billing-header: ...cch=HASH... into the system prompt
  • The cch= hash changes with every request, causing token sequences to diverge at position ~40
  • This completely defeats prefix cache (60K token full prefill on every request: ~50s)
  • Fix: strip the billing header from system prompt before tokenization
  • Result: 50s → 3.65s per request (13.7x speedup, 42 tok/s)

Files changed

  • vllm_mlx/patches/gemma4_mllm.py — Runtime monkey-patch for Gemma 4 attention + RotatingKVCache
  • vllm_mlx/mllm_batch_generator.py — Register patch, RotatingKVCache prefix cache extraction
  • vllm_mlx/memory_cache.py — RotatingKVCache support in trim/quantize/dequantize
  • vllm_mlx/reasoning/gemma4_parser.py — New reasoning parser for <|channel>...<channel|> tags
  • vllm_mlx/reasoning/__init__.py — Register gemma4 parser
  • vllm_mlx/mllm_scheduler.py — Read additional EOS tokens from generation_config.json
  • vllm_mlx/server.py — Anthropic JSON escape fix
  • vllm_mlx/api/anthropic_adapter.py — Strip billing header from system prompt

Test plan

  • Gemma 4 26B-A4B generates coherent text (no repetition) with continuous batching
  • Gemma 4 31B generates coherent text with continuous batching
  • Non-MLLM models unaffected by patch
  • Concurrent requests work correctly
  • --reasoning-parser gemma4 correctly separates thinking from content
  • Model stops at <turn|> (token 106) — no garbage after response
  • Prefix cache exact hit on Gemma 4 (RotatingKVCache) — no crash
  • Prefix cache LCP match on Gemma 4 — shared system prompt reuse works
  • Anthropic endpoint handles invalid JSON escapes (e.g. \s in tool schemas)
  • Billing header stripped: prefix cache matches 60K/60K tokens across turns
  • Claude Code Anthropic requests: 50s → 3.65s with prefix cache (13.7x speedup)

🤖 Generated with Claude Code

@janhilgard
Copy link
Copy Markdown
Collaborator Author

Benchmark: vllm-mlx vs LM Studio — Gemma 4 26B-A4B (4-bit)

Tested on Apple M3 Ultra 256GB, same model (mlx-community/gemma-4-26b-a4b-it-4bit), 500 tokens, temperature=0, single request, 3 runs each.

Engine Run 1 Run 2 Run 3 AVG tok/s
LM Studio 82.5 85.8 85.5 84.6
vllm-mlx 91.1 94.4 94.6 93.4

vllm-mlx is 10.4% faster than LM Studio on single-request throughput.

@janhilgard
Copy link
Copy Markdown
Collaborator Author

Batch throughput benchmark: vllm-mlx vs LM Studio

Same setup (M3 Ultra 256GB, Gemma 4 26B-A4B 4-bit), 300 max tokens per request, temperature=0.

Concurrency LM Studio vllm-mlx Difference
1 84.6 tok/s 93.4 tok/s +10%
2 112.8 tok/s 139.0 tok/s +23%
4 133.8 tok/s 201.1 tok/s +50%
8 134.4 tok/s 274.6 tok/s +104%

LM Studio saturates around ~134 tok/s at 4+ concurrent requests. vllm-mlx with continuous batching scales linearly — 2x faster at 8 concurrent requests.

@janhilgard
Copy link
Copy Markdown
Collaborator Author

Gemma 4 31B dense model tested — patch works correctly

Deployed mlx-community/gemma-4-31b-it-4bit (dense 31B, all params active, ~19 GB 4-bit) with --mllm --continuous-batching on Apple M3 Ultra 256GB. The BatchKVCache offset patch applies automatically and produces correct output.

Test results (7/7 passed)

Test Status Details
Basic generation (5 prompts) ✅ 5/5 Math, haiku, primes, translation, code — all coherent
Multi-turn conversation Remembers context across 2 turns
Tool calling (3 scenarios) ✅ 3/3 get_weather(Tokyo), search_web(...), no-tool-needed
Streaming (SSE) 101 chunks, [DONE] received
Concurrent batching (8 requests) ✅ 8/8 All 8 parallel responses correct
Long context (~500 tokens input) Correct summarization
System prompt adherence Follows pirate persona

Throughput

Mode Performance
Single request ~29–31 tok/s
8× concurrent batch 6–10 tok/s per request, 52 tok/s aggregate
Streaming 30.1 tok/s

Config

vllm-mlx serve mlx-community/gemma-4-31b-it-4bit \
    --port 1239 --host 0.0.0.0 \
    --continuous-batching --max-num-seqs 8 \
    --cache-memory-mb 8192 --max-tokens 131072 \
    --enable-auto-tool-choice --tool-call-parser gemma4 \
    --mllm

This confirms the patch handles both MoE (26B-A4B) and dense (31B) Gemma 4 architectures correctly with BatchKVCache.

Note: dboris/gemma-4-31b-abliterated-mlx-4bit produces garbage output — broken quantization unrelated to this patch.

@janhilgard janhilgard requested a review from waybarrios April 5, 2026 21:26
@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch 2 times, most recently from dfe2e8f to 90f83e1 Compare April 6, 2026 09:50
@janhilgard janhilgard changed the title fix: patch Gemma 4 attention for BatchKVCache offset mutation fix: patch Gemma 4 attention and RotatingKVCache for BatchKVCache Apr 6, 2026
@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch 3 times, most recently from 9f54984 to 8e547a7 Compare April 6, 2026 11:35
@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch from 8e547a7 to 047f523 Compare April 6, 2026 12:55
@Thump604
Copy link
Copy Markdown
Collaborator

Thump604 commented Apr 7, 2026

@waybarrios, @janhilgard: cross-reference note. The BatchKVCache offset bug you describe in section 1 is the same root cause as a bug I fixed at the mlx_vlm layer in mlx_vlm/models/gemma4/language.py:Attention.__call__. The mechanism is identical: cache.offset is an mx.array, __iadd__ mutates in place, so a bare reference captured before update_and_fetch() sees the post-mutation value at Q-rope time.

The fix at the mlx_vlm layer is the cache.offset + 0 snapshot pattern (the + 0 forces a new array because int + 0 rebinds to a new int but mx.array + 0 returns a new array).

The two fixes are complementary. Yours patches the BatchKVCache layer in vllm_mlx, mine patches the per-attention-call usage in mlx_vlm. Both layers can hit the same root cause depending on which path is exercised.

Mergeable on current main per the PR JSON. Sound fix on the BatchKVCache side, the additional RotatingKVCache.max_size fix is also legitimate.

@janhilgard
Copy link
Copy Markdown
Collaborator Author

@Thump604 Thanks for the cross-reference and confirmation — really helpful to have independent validation of the root cause.

Agreed the two fixes are complementary. Since vllm-mlx doesn't control the upstream mlx_vlm code, the runtime monkey-patch in vllm_mlx/patches/gemma4_mllm.py is necessary to guarantee correct behavior regardless of the mlx_vlm version installed. Your fix at the mlx_vlm attention layer is the more principled long-term solution.

Ideally the cache.offset + 0 defensive copy pattern would land directly in mlx_vlm upstream — either in the BatchKVCache.offset property itself (returning a copy) or in the Attention.__call__ implementations that read it before update_and_fetch(). That way downstream projects like vllm-mlx wouldn't need to carry patches.

Worth noting this bug isn't Gemma 4-specific — any mlx_vlm model that captures cache.offset into a local variable before calling update_and_fetch() will see the same incorrect RoPE positions with BatchKVCache. Might be worth auditing other model attention implementations in mlx_vlm for the same pattern.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both fixes are the right call together. Landing the runtime monkey-patch here guarantees correct behavior against any mlx_vlm version a user might have installed, including older releases that don't yet carry the upstream defensive copy. The mlx-vlm#966 fix (now merged) handles the forward path for fresh installs.

Good catch that this isn't Gemma 4-specific — any mlx_vlm attention implementation that captures cache.offset into a local variable before update_and_fetch() has the same race against BatchKVCache.offset being advanced mid-call. I'd been treating it as a Gemma 4 bug, but your framing as a general mlx_vlm pattern-level hazard is more accurate. A follow-up audit of the other mlx_vlm model attention implementations for the same pattern would be worth filing as a separate issue on Blaizzy/mlx-vlm, tagged for coordination.

Approving.

keegoid added a commit to keegoid/vllm-mlx that referenced this pull request Apr 10, 2026
keegoid added a commit to keegoid/vllm-mlx that referenced this pull request Apr 10, 2026
PR waybarrios#256 added a trim-before-merge fix for RotatingKVCache in the MLLM
continuous-batching path, but left the isinstance guard above it
unchanged to require KVCache specifically. RotatingKVCache is a sibling
of KVCache under _BaseCache in mlx_lm, not a subclass, so the guard
always fires before the trim logic can run.

Gemma 4 uses sliding-window attention and returns RotatingKVCache
natively, so --mllm --continuous-batching fails at first inference with:

  ValueError: MLLM continuous batching requires standard KVCache but
  got RotatingKVCache.

Relax the guard to accept both KVCache and RotatingKVCache.
QuantizedKVCache (the original rejection target of the guard) is still
rejected correctly because it's not in the isinstance tuple. The
downstream .merge() call works for both cache types — RotatingKVCache
produces a BatchRotatingKVCache, KVCache produces a BatchKVCache — both
of which the BatchedEngine handles.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@keegoid
Copy link
Copy Markdown

keegoid commented Apr 10, 2026

Thanks for this @janhilgard — the BatchKVCache offset snapshot and the extra EOS tokens from generation_config.json both nailed fixes I was hitting independently on mlx-community/gemma-4-31b-it-8bit.

While integrating this on top of #254 for a production pm_peng workload I hit one issue I wanted to flag: the RotatingKVCache trim logic added in this PR is unreachable on Gemma 4. In mllm_batch_generator.py the isinstance guard at line ~680 still rejects anything that isn't KVCache:

sample_cache = per_request_caches[0][0]
if not isinstance(sample_cache, KVCache):
    raise ValueError(
        f"MLLM continuous batching requires standard KVCache but got "
        f"{type(sample_cache).__name__}. Disable --kv-cache-quantization "
        f"when using multimodal models with --continuous-batching."
    )

# Fix: RotatingKVCache._update_concat does NOT trim on first call —
# ...
for rc in per_request_caches:
    for layer_cache in rc:
        if isinstance(layer_cache, RotatingKVCache):   # ← never reached
            ...

In mlx_lm/models/cache.py, RotatingKVCache is a sibling of KVCache under _BaseCache, not a subclass, so isinstance(rotating, KVCache) is False. Gemma 4 uses sliding-window attention natively, so every request through MLLMBatchGenerator._process_prompts ends up tripping the guard before the trim logic can run. First inference request against gemma-4-31b-it-8bit with --mllm --continuous-batching fails with:

ValueError: MLLM continuous batching requires standard KVCache but got RotatingKVCache.
Disable --kv-cache-quantization when using multimodal models with --continuous-batching.

Proposed one-line fix — relax the guard to accept both cache types. QuantizedKVCache (the guard's original rejection target) is still rejected correctly since it's not in the tuple:

if not isinstance(sample_cache, (KVCache, RotatingKVCache)):
    raise ValueError(
        f"MLLM continuous batching requires standard KVCache or "
        f"RotatingKVCache but got {type(sample_cache).__name__}. "
        f"Disable --kv-cache-quantization when using multimodal "
        f"models with --continuous-batching."
    )

After this one-line change on top of your PR + #254, I'm getting clean end-to-end behavior on gemma-4-31b-it-8bit with --mllm --continuous-batching --enable-auto-tool-choice --tool-call-parser gemma4 --reasoning-parser gemma4:

  • Deterministic prompt ("Reply with exactly: PM PENG OK") → exact reply with finish_reason: stop in ~0.9s
  • Tool-calling probe (get_weather with a Seattle query) → clean tool_calls emission with finish_reason: tool_calls and well-formed JSON arguments

The rest of your patch works exactly as described — the attention snapshot fix, the channel reasoning parser, and the extra EOS tokens all do their jobs. Just the guard above the trim loop needs relaxing.

Happy to open a separate PR for the one-liner if that's easier, or you can fold it into this PR — whichever you prefer. Reference commit on my fork: keegoid@42b85db

(Also — the mllm_batch_generator.py error-wrapping you added around _process_prompts in this PR was a lifesaver. Converting the ValueError into a clean finish_reason: error instead of a 500 made this trivial to diagnose from the client side. Nice touch.)

@Thump604
Copy link
Copy Markdown
Collaborator

I took @keegoid's report and packaged it as a minimal follow-up PR: #273.

What it does:

  • relaxes the _process_prompts() guard to accept RotatingKVCache alongside KVCache
  • adds a regression test that drives _process_prompts() with a real RotatingKVCache and asserts the merged cache is BatchRotatingKVCache

Validation on the follow-up branch:

  • python -m pytest tests/test_mllm_continuous_batching.py -q -> 24 passed, 3 deselected
  • python -m black --check --target-version py312 vllm_mlx/mllm_batch_generator.py tests/test_mllm_continuous_batching.py

If you'd rather keep the fix inside this PR, #273 should be easy to cherry-pick.

@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch from 047f523 to edcb33e Compare April 10, 2026 16:44
@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch 2 times, most recently from 19f64f9 to 6fab18d Compare April 10, 2026 20:59
@janhilgard
Copy link
Copy Markdown
Collaborator Author

@keegoid Great catch — you're absolutely right. The isinstance(sample_cache, KVCache) guard was blocking RotatingKVCache because it's a sibling class under _BaseCache, not a subclass of KVCache. Gemma 4 uses sliding window attention on 25/30 layers, so the very first cache entry is RotatingKVCache and the guard rejects it immediately.

I've folded the fix into this PR: the guard now accepts (KVCache, RotatingKVCache). Also removed _make_batch_cache() which was dead code with the same limitation.

@Thump604 Thanks for packaging the fix as #273 — since it's now included here, #273 can be closed if you prefer.

@janhilgard janhilgard force-pushed the fix/gemma4-batched-rope branch from 6fab18d to 5276bc2 Compare April 10, 2026 21:48
- Fix BatchKVCache offset bug: mx.array.__iadd__ mutates in-place,
  causing incorrect RoPE positions and token repetition
- Fix RotatingKVCache.max_size returning mx.array instead of int
- Add Gemma 4 reasoning parser (--reasoning-parser gemma4)
- Read additional EOS tokens from generation_config.json
- Fix RotatingKVCache prefix cache extraction (negative left_padding)
- Relax isinstance guard to accept RotatingKVCache for sliding window
  models like Gemma 4 (fixes ValueError on continuous batching)
- Remove unused _make_batch_cache() dead code
- Fix Anthropic endpoint JSON parsing for clients sending invalid
  escape sequences (e.g. \s, \d in regex patterns within tool defs)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@waybarrios waybarrios force-pushed the fix/gemma4-batched-rope branch from 5276bc2 to 6f0efc2 Compare April 10, 2026 21:52
@waybarrios
Copy link
Copy Markdown
Owner

waybarrios commented Apr 10, 2026

Hey @janhilgard I rebased this on main after merging #268. The overlapping files (reasoning parser, mllm_scheduler, patches, etc.) resolved cleanly since #268 already covered those.

What's left in this PR after the rebase is the stuff #268 didn't touch:

  • anthropic_models.py: thinking block support for the Anthropic endpoint
  • memory_cache.py: RotatingKVCache support in prefix cache trim
  • mllm_batch_generator.py: BatchRotatingKVCache extract fix for negative left_padding
  • server.py: reasoning extraction and thinking blocks in /v1/messages, streaming support

Tests pass, black is clean. This looks good to merge as the Anthropic API complement to #268.

Copy link
Copy Markdown
Owner

@waybarrios waybarrios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased on main, tests pass, black clean. The remaining changes (Anthropic thinking blocks, RotatingKVCache in memory_cache, BatchRotatingKVCache extract fix) are solid and complement #268 well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants