diff --git a/vllm_mlx/utils/mamba_cache.py b/vllm_mlx/utils/mamba_cache.py index eec489988..b29a9e877 100644 --- a/vllm_mlx/utils/mamba_cache.py +++ b/vllm_mlx/utils/mamba_cache.py @@ -12,17 +12,12 @@ import mlx.core as mx -# MambaCache was removed in mlx-lm 0.30.6 - make import conditional +# MambaCache was removed in mlx-lm 0.30.6, fall back to ArraysCache try: from mlx_lm.models.cache import MambaCache - - HAS_MAMBA_CACHE = True except ImportError: - # Fallback for mlx-lm >= 0.30.6 where MambaCache was removed from mlx_lm.models.cache import ArraysCache as MambaCache - HAS_MAMBA_CACHE = False - logger = logging.getLogger(__name__) @@ -42,10 +37,9 @@ def __init__(self, left_padding: Optional[List[int]] = None, size: int = 2): left_padding: Amount of left padding for each sequence in batch size: Number of state arrays (default 2 for Mamba models) """ - if HAS_MAMBA_CACHE: - super().__init__(left_padding=left_padding) - else: - super().__init__(size=size, left_padding=left_padding) + # Always pass size - ArraysCache requires it, and MambaCache + # (if it exists) inherits from ArraysCache + super().__init__(size=size, left_padding=left_padding) self._batch_size = len(left_padding) if left_padding else 0 def extract(self, idx: int) -> MambaCache: @@ -59,10 +53,7 @@ def extract(self, idx: int) -> MambaCache: A new MambaCache with the extracted state """ size = len(self.cache) - if HAS_MAMBA_CACHE: - cache = MambaCache() - else: - cache = MambaCache(size=size) + cache = MambaCache(size=size) # Extract the state arrays for this index cache.cache = [ mx.contiguous(c[idx : idx + 1]) if c is not None else None