-
Notifications
You must be signed in to change notification settings - Fork 136
Update Mamba cache to support ArraysCache #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,6 @@ | |
| ArraysCache, | ||
| BatchKVCache, | ||
| KVCache, | ||
| MambaCache, | ||
| make_prompt_cache, | ||
| ) | ||
|
|
||
|
|
@@ -284,59 +283,12 @@ def get_stats(self) -> dict: | |
|
|
||
|
|
||
| # Type alias for any cache type supported by the model | ||
| AnyCache: TypeAlias = KVCache | MambaCache | ArraysCache | ||
|
|
||
|
|
||
| class BatchMambaCache: | ||
| """Batched cache for Mamba/SSM layers. | ||
|
|
||
| Wraps multiple MambaCache instances into a single batched cache | ||
| for efficient batched forward passes on hybrid models. | ||
| """ | ||
|
|
||
| def __init__(self, caches: list[MambaCache | ArraysCache]): | ||
| """Create a batched Mamba cache from individual caches. | ||
|
|
||
| Args: | ||
| caches: List of MambaCache instances to batch | ||
| """ | ||
| self._batch_size = len(caches) | ||
| self._cache_size = len(caches[0].cache) if caches else 0 | ||
|
|
||
| # Stack each state array across the batch dimension | ||
| self.cache: list[mx.array | None] = [] | ||
| for i in range(self._cache_size): | ||
| states = [c.cache[i] for c in caches] | ||
| if all(s is not None for s in states): | ||
| self.cache.append(mx.concatenate(states, axis=0)) | ||
| else: | ||
| self.cache.append(None) | ||
|
|
||
| def __getitem__(self, idx: int) -> mx.array | None: | ||
| return self.cache[idx] | ||
|
|
||
| def __setitem__(self, idx: int, value: mx.array | None) -> None: | ||
| self.cache[idx] = value | ||
|
|
||
| def extract(self, idx: int) -> MambaCache: | ||
| """Extract a single request's cache from the batch. | ||
|
|
||
| Args: | ||
| idx: Index of the request in the batch | ||
|
|
||
| Returns: | ||
| MambaCache for the individual request | ||
| """ | ||
| cache = MambaCache() | ||
| for i in range(self._cache_size): | ||
| if self.cache[i] is not None: | ||
| cache.cache[i] = self.cache[i][idx : idx + 1] | ||
| return cache | ||
| AnyCache: TypeAlias = KVCache | ArraysCache | Any | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this make the alias effectively type safety lost? |
||
|
|
||
|
|
||
| def _is_mamba_cache(cache: AnyCache) -> bool: | ||
| """Check if a cache is a Mamba-style cache (ArraysCache or MambaCache).""" | ||
| return isinstance(cache, (MambaCache, ArraysCache)) | ||
| """Check if a cache is a Mamba-style cache (has .cache attribute).""" | ||
| return hasattr(cache, "cache") and not isinstance(cache, (KVCache, BatchKVCache)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer isinstance(cache, ArraysCache) since MambaCache is a subclass of ArraysCache in mlx-lm; this is clearer and avoids accidental matches ╰─➤ python
Python 3.12.7 (main, Oct 16 2024, 07:12:08) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from mlx_lm.models.cache import ArraysCache, MambaCache
>>> print('issubclass:', issubclass(MambaCache, ArraysCache))
issubclass: True
>>> print('isinstance:', isinstance(MambaCache(), ArraysCache))
isinstance: True |
||
|
|
||
|
|
||
| def _mlx_greedy_sample(logits: mx.array) -> mx.array: | ||
|
|
@@ -394,7 +346,7 @@ class RequestState: | |
|
|
||
| def _merge_kv_caches( | ||
| caches_list: list[list[AnyCache]], | ||
| ) -> list[BatchKVCache | BatchMambaCache]: | ||
| ) -> list[BatchKVCache | ArraysCache]: | ||
| """Merge multiple per-request caches into batched caches. | ||
|
|
||
| Args: | ||
|
|
@@ -407,12 +359,12 @@ def _merge_kv_caches( | |
| return [] | ||
|
|
||
| num_layers = len(caches_list[0]) | ||
| merged: list[BatchKVCache | BatchMambaCache] = [] | ||
| merged: list[BatchKVCache | ArraysCache] = [] | ||
|
|
||
| for layer_idx in range(num_layers): | ||
| layer_caches = [caches[layer_idx] for caches in caches_list] | ||
| if _is_mamba_cache(layer_caches[0]): | ||
| batch_cache = BatchMambaCache(layer_caches) | ||
| batch_cache = ArraysCache.merge(layer_caches) | ||
| else: | ||
| batch_cache = BatchKVCache.merge(layer_caches) | ||
| merged.append(batch_cache) | ||
|
|
@@ -421,7 +373,7 @@ def _merge_kv_caches( | |
|
|
||
|
|
||
| def _extract_kv_cache( | ||
| batch_caches: list[BatchKVCache | BatchMambaCache], idx: int | ||
| batch_caches: list[BatchKVCache | ArraysCache], idx: int | ||
| ) -> list[AnyCache]: | ||
| """Extract a single request's cache from batched caches. | ||
|
|
||
|
|
@@ -432,7 +384,18 @@ def _extract_kv_cache( | |
| Returns: | ||
| List of caches for the request, one per layer | ||
| """ | ||
| return [cache.extract(idx) for cache in batch_caches] | ||
| extracted: list[AnyCache] = [] | ||
| for cache in batch_caches: | ||
| if hasattr(cache, "extract"): | ||
| extracted.append(cache.extract(idx)) | ||
| elif hasattr(cache, "cache") and not isinstance(cache, BatchKVCache): | ||
| # Fallback for older ArraysCache/MambaCache versions where .extract is missing | ||
| new_cache = type(cache)(len(cache.cache)) | ||
| new_cache.cache = [c[idx : idx + 1] for c in cache.cache] | ||
| extracted.append(new_cache) | ||
|
Comment on lines
+391
to
+395
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm... can you confirm whether cache can include None entries in this path (hybrid/Mamba/ArraysCache)? The fallback list comprehension slices every entry, which would crash on None. If None is possible, we should preserve those entries in the fallback (or add a test showing it cant happen) |
||
| else: | ||
| raise TypeError(f"Unsupported cache type: {type(cache)}") | ||
| return extracted | ||
|
|
||
|
|
||
| class MetalModelRunner: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| ) | ||
| from vllm.logger import init_logger | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.model_executor import set_random_seed | ||
| from vllm.utils.torch_utils import set_random_seed | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. importing set_random_seed from vllm.utils.torch_utils breaks on vLLM 0.13.0 (our declared minimum); either keep compatibility with the version we declared or bumped the min version to v0.14.0 |
||
| from vllm.tasks import SupportedTask | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
| from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm.v1.attention.* imports in vllm_metal/platform.py and tests/test_platform.py do not exist in vLLM 0.13.0; it only exists in 0.14.0?