diff --git a/tests/test_mllm_hybrid_cache.py b/tests/test_mllm_hybrid_cache.py new file mode 100644 index 000000000..6e11f6739 --- /dev/null +++ b/tests/test_mllm_hybrid_cache.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for MLLM continuous batching with hybrid model caches. + +Hybrid models (Qwen 3.5, Nemotron 3 Super) mix attention layers (KVCache) +with recurrent/SSM layers (ArraysCache). The MLLM batch generator must +handle both cache types during merge, filter, extract, and extend operations. +""" + +import pytest + +try: + import mlx.core as mx + from mlx_lm.models.cache import ( + ArraysCache, + BatchKVCache, + KVCache, + RotatingKVCache, + BatchRotatingKVCache, + ) + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Helpers — simulate Qwen 3.5 cache layout (12 KVCache + 36 ArraysCache) +# --------------------------------------------------------------------------- + + +def _make_hybrid_cache(n_kv=12, n_arrays=36, arrays_size=2): + """Create a hybrid cache list like Qwen 3.5's make_cache(). + + Qwen 3.5 layout: is_linear = (layer_idx + 1) % 4 != 0 + So layers 0,1,2 are ArraysCache, layer 3 is KVCache, etc. + For simplicity, we just create n_arrays ArraysCache + n_kv KVCache + interleaved with the real pattern. + """ + full_attention_interval = 4 + total = n_kv + n_arrays + cache = [] + for i in range(total): + is_linear = (i + 1) % full_attention_interval != 0 + if is_linear: + cache.append(ArraysCache(size=arrays_size)) + else: + cache.append(KVCache()) + return cache + + +def _populate_kv_cache( + cache: KVCache, seq_len: int, n_kv_heads: int = 4, head_dim: int = 8 +): + """Populate a KVCache with dummy data to simulate a completed prefill.""" + # KVCache.update_and_fetch expects 4D: (batch, n_kv_heads, seq_len, head_dim) + keys = mx.random.normal((1, n_kv_heads, seq_len, head_dim)) + values = mx.random.normal((1, n_kv_heads, seq_len, head_dim)) + cache.update_and_fetch(keys, values) + + +def _populate_arrays_cache( + cache: ArraysCache, batch_size: int = 1, state_dim: int = 16 +): + """Populate an ArraysCache with dummy SSM state.""" + for i in range(len(cache.cache)): + cache.cache[i] = mx.random.normal((batch_size, state_dim)) + + +def _make_populated_hybrid_cache( + seq_len: int = 10, n_kv_heads: int = 4, head_dim: int = 8, state_dim: int = 16 +): + """Create and populate a hybrid cache simulating a completed vision encoding prefill.""" + cache = _make_hybrid_cache() + for c in cache: + if isinstance(c, KVCache): + _populate_kv_cache(c, seq_len, n_kv_heads, head_dim) + elif isinstance(c, ArraysCache): + _populate_arrays_cache(c, batch_size=1, state_dim=state_dim) + return cache + + +# --------------------------------------------------------------------------- +# Test: _make_batch_cache handles all cache types +# --------------------------------------------------------------------------- + + +class TestMakeBatchCache: + """Test _make_batch_cache() with hybrid model caches.""" + + def test_hybrid_cache_creates_correct_types(self): + """_make_batch_cache returns BatchKVCache for KVCache layers, ArraysCache for ArraysCache layers.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + # Mock model with make_cache returning hybrid layout + class FakeModel: + def make_cache(self): + return _make_hybrid_cache() + + left_padding = [0, 2] # 2-request batch, different prompt lengths + batch_cache = _make_batch_cache(FakeModel(), left_padding) + + assert len(batch_cache) == 48 # 12 KV + 36 Arrays + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear: + # ArraysCache is returned as-is with left_padding set + assert isinstance(c, ArraysCache), f"Layer {i} should be ArraysCache" + assert c.left_padding is not None + else: + assert isinstance(c, BatchKVCache), f"Layer {i} should be BatchKVCache" + + def test_pure_kv_cache_still_works(self): + """Regression: pure attention models (all KVCache) still work.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [KVCache() for _ in range(24)] + + batch_cache = _make_batch_cache(FakeModel(), [0, 1]) + assert all(isinstance(c, BatchKVCache) for c in batch_cache) + + def test_pure_arrays_cache_works(self): + """Pure SSM models (all ArraysCache) work.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [ArraysCache(size=2) for _ in range(24)] + + batch_cache = _make_batch_cache(FakeModel(), [0, 1]) + assert all(isinstance(c, ArraysCache) for c in batch_cache) + + def test_rotating_kv_cache_works(self): + """RotatingKVCache (keep=0) gets wrapped in BatchRotatingKVCache.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [RotatingKVCache(max_size=1024, keep=0) for _ in range(4)] + + batch_cache = _make_batch_cache(FakeModel(), [0]) + assert all(isinstance(c, BatchRotatingKVCache) for c in batch_cache) + + def test_rotating_kv_cache_with_keep_rejected(self): + """RotatingKVCache with keep > 0 is rejected.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [RotatingKVCache(max_size=1024, keep=4)] + + with pytest.raises(ValueError, match="keep tokens is not supported"): + _make_batch_cache(FakeModel(), [0]) + + def test_unsupported_cache_type_rejected(self): + """Cache types without batching support are rejected with clear error.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class UnsupportedCache: + pass + + class FakeModel: + def make_cache(self): + return [UnsupportedCache()] + + with pytest.raises(ValueError, match="does not support"): + _make_batch_cache(FakeModel(), [0]) + + +# --------------------------------------------------------------------------- +# Test: Merge loop works with mixed cache types +# --------------------------------------------------------------------------- + + +class TestHybridCacheMerge: + """Test the per-layer merge loop from _process_prompts.""" + + def test_merge_hybrid_per_request_caches(self): + """Merging per-request hybrid caches produces correct batched types.""" + # Simulate 2 requests, each with a populated hybrid cache + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + + # This is the exact merge loop from _process_prompts (lines 679-685) + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + assert len(batch_cache) == 48 + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear: + assert isinstance( + c, ArraysCache + ), f"Layer {i}: merged ArraysCache should stay ArraysCache" + # Merged arrays should have batch dimension = 2 + for arr in c.cache: + if arr is not None: + assert arr.shape[0] == 2, f"Layer {i}: batch dim should be 2" + else: + assert isinstance( + c, BatchKVCache + ), f"Layer {i}: merged KVCache should become BatchKVCache" + + def test_merge_single_request(self): + """Single-request merge works (degenerate case).""" + caches = [_make_populated_hybrid_cache(seq_len=10)] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + assert len(batch_cache) == 48 + + def test_type_guard_rejects_unmergeable_cache(self): + """The capability check rejects caches without merge().""" + from vllm_mlx.mllm_batch_generator import _validate_caches_mergeable + + class NoMergeCache: + pass + + per_request_caches = [[NoMergeCache(), KVCache()]] + with pytest.raises(ValueError, match="lacks a merge"): + _validate_caches_mergeable(per_request_caches) + + +# --------------------------------------------------------------------------- +# Test: Filter on merged batch +# --------------------------------------------------------------------------- + + +class TestHybridCacheFilter: + """Test filter() on merged hybrid batches.""" + + def test_filter_keeps_correct_requests(self): + """Filter on merged batch keeps correct batch elements for both cache types.""" + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + _make_populated_hybrid_cache(seq_len=12), + ] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + # Keep only request 0 and 2 + keep_idx = mx.array([0, 2], mx.int32) + for c in batch_cache: + if hasattr(c, "filter"): + c.filter(keep_idx) + + # Verify batch dimension is now 2 + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear and isinstance(c, ArraysCache): + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 2 + ), f"Layer {i}: filtered batch dim should be 2" + + +# --------------------------------------------------------------------------- +# Test: Extract from merged batch +# --------------------------------------------------------------------------- + + +class TestHybridCacheExtract: + """Test extract() on merged hybrid batches.""" + + def test_extract_returns_correct_types(self): + """Extracting a single request returns correct unbatched types.""" + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + # Extract request 0 + extracted = [ + c.extract(0) if hasattr(c, "extract") else None for c in batch_cache + ] + + for i, c in enumerate(extracted): + is_linear = (i + 1) % 4 != 0 + if c is None: + continue + if is_linear: + assert isinstance( + c, ArraysCache + ), f"Layer {i}: extracted should be ArraysCache" + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 1 + ), f"Layer {i}: extracted batch dim should be 1" + else: + assert isinstance(c, KVCache), f"Layer {i}: extracted should be KVCache" + + +# --------------------------------------------------------------------------- +# Test: Extend merged batches +# --------------------------------------------------------------------------- + + +class TestHybridCacheExtend: + """Test extend() combining two merged hybrid batches.""" + + def test_extend_combines_batches(self): + """Extending one merged batch with another works for both cache types.""" + caches_a = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + caches_b = [ + _make_populated_hybrid_cache(seq_len=12), + ] + batch_a = [ + caches_a[0][layer_idx].merge([c[layer_idx] for c in caches_a]) + for layer_idx in range(len(caches_a[0])) + ] + batch_b = [ + caches_b[0][layer_idx].merge([c[layer_idx] for c in caches_b]) + for layer_idx in range(len(caches_b[0])) + ] + + # Extend batch_a with batch_b + for c, o in zip(batch_a, batch_b): + if c is not None and o is not None and hasattr(c, "extend"): + if not c.empty() and not o.empty(): + c.extend(o) + + # Verify combined batch has 3 elements + for i, c in enumerate(batch_a): + is_linear = (i + 1) % 4 != 0 + if is_linear and isinstance(c, ArraysCache): + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 3 + ), f"Layer {i}: extended batch dim should be 3" + + +# --------------------------------------------------------------------------- +# Test: Message normalization +# --------------------------------------------------------------------------- + + +class TestNormalizeMessages: + """Test _normalize_messages() for handling real-world client formats.""" + + def test_merge_consecutive_system_messages(self): + """Consecutive system messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "Always respond in JSON."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "helpful assistant" in result[0]["content"] + assert "JSON" in result[0]["content"] + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_merge_consecutive_user_messages(self): + """Consecutive user messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "First part"}, + {"role": "user", "content": "Second part"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[1]["role"] == "user" + assert "First part" in result[1]["content"] + assert "Second part" in result[1]["content"] + + def test_opencode_format(self): + """OpenCode's system+system+user+user format is normalized.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "System prompt part 1"}, + {"role": "system", "content": "System prompt part 2"}, + {"role": "user", "content": "User instruction"}, + {"role": "user", "content": "User question"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_already_alternating_unchanged(self): + """Well-formed alternating messages pass through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Bye"}, + ] + result = _normalize_messages(messages) + assert result == messages + + def test_single_message_unchanged(self): + """Single message passes through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [{"role": "user", "content": "Hello"}] + result = _normalize_messages(messages) + assert result == messages + + def test_empty_messages(self): + """Empty message list passes through.""" + from vllm_mlx.server import _normalize_messages + + assert _normalize_messages([]) == [] + + def test_multimodal_content_preserved(self): + """Messages with list content (multimodal) are preserved during merge.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "user", "content": "Describe this:"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/img.png"}, + }, + ], + }, + ] + result = _normalize_messages(messages) + # When one message has list content and previous has string, + # they can't be trivially merged — keep them or convert + # At minimum, no crash + assert len(result) >= 1 + + def test_preserves_non_content_fields(self): + """Fields other than role/content are preserved.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1", "name": "sys1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + # First message retains extra fields from first of the merged group + assert result[0]["role"] == "system" + + def test_null_content_not_merged(self): + """Messages with None content (tool_calls pattern) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "assistant", "content": "Follow-up"}, + ] + result = _normalize_messages(messages) + # None content can't be merged with string — kept separate + assert len(result) == 2 + + def test_three_consecutive_system_messages(self): + """Three consecutive system messages merge into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "system", "content": "Part 3"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert "Part 1" in result[0]["content"] + assert "Part 3" in result[0]["content"] + + +# --------------------------------------------------------------------------- +# Test: Empty cache extend guard +# --------------------------------------------------------------------------- + + +class TestEmptyCacheExtend: + """Test the empty() guard in extend prevents crashes on unpopulated caches.""" + + def test_extend_skips_empty_caches(self): + """Extending when one cache is empty does not crash.""" + populated = _make_populated_hybrid_cache(seq_len=10) + + # Merge populated into single-request batch + batch_pop = [populated[i].merge([populated[i]]) for i in range(len(populated))] + + # Create empty caches directly (don't merge — merge() can't handle all-None) + batch_empty = _make_hybrid_cache() + + # Extend should not crash — empty guard should skip + for c, o in zip(batch_pop, batch_empty): + if c is not None and o is not None and hasattr(c, "extend"): + if not c.empty() and not o.empty(): + c.extend(o) + # If we get here without crash, test passes diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..e7f5f0718 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -30,6 +30,21 @@ logger = logging.getLogger(__name__) +def _validate_caches_mergeable(per_request_caches: List[List[Any]]) -> None: + """Validate that all cache layers support merge() for batch creation. + + Raises ValueError if any layer lacks a merge() method (e.g. QuantizedKVCache). + Called before the merge loop in _process_prompts(). + """ + for layer_idx, layer_cache in enumerate(per_request_caches[0]): + if not hasattr(layer_cache, "merge"): + raise ValueError( + f"MLLM continuous batching requires mergeable cache types " + f"but layer {layer_idx} has {type(layer_cache).__name__} " + f"which lacks a merge() method." + ) + + @dataclass class MLLMBatchRequest: """ @@ -139,20 +154,17 @@ def extend(self, other: "MLLMBatch") -> None: self.max_tokens.extend(other.max_tokens) self.requests.extend(other.requests) - # Extend cache - handle None and incompatible caches + # Extend cache - each cache type's extend() handles its own validation. + # Uses empty() (universal via _BaseCache) instead of checking .keys + # (KVCache-specific). This supports hybrid models with ArraysCache + # layers that use .cache instead of .keys/.values. for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): try: - # Only extend if both caches have valid keys - if ( - hasattr(c, "keys") - and c.keys is not None - and hasattr(o, "keys") - and o.keys is not None - ): + if not c.empty() and not o.empty(): c.extend(o) except Exception as e: - logger.warning(f"Failed to extend cache: {e}") + logger.warning(f"Failed to extend cache layer: {e}") def extract_cache(self, idx: int) -> List[Any]: """ @@ -207,22 +219,52 @@ def to_dict(self) -> Dict[str, Any]: def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]: """ - Create batch-aware KV cache for the language model. + Create batch-aware cache for the language model. + + Handles all cache types from hybrid models: + - KVCache → BatchKVCache (attention layers) + - ArraysCache → ArraysCache with left_padding (SSM/recurrent layers) + - RotatingKVCache → BatchRotatingKVCache + - CacheList → recursive conversion Args: model: The language model (model.language_model from VLM) left_padding: Padding amounts for left-padded prompts Returns: - List of BatchKVCache objects for each layer + List of batch-aware cache objects for each layer """ - from mlx_lm.models.cache import BatchKVCache, KVCache + from mlx_lm.models.cache import ( + ArraysCache, + BatchKVCache, + BatchRotatingKVCache, + CacheList, + KVCache, + RotatingKVCache, + ) def to_batch_cache(c): - if isinstance(c, KVCache): + # Strict type identity for KVCache — avoid catching QuantizedKVCache + if type(c) is KVCache: return BatchKVCache(left_padding) + elif isinstance(c, ArraysCache): + # ArraysCache handles batching natively — just set left_padding + c.left_padding = mx.array(left_padding) + return c + elif isinstance(c, RotatingKVCache): + if c.keep > 0: + raise ValueError( + "RotatingKVCache with keep tokens is not supported " + "in MLLM continuous batching." + ) + return BatchRotatingKVCache(c.max_size, left_padding) + elif isinstance(c, CacheList): + return CacheList(*(to_batch_cache(sub_c) for sub_c in c.caches)) else: - raise ValueError(f"{type(c)} does not yet support batching") + raise ValueError( + f"MLLM continuous batching does not support {type(c).__name__}. " + f"Supported: KVCache, ArraysCache, RotatingKVCache, CacheList." + ) if hasattr(model, "make_cache"): cache = model.make_cache() @@ -324,6 +366,11 @@ def __init__( "MLLMBatchGenerator: Model does not have language_model, using model directly" ) + # Patch Qwen3.5 attention for BatchKVCache compatibility + from .patches.qwen3_5_mllm import patch_qwen35_attention_for_batching + + patch_qwen35_attention_for_batching() + self.max_tokens = max_tokens self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -340,6 +387,9 @@ def __init__( # Statistics self._stats = MLLMBatchStats() + # Error responses for requests that failed during preprocessing + self._pending_error_responses: List[MLLMBatchResponse] = [] + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -613,23 +663,48 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: tic = time.perf_counter() - # Preprocess all requests + # Preprocess all requests (per-request error handling) + failed_requests = [] for req in requests: - self._preprocess_request(req) + try: + self._preprocess_request(req) + except Exception as e: + logger.error( + f"Failed to preprocess request {req.request_id}: " + f"{type(e).__name__}: {e}" + ) + failed_requests.append(req) + + # Remove failed requests and create error responses + if failed_requests: + for req in failed_requests: + requests.remove(req) + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + if not requests: + return None total_prompt_tokens = sum( req.input_ids.size if req.input_ids is not None else 1 for req in requests ) self._stats.prompt_tokens += total_prompt_tokens - # Guard against excessive memory usage during cache merge. - # Each token in the batch requires KV entries across all layers. + # Log large prompts for monitoring (was previously a hard error that + # could cause infinite retry loops). max_batch_tokens = self.prefill_step_size * len(requests) if total_prompt_tokens > max_batch_tokens: - raise ValueError( - f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit " - f"({max_batch_tokens}) for {len(requests)} requests. " - f"Reduce prompt length or batch size." + logger.warning( + f"Large batch prefill: {total_prompt_tokens} tokens " + f"(step_size={self.prefill_step_size}, requests={len(requests)}). " + f"Processing may be slow." ) # Run vision encoding for each request with its own KVCache. @@ -662,19 +737,11 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: per_request_caches.append(request_cache) - # Merge per-request KVCaches into a single BatchKVCache. - # KVCache.merge() creates a BatchKVCache with proper left-padding - # alignment, so all requests share a single batched cache for - # subsequent generation steps. - from mlx_lm.models.cache import KVCache - - sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): - raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." - ) + # Merge per-request caches into a single batched cache. + # Each cache type's merge() returns the correct batched representation: + # KVCache.merge() → BatchKVCache, ArraysCache.merge() → batched ArraysCache. + # This supports hybrid models mixing attention + SSM layers. + _validate_caches_mergeable(per_request_caches) try: batch_cache = [ @@ -769,10 +836,16 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = new_batch prompt_processing = True + # Collect any pending error responses (from failed preprocessing) + error_responses = [] + if self._pending_error_responses: + error_responses = list(self._pending_error_responses) + self._pending_error_responses.clear() + # Generate next token for active batch batch = self.active_batch if batch is None: - return [] + return error_responses y, logprobs = batch.y, batch.logprobs batch.y, batch.logprobs = self._step(y[:, None], batch.cache) @@ -841,7 +914,7 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None self._stats.generation_tokens += len(responses) - return responses + return error_responses + responses def next(self) -> List[MLLMBatchResponse]: """ diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 80c9f39c2..99657a508 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -29,6 +29,8 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer + from .mllm_batch_generator import ( MLLMBatchGenerator, MLLMBatchRequest, @@ -198,6 +200,9 @@ def __init__( self.request_id_to_uid: Dict[str, int] = {} self.uid_to_request_id: Dict[int, str] = {} + # Per-request streaming detokenizers for UTF-8-safe incremental decode + self._detokenizer_pool: Dict[str, Any] = {} + # Output queues for async streaming self.output_queues: Dict[str, asyncio.Queue] = {} @@ -447,12 +452,42 @@ def _process_batch_responses( if request is None: continue + # Handle error responses from failed preprocessing + if response.finish_reason == "error": + output = RequestOutput( + request_id=request_id, + new_token_ids=[], + new_text="", + output_token_ids=[], + prompt_tokens=0, + completion_tokens=0, + finished=True, + finish_reason="error", + ) + request.status = RequestStatus.FINISHED_ABORTED + request.output_text = "" + request.finish_reason = "error" + finished_ids.add(request_id) + self.num_requests_processed += 1 + logger.warning(f"Request {request_id} failed during preprocessing") + outputs.append(output) + continue + # Append token to request request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) - # Decode the new token - new_text = tokenizer.decode([response.token]) + # Decode the new token using streaming detokenizer (UTF-8 safe) + if request_id not in self._detokenizer_pool: + if hasattr(tokenizer, "detokenizer"): + detok = tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + detok = self._detokenizer_pool[request_id] + detok.add_token(response.token) + new_text = detok.last_segment # Create output output = RequestOutput( @@ -475,8 +510,13 @@ def _process_batch_responses( output.finish_reason = response.finish_reason finished_ids.add(request_id) - # Decode full output - output.output_text = tokenizer.decode(request.output_tokens) + # Finalize streaming detokenizer and get full output + detok = self._detokenizer_pool.pop(request_id, None) + if detok is not None: + detok.finalize() + output.output_text = detok.text + else: + output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = response.finish_reason diff --git a/vllm_mlx/patches/qwen3_5_mllm.py b/vllm_mlx/patches/qwen3_5_mllm.py new file mode 100644 index 000000000..c592928da --- /dev/null +++ b/vllm_mlx/patches/qwen3_5_mllm.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Qwen3.5 attention to support BatchKVCache. + +mlx-vlm's Qwen3_5Attention uses cache.offset directly for kv_seq_len +computation and mask slicing. BatchKVCache stores offset as mx.array +(per-batch-item), not int, causing: + + mask = mask[..., :kv_seq_len] + ValueError: Slice indices must be integers or None. + +This patch replaces Qwen3_5Attention.__call__ with a version that +converts cache.offset to int before using it for arithmetic/slicing, +while leaving the actual cache.offset untouched so update_and_fetch +still works correctly with per-batch offsets. +""" + +import logging +from typing import Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _cache_offset_to_int(cache) -> int: + """Extract cache offset as int, handling BatchKVCache mx.array offset.""" + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return int(off.max().item()) if off.ndim > 0 else int(off.item()) + return int(off) + + +def patch_qwen35_attention_for_batching() -> bool: + """Monkey-patch Qwen3_5Attention.__call__ to handle BatchKVCache. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Qwen3.5 module not available. + """ + try: + from mlx_vlm.models.qwen3_5.language import ( + Qwen3_5Attention, + apply_multimodal_rotary_pos_emb, + ) + from mlx_lm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Qwen3.5 patch] mlx-vlm Qwen3.5 module not available") + return False + + if getattr(Qwen3_5Attention, "_batch_patched", False): + logger.debug("[Qwen3.5 patch] Already patched") + return True + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache=None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1), + 2, + axis=-1, + ) + gate = gate.reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + kv_seq_len = keys.shape[-2] + + # Convert cache.offset to int for slice compatibility. + # BatchKVCache stores offset as mx.array (per-batch-item), + # but kv_seq_len must be int for mask[..., :kv_seq_len]. + _offset = _cache_offset_to_int(cache) + + if position_ids is None: + kv_seq_len += _offset + 1 + position_ids = mx.arange(_offset, _offset + L) + position_ids = mx.expand_dims(position_ids, axis=0) + position_ids = mx.tile(position_ids, (3, 1, 1)) + else: + kv_seq_len += _offset + 1 if cache is not None else 0 + + cos, sin = self.rotary_emb(values, position_ids) + + if mask is not None and isinstance(mask, mx.array): + mask = mask[..., :kv_seq_len] + + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output * mx.sigmoid(gate)) + + Qwen3_5Attention.__call__ = _patched_call + Qwen3_5Attention._batch_patched = True + logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f8..930aa024b 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1365,12 +1365,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re messages.append(msg_dict) images, videos = [], [] # MLLM extracts these from messages logger.debug(f"MLLM: Processing {len(messages)} messages") + messages = _normalize_messages(messages) else: # For LLM, extract text, images, and videos separately messages, images, videos = extract_multimodal_content( request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) has_media = bool(images or videos) if engine.is_mllm and not has_media: @@ -1496,6 +1498,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` → ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + New list with normalized roles and consecutive same-role messages merged. + """ + # OpenAI Responses API uses "developer" instead of "system". + # Map it so chat templates don't fail and fall back to raw prefill. + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + # Merge string content with double newline separator + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP) + merged_count = len(messages) - len(merged) + if mapped_roles or merged_count: + parts = [] + if mapped_roles: + parts.append(f"mapped {mapped_roles} role(s)") + if merged_count: + parts.append(f"merged {len(messages)} → {len(merged)}") + logger.info(f"Normalized messages: {', '.join(parts)}") + + return merged + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. @@ -1593,6 +1653,7 @@ async def create_anthropic_message( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, @@ -1750,6 +1811,7 @@ async def _stream_anthropic_messages( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..bb96f69cf 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -28,6 +28,56 @@ def _needs_tokenizer_fallback(model_name: str) -> bool: return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS) +def _needs_strict_false(model_name: str) -> bool: + """Check if model needs strict=False loading (VLM models with extra weights). + + VLM models (e.g., Qwen3.5) have vision_tower weights that don't match + the text-only model class. Loading with strict=True fails and wastes + memory by loading all weights (~100 GB) before raising ValueError. + Detect these models up-front to avoid the double-load penalty. + """ + from mlx_lm.utils import _download, load_config + + try: + model_path = _download(model_name) + config = load_config(model_path) + except Exception: + return False + if "vision_config" in config and "text_config" in config: + return True + return False + + +def _load_strict_false(model_name: str, tokenizer_config: dict = None): + """Load model with strict=False to discard extra weights. + + Handles models with extra parameters that the text-only model class + doesn't define (e.g., vision tower weights in VLM models, MTP layers). + """ + import mlx.core as mx + from mlx_lm.utils import _download, load_model, load_tokenizer + + model_path = _download(model_name) + model, config = load_model(model_path, strict=False) + + from mlx.utils import tree_flatten + + params = tree_flatten(model.parameters()) + total_params = len(params) + zero_params = sum(1 for _, v in params if mx.all(v == 0).item()) + logger.info( + f"[strict=False] Loaded {total_params} parameters, " + f"{zero_params} all-zero tensors" + ) + + tokenizer = load_tokenizer( + model_path, + tokenizer_config or {}, + eos_token_ids=config.get("eos_token_id", None), + ) + return model, tokenizer + + def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): """ Load model and tokenizer with fallback for non-standard tokenizers. @@ -50,6 +100,11 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): ) return _load_with_tokenizer_fallback(model_name) + # VLM models: skip strict=True attempt to avoid double-loading ~100GB weights + if _needs_strict_false(model_name): + logger.info(f"Model {model_name} detected as VLM, loading with strict=False") + return _load_strict_false(model_name, tokenizer_config) + try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) except ValueError as e: @@ -57,79 +112,24 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e): logger.warning(f"Standard tokenizer loading failed, using fallback: {e}") return _load_with_tokenizer_fallback(model_name) - # Fallback for models with extra weights (e.g., vision tower, MTP layers). - # Retry with strict=False to discard extra weights. - if "parameters not in model" in str(e): + elif "parameters not in model" in str(e): logger.warning( - f"Extra parameters found (e.g., vision tower / MTP weights), " - f"retrying with strict=False: {e}" + "Extra parameters found (e.g., MTP/vision weights), " + "retrying with strict=False" ) - return _load_strict_false(model_name, tokenizer_config) - raise - + # Clear traceback references to free memory from the failed load + e.__traceback__ = None + del e + import gc -def _load_strict_false(model_name: str, tokenizer_config: dict = None): - """Load model with strict=False to discard extra weights (e.g., vision tower, MTP).""" - from mlx_lm.utils import load_model, load_tokenizer - - local_path = Path(model_name) - if local_path.is_dir(): - model_path = local_path - else: - from huggingface_hub import snapshot_download - - model_path = Path(snapshot_download(model_name)) + gc.collect() + return _load_strict_false(model_name, tokenizer_config) + else: + raise - model, config = load_model(model_path, strict=False) - tokenizer = load_tokenizer( - model_path, - tokenizer_config or {}, - eos_token_ids=config.get("eos_token_id", None), - ) - # Inject MTP support if model has MTP config + weights - _try_inject_mtp(model, model_path, config) return model, tokenizer -def _try_inject_mtp(model, model_path, config): - """Inject MTP support if model has MTP config + weights.""" - if config.get("num_nextn_predict_layers", 0) > 0: - from ..patches.qwen3_next_mtp import inject_mtp_support - - inject_mtp_support(model, model_path, config) - - -def _try_inject_mtp_post_load(model, model_name): - """Check if MTP weights exist but were stripped by sanitize(), and inject.""" - import json - - from mlx_lm.utils import _download - - model_path = _download(model_name) - config_path = Path(model_path) / "config.json" - if not config_path.exists(): - return - with open(config_path) as f: - config = json.load(f) - # Also check text_config for nested configs - num_mtp = config.get("num_nextn_predict_layers", 0) - if num_mtp == 0: - text_config = config.get("text_config", {}) - num_mtp = text_config.get("num_nextn_predict_layers", 0) - if num_mtp > 0 and getattr(model, "mtp", None) is None: - mtp_file = Path(model_path) / "model-mtp.safetensors" - if mtp_file.exists(): - logger.info( - f"[MTP] Found MTP config (layers={num_mtp}) and weights, injecting..." - ) - _try_inject_mtp(model, model_path, config) - else: - logger.info( - f"[MTP] Config has num_nextn_predict_layers={num_mtp} " - "but model-mtp.safetensors not found, skipping MTP." - ) - - def _load_with_tokenizer_fallback(model_name: str): """Load model with fallback tokenizer for non-standard models like Nemotron.""" from mlx_lm.utils import load_model