Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import pytest
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import AttentionSelectorConfig
Comment on lines +8 to +9
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm.v1.attention.* imports in vllm_metal/platform.py and tests/test_platform.py do not exist in vLLM 0.13.0; it only exists in 0.14.0?


from vllm_metal.platform import MetalPlatform

Expand Down
20 changes: 17 additions & 3 deletions tests/test_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
import vllm_metal.v1.model_runner as mr


class StubMambaCache:
class StubArraysCache:
def __init__(self):
self.cache = [None, None]

def __getitem__(self, idx):
return self.cache[idx]

@property
def state(self):
return []
return self.cache


class TestPrefixCacheHybridGuard:
Expand All @@ -32,7 +38,7 @@ def fake_make_prompt_cache(model):
kv.keys = mx.zeros((1, 8, 0, 64))
kv.values = mx.zeros((1, 8, 0, 64))
kv.offset = 0
return [kv, StubMambaCache(), kv]
return [kv, StubArraysCache(), kv]

monkeypatch.setattr(mr, "make_prompt_cache", fake_make_prompt_cache)

Expand All @@ -51,6 +57,10 @@ def fake_make_prompt_cache(model):
frequency_penalty=0,
presence_penalty=0,
repetition_penalty=1.0,
logprobs=None,
seed=None,
logit_bias=None,
logits_processors=None,
)

runner._prefill_single("req-1", token_ids, sampling_params)
Expand Down Expand Up @@ -84,6 +94,10 @@ def fake_make_prompt_cache(model):
frequency_penalty=0,
presence_penalty=0,
repetition_penalty=1.0,
logprobs=None,
seed=None,
logit_bias=None,
logits_processors=None,
)

runner._prefill_single("req-1", token_ids, sampling_params)
Expand Down
4 changes: 2 additions & 2 deletions vllm_metal/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

import psutil
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum

from vllm_metal.config import get_config

if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig

logger = logging.getLogger(__name__)
Expand Down
75 changes: 19 additions & 56 deletions vllm_metal/v1/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ArraysCache,
BatchKVCache,
KVCache,
MambaCache,
make_prompt_cache,
)

Expand Down Expand Up @@ -284,59 +283,12 @@ def get_stats(self) -> dict:


# Type alias for any cache type supported by the model
AnyCache: TypeAlias = KVCache | MambaCache | ArraysCache


class BatchMambaCache:
"""Batched cache for Mamba/SSM layers.

Wraps multiple MambaCache instances into a single batched cache
for efficient batched forward passes on hybrid models.
"""

def __init__(self, caches: list[MambaCache | ArraysCache]):
"""Create a batched Mamba cache from individual caches.

Args:
caches: List of MambaCache instances to batch
"""
self._batch_size = len(caches)
self._cache_size = len(caches[0].cache) if caches else 0

# Stack each state array across the batch dimension
self.cache: list[mx.array | None] = []
for i in range(self._cache_size):
states = [c.cache[i] for c in caches]
if all(s is not None for s in states):
self.cache.append(mx.concatenate(states, axis=0))
else:
self.cache.append(None)

def __getitem__(self, idx: int) -> mx.array | None:
return self.cache[idx]

def __setitem__(self, idx: int, value: mx.array | None) -> None:
self.cache[idx] = value

def extract(self, idx: int) -> MambaCache:
"""Extract a single request's cache from the batch.

Args:
idx: Index of the request in the batch

Returns:
MambaCache for the individual request
"""
cache = MambaCache()
for i in range(self._cache_size):
if self.cache[i] is not None:
cache.cache[i] = self.cache[i][idx : idx + 1]
return cache
AnyCache: TypeAlias = KVCache | ArraysCache | Any
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this make the alias effectively type safety lost?



def _is_mamba_cache(cache: AnyCache) -> bool:
"""Check if a cache is a Mamba-style cache (ArraysCache or MambaCache)."""
return isinstance(cache, (MambaCache, ArraysCache))
"""Check if a cache is a Mamba-style cache (has .cache attribute)."""
return hasattr(cache, "cache") and not isinstance(cache, (KVCache, BatchKVCache))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer isinstance(cache, ArraysCache) since MambaCache is a subclass of ArraysCache in mlx-lm; this is clearer and avoids accidental matches

╰─➤  python
Python 3.12.7 (main, Oct 16 2024, 07:12:08) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from mlx_lm.models.cache import ArraysCache, MambaCache
>>> print('issubclass:', issubclass(MambaCache, ArraysCache))
issubclass: True
>>> print('isinstance:', isinstance(MambaCache(), ArraysCache))
isinstance: True



def _mlx_greedy_sample(logits: mx.array) -> mx.array:
Expand Down Expand Up @@ -394,7 +346,7 @@ class RequestState:

def _merge_kv_caches(
caches_list: list[list[AnyCache]],
) -> list[BatchKVCache | BatchMambaCache]:
) -> list[BatchKVCache | ArraysCache]:
"""Merge multiple per-request caches into batched caches.

Args:
Expand All @@ -407,12 +359,12 @@ def _merge_kv_caches(
return []

num_layers = len(caches_list[0])
merged: list[BatchKVCache | BatchMambaCache] = []
merged: list[BatchKVCache | ArraysCache] = []

for layer_idx in range(num_layers):
layer_caches = [caches[layer_idx] for caches in caches_list]
if _is_mamba_cache(layer_caches[0]):
batch_cache = BatchMambaCache(layer_caches)
batch_cache = ArraysCache.merge(layer_caches)
else:
batch_cache = BatchKVCache.merge(layer_caches)
merged.append(batch_cache)
Expand All @@ -421,7 +373,7 @@ def _merge_kv_caches(


def _extract_kv_cache(
batch_caches: list[BatchKVCache | BatchMambaCache], idx: int
batch_caches: list[BatchKVCache | ArraysCache], idx: int
) -> list[AnyCache]:
"""Extract a single request's cache from batched caches.

Expand All @@ -432,7 +384,18 @@ def _extract_kv_cache(
Returns:
List of caches for the request, one per layer
"""
return [cache.extract(idx) for cache in batch_caches]
extracted: list[AnyCache] = []
for cache in batch_caches:
if hasattr(cache, "extract"):
extracted.append(cache.extract(idx))
elif hasattr(cache, "cache") and not isinstance(cache, BatchKVCache):
# Fallback for older ArraysCache/MambaCache versions where .extract is missing
new_cache = type(cache)(len(cache.cache))
new_cache.cache = [c[idx : idx + 1] for c in cache.cache]
extracted.append(new_cache)
Comment on lines +391 to +395
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... can you confirm whether cache can include None entries in this path (hybrid/Mamba/ArraysCache)? The fallback list comprehension slices every entry, which would crash on None. If None is possible, we should preserve those entries in the fallback (or add a test showing it cant happen)

else:
raise TypeError(f"Unsupported cache type: {type(cache)}")
return extracted


class MetalModelRunner:
Expand Down
2 changes: 1 addition & 1 deletion vllm_metal/v1/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils.torch_utils import set_random_seed
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing set_random_seed from vllm.utils.torch_utils breaks on vLLM 0.13.0 (our declared minimum); either keep compatibility with the version we declared or bumped the min version to v0.14.0

from vllm.tasks import SupportedTask
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
Expand Down
Loading