From 9f17ec3dc6f55d750d3e6885841819a389c43ef0 Mon Sep 17 00:00:00 2001 From: gaurav Date: Sun, 8 Feb 2026 13:12:49 +0530 Subject: [PATCH 1/2] Refactor Mamba cache to support ArraysCache and older mlx-lm versions Signed-off-by: gaurav --- tests/test_prefix_cache.py | 20 ++++++++-- vllm_metal/v1/model_runner.py | 75 +++++++++-------------------------- 2 files changed, 36 insertions(+), 59 deletions(-) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index eb22ca3a..17a245d7 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -9,10 +9,16 @@ import vllm_metal.v1.model_runner as mr -class StubMambaCache: +class StubArraysCache: + def __init__(self): + self.cache = [None, None] + + def __getitem__(self, idx): + return self.cache[idx] + @property def state(self): - return [] + return self.cache class TestPrefixCacheHybridGuard: @@ -32,7 +38,7 @@ def fake_make_prompt_cache(model): kv.keys = mx.zeros((1, 8, 0, 64)) kv.values = mx.zeros((1, 8, 0, 64)) kv.offset = 0 - return [kv, StubMambaCache(), kv] + return [kv, StubArraysCache(), kv] monkeypatch.setattr(mr, "make_prompt_cache", fake_make_prompt_cache) @@ -51,6 +57,10 @@ def fake_make_prompt_cache(model): frequency_penalty=0, presence_penalty=0, repetition_penalty=1.0, + logprobs=None, + seed=None, + logit_bias=None, + logits_processors=None, ) runner._prefill_single("req-1", token_ids, sampling_params) @@ -84,6 +94,10 @@ def fake_make_prompt_cache(model): frequency_penalty=0, presence_penalty=0, repetition_penalty=1.0, + logprobs=None, + seed=None, + logit_bias=None, + logits_processors=None, ) runner._prefill_single("req-1", token_ids, sampling_params) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 7084385e..2ee6c890 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -25,7 +25,6 @@ ArraysCache, BatchKVCache, KVCache, - MambaCache, make_prompt_cache, ) @@ -193,59 +192,12 @@ def get_stats(self) -> dict: # Type alias for any cache type supported by the model -AnyCache: TypeAlias = KVCache | MambaCache | ArraysCache - - -class BatchMambaCache: - """Batched cache for Mamba/SSM layers. - - Wraps multiple MambaCache instances into a single batched cache - for efficient batched forward passes on hybrid models. - """ - - def __init__(self, caches: list[MambaCache | ArraysCache]): - """Create a batched Mamba cache from individual caches. - - Args: - caches: List of MambaCache instances to batch - """ - self._batch_size = len(caches) - self._cache_size = len(caches[0].cache) if caches else 0 - - # Stack each state array across the batch dimension - self.cache: list[mx.array | None] = [] - for i in range(self._cache_size): - states = [c.cache[i] for c in caches] - if all(s is not None for s in states): - self.cache.append(mx.concatenate(states, axis=0)) - else: - self.cache.append(None) - - def __getitem__(self, idx: int) -> mx.array | None: - return self.cache[idx] - - def __setitem__(self, idx: int, value: mx.array | None) -> None: - self.cache[idx] = value - - def extract(self, idx: int) -> MambaCache: - """Extract a single request's cache from the batch. - - Args: - idx: Index of the request in the batch - - Returns: - MambaCache for the individual request - """ - cache = MambaCache() - for i in range(self._cache_size): - if self.cache[i] is not None: - cache.cache[i] = self.cache[i][idx : idx + 1] - return cache +AnyCache: TypeAlias = KVCache | ArraysCache | Any def _is_mamba_cache(cache: AnyCache) -> bool: - """Check if a cache is a Mamba-style cache (ArraysCache or MambaCache).""" - return isinstance(cache, (MambaCache, ArraysCache)) + """Check if a cache is a Mamba-style cache (has .cache attribute).""" + return hasattr(cache, "cache") and not isinstance(cache, (KVCache, BatchKVCache)) def _mlx_greedy_sample(logits: mx.array) -> mx.array: @@ -303,7 +255,7 @@ class RequestState: def _merge_kv_caches( caches_list: list[list[AnyCache]], -) -> list[BatchKVCache | BatchMambaCache]: +) -> list[BatchKVCache | ArraysCache]: """Merge multiple per-request caches into batched caches. Args: @@ -316,12 +268,12 @@ def _merge_kv_caches( return [] num_layers = len(caches_list[0]) - merged: list[BatchKVCache | BatchMambaCache] = [] + merged: list[BatchKVCache | ArraysCache] = [] for layer_idx in range(num_layers): layer_caches = [caches[layer_idx] for caches in caches_list] if _is_mamba_cache(layer_caches[0]): - batch_cache = BatchMambaCache(layer_caches) + batch_cache = ArraysCache.merge(layer_caches) else: batch_cache = BatchKVCache.merge(layer_caches) merged.append(batch_cache) @@ -330,7 +282,7 @@ def _merge_kv_caches( def _extract_kv_cache( - batch_caches: list[BatchKVCache | BatchMambaCache], idx: int + batch_caches: list[BatchKVCache | ArraysCache], idx: int ) -> list[AnyCache]: """Extract a single request's cache from batched caches. @@ -341,7 +293,18 @@ def _extract_kv_cache( Returns: List of caches for the request, one per layer """ - return [cache.extract(idx) for cache in batch_caches] + extracted: list[AnyCache] = [] + for cache in batch_caches: + if hasattr(cache, "extract"): + extracted.append(cache.extract(idx)) + elif hasattr(cache, "cache") and not isinstance(cache, BatchKVCache): + # Fallback for older ArraysCache/MambaCache versions where .extract is missing + new_cache = type(cache)(len(cache.cache)) + new_cache.cache = [c[idx : idx + 1] for c in cache.cache] + extracted.append(new_cache) + else: + raise TypeError(f"Unsupported cache type: {type(cache)}") + return extracted class MetalModelRunner: From d46730ba710294211a3d38e0e5207053b29c7665 Mon Sep 17 00:00:00 2001 From: gaurav Date: Tue, 10 Feb 2026 21:29:26 +0530 Subject: [PATCH 2/2] fix Importerror Signed-off-by: gaurav --- tests/test_platform.py | 4 ++-- vllm_metal/platform.py | 4 ++-- vllm_metal/v1/worker.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_platform.py b/tests/test_platform.py index 1187601e..e2b82321 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -5,8 +5,8 @@ import pytest import torch -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.selector import AttentionSelectorConfig +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.attention.selector import AttentionSelectorConfig from vllm_metal.platform import MetalPlatform diff --git a/vllm_metal/platform.py b/vllm_metal/platform.py index 4b55f07c..0a0db541 100644 --- a/vllm_metal/platform.py +++ b/vllm_metal/platform.py @@ -7,13 +7,13 @@ import psutil import torch -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum from vllm_metal.config import get_config if TYPE_CHECKING: - from vllm.attention.selector import AttentionSelectorConfig + from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig logger = logging.getLogger(__name__) diff --git a/vllm_metal/v1/worker.py b/vllm_metal/v1/worker.py index e4ee02bd..41b95d40 100644 --- a/vllm_metal/v1/worker.py +++ b/vllm_metal/v1/worker.py @@ -15,7 +15,7 @@ ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed +from vllm.utils.torch_utils import set_random_seed from vllm.tasks import SupportedTask from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec