Skip to content

fix: MLLM continuous batching — system prompt, routing, and KV cache#76

Merged
waybarrios merged 4 commits intowaybarrios:mainfrom
camerhann:fix/mllm-continuous-batching
Feb 13, 2026
Merged

fix: MLLM continuous batching — system prompt, routing, and KV cache#76
waybarrios merged 4 commits intowaybarrios:mainfrom
camerhann:fix/mllm-continuous-batching

Conversation

@camerhann
Copy link
Copy Markdown
Contributor

Summary

Fixes three related bugs that caused multimodal models (e.g. Qwen2.5-VL, Qwen3-VL) to produce garbage output when running with --continuous-batching:

  • Chat template drops conversation history: _apply_chat_template used mlx_vlm.prompt_utils.apply_chat_template which only extracted the last user message text, discarding system prompts and all prior turns. Fixed by using the processor/tokenizer's apply_chat_template with the full message list.
  • MLLM text-only requests crash: generate() and stream_generate() only routed to _mllm_scheduler when images/videos were present, but _engine is None for MLLM models. Fixed by routing all requests through _mllm_scheduler when the model is multimodal.
  • KV cache not transferred from VLM prefill: _run_vision_encoding was called without a cache argument, so prefill KV state was discarded. The code then tried to copy state via BatchKVCache.insert_single which doesn't exist. Fixed by passing per-request KVCache objects to the VLM forward pass, then merging them into a BatchKVCache via KVCache.merge().

Details

Bug 1: Chat template (batched.py)

The MLLM path called mlx_vlm.prompt_utils.apply_chat_template(processor, config, text_prompt) with only the extracted last-user-message text. This function is designed for single-turn VLM inference and doesn't accept a messages list. The fix uses processor.apply_chat_template(messages, ...) (the HuggingFace standard) which preserves system prompts, assistant turns, and multi-turn history. Also removed hardcoded enable_thinking=True which isn't supported by all model templates.

Bug 2: MLLM routing (batched.py)

For MLLM models, only _mllm_scheduler is initialised (not _engine). The condition if self._is_mllm and self._mllm_scheduler and (images or videos) meant text-only chat requests fell through to self._engine.add_request() which is None, causing an AttributeError. Removed the (images or videos) guard so all requests route through the MLLM scheduler when the model is multimodal.

Bug 3: KV cache (mllm_batch_generator.py)

_process_prompts created an empty BatchKVCache upfront, ran VLM encoding without passing any cache, then attempted to extract KV state from the model's internal layer caches. This failed because (a) the VLM discards its internal cache after the forward pass without a cache= argument, and (b) BatchKVCache doesn't have an insert_single method. The fix:

  1. Creates a per-request KVCache (from mlx_lm.models.cache.make_prompt_cache)
  2. Passes it to _run_vision_encoding(req, cache=request_cache) — the VLM model's __call__ passes cache= through to self.language_model()
  3. After all requests are prefilled, merges per-request caches via KVCache.merge()BatchKVCache with proper left-padding alignment

Test plan

Tested on Mac Studio M3 Ultra (96GB):

  • Qwen2.5-VL-32B-Instruct-4bit with --continuous-batching — correct output
  • Qwen3-VL-30B-A3B-Instruct-4bit with --continuous-batching — correct output
  • System prompt retention (e.g. "Always mention TAN15 policy" → model follows instruction)
  • Multi-turn conversation history (e.g. "My name is Chris" → "What is my name?" → "Chris")
  • Concurrent request batching (3 simultaneous requests, all correct, <1.5s total)
  • Text-only requests to MLLM models (no crash, correct output)

🤖 Generated with Claude Code

ccameronhann and others added 4 commits February 12, 2026 16:17
Three related bugs caused multimodal models (e.g. Qwen2.5-VL, Qwen3-VL)
to produce garbage output when running with --continuous-batching:

1. Chat template drops system prompt and conversation history
   _apply_chat_template used mlx_vlm.prompt_utils.apply_chat_template
   which only extracted the last user message text, discarding system
   prompts and all prior conversation turns. Fixed by using the
   processor's (or tokenizer's) apply_chat_template with the full
   message list. Also removed hardcoded enable_thinking=True which
   caused issues with non-thinking models.

2. MLLM text-only requests crash with NoneType error
   generate() and stream_generate() only routed to _mllm_scheduler
   when images or videos were present, but for MLLM models _engine
   is None (only _mllm_scheduler is initialised). Text-only requests
   to MLLM models fell through to self._engine which is None.
   Fixed by routing all requests through _mllm_scheduler when the
   model is multimodal.

3. KV cache from VLM prefill not transferred to BatchKVCache
   _process_prompts called _run_vision_encoding without passing a
   cache, so the VLM's language model created temporary internal
   caches that were discarded. The code then tried to transfer KV
   state from the model's internal layer caches to a pre-created
   BatchKVCache, but BatchKVCache.insert_single doesn't exist.
   Fixed by:
   - Passing a per-request KVCache to _run_vision_encoding, which
     flows through to the VLM's language_model(cache=...) call
   - Using KVCache.merge() to combine per-request caches into a
     properly aligned BatchKVCache for generation

Tested with Qwen2.5-VL-32B and Qwen3-VL-30B — both produce correct
output with --continuous-batching, including system prompt retention,
multi-turn conversation history, and concurrent request batching.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
Add a guard to catch QuantizedKVCache early since it does not support
merge and would crash at runtime. Wrap the cache merge in try-except
with proper logging so failures are not silent. Validate input types
in prepare_mllm_messages to reject malformed messages and limit total
prompt tokens before merging to prevent memory exhaustion.

Also fix missing type annotation on the cache parameter, log the
TypeError fallback in chat template application, complete the
prepare_mllm_messages docstring, and update several stale docstrings
that still described the old hybrid routing instead of the current
MLLMScheduler approach.
@waybarrios
Copy link
Copy Markdown
Owner

Pushed two commits with some hardening fixes on top of your changes.

The most important one is a guard for QuantizedKVCache. The merge() method only exists on KVCache, so if someone enables kv-cache-quantization with an MLLM model and continuous batching, it would crash with an AttributeError at runtime. Now it raises a clear ValueError telling them to disable that flag.

I also wrapped the cache merge in a try-except with logging so any unexpected failures get surfaced instead of silently breaking, and added a prompt token limit check before the merge to prevent memory exhaustion from oversized batches.

On the input validation side, _prepare_mllm_messages now skips non-dict messages and filters out content parts that are not dicts or strings, which avoids passing unexpected types to the processor.

The rest is smaller stuff: added a type annotation on the cache parameter, logged the TypeError fallback in the chat template path instead of swallowing it silently, completed the _prepare_mllm_messages docstring, and updated a few stale docstrings that still referenced the old hybrid routing instead of the current MLLMScheduler approach.

@waybarrios waybarrios merged commit 973b695 into waybarrios:main Feb 13, 2026
7 checks passed
sooth pushed a commit to sooth/vllm-mlx that referenced this pull request Feb 27, 2026
Merge 17 upstream commits including:
- KV cache quantization for prefix cache memory reduction (waybarrios#62)
- Streaming tool call parsing via ToolParser integration (waybarrios#46)
- MTP speculative decoding for Qwen3-Next (waybarrios#82)
- GPT-OSS reasoning parser and Harmony format parsers
- mlx-lm >= 0.30.5 requirement, transformers >= 5.0.0
- BatchMambaCache fix for mlx-lm >= 0.30.6 (waybarrios#89)
- MLLM continuous batching fixes (waybarrios#76)
- Force MLLM mode option (waybarrios#81)
- Various bug fixes

Conflict resolution:
- server.py: Replaced local tool_call_buffering with upstream's
  ToolParser-based streaming (more robust)
- cli.py: Deduplicated --mllm, --default-temperature, --default-top-p
  args (upstream already added them), kept local --embedding-model
- mamba_cache.py: Took upstream's conditional HAS_MAMBA_CACHE approach
- pyproject.toml: Took upstream's version and dependency changes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants