diff --git a/vllm_mlx/utils/mamba_cache.py b/vllm_mlx/utils/mamba_cache.py index 03dbead5..eec48998 100644 --- a/vllm_mlx/utils/mamba_cache.py +++ b/vllm_mlx/utils/mamba_cache.py @@ -34,14 +34,18 @@ class BatchMambaCache(MambaCache): mlx-lm's BatchGenerator, specifically the `extract` method. """ - def __init__(self, left_padding: Optional[List[int]] = None): + def __init__(self, left_padding: Optional[List[int]] = None, size: int = 2): """ Initialize BatchMambaCache. Args: left_padding: Amount of left padding for each sequence in batch + size: Number of state arrays (default 2 for Mamba models) """ - super().__init__(left_padding=left_padding) + if HAS_MAMBA_CACHE: + super().__init__(left_padding=left_padding) + else: + super().__init__(size=size, left_padding=left_padding) self._batch_size = len(left_padding) if left_padding else 0 def extract(self, idx: int) -> MambaCache: @@ -54,7 +58,11 @@ def extract(self, idx: int) -> MambaCache: Returns: A new MambaCache with the extracted state """ - cache = MambaCache() + size = len(self.cache) + if HAS_MAMBA_CACHE: + cache = MambaCache() + else: + cache = MambaCache(size=size) # Extract the state arrays for this index cache.cache = [ mx.contiguous(c[idx : idx + 1]) if c is not None else None