Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlx_lm/models/baichuan_m1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import CacheList, KVCache, MambaCache, RotatingKVCache
from .cache import ArraysCache, CacheList, KVCache, RotatingKVCache


@dataclass
Expand Down Expand Up @@ -223,7 +223,7 @@ def make_cache(self) -> List[Any]:
caches = []
for i, layer in enumerate(self.model.layers):
is_swa = i in self.config.sliding_window_layers
conv_cache = MambaCache()
conv_cache = ArraysCache(size=2)
if is_swa:
kv_cache = RotatingKVCache(max_size=self.config.sliding_window)
else:
Expand Down
5 changes: 0 additions & 5 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,11 +646,6 @@ def empty(self):
return self.cache[0] is None


class MambaCache(ArraysCache):
def __init__(self, left_padding: Optional[List[int]] = None):
super().__init__(size=2, left_padding=left_padding)


class ChunkedKVCache(_BaseCache):
step = 256

Expand Down
8 changes: 4 additions & 4 deletions mlx_lm/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import CacheList, KVCache, MambaCache
from .cache import ArraysCache, CacheList, KVCache
from .rope_utils import initialize_rope
from .ssm import ssm_update

Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(self, args):
def _conv(
self,
conv_input: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
if mask is not None:
Expand Down Expand Up @@ -273,7 +273,7 @@ def _ssm(
B: mx.array,
C: mx.array,
dt: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -495,7 +495,7 @@ def sanitize(self, weights):

def make_cache(self):
return [
CacheList(MambaCache(), KVCache())
CacheList(ArraysCache(size=2), KVCache())
for _ in range(self.args.num_hidden_layers)
]

Expand Down
10 changes: 5 additions & 5 deletions mlx_lm/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .rope_utils import initialize_rope
from .ssm import ssm_update
from .switch_layers import SwitchGLU
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, args: ModelArgs):
def _conv(
self,
conv_input: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
if mask is not None:
Expand Down Expand Up @@ -160,7 +160,7 @@ def _ssm(
B: mx.array,
C: mx.array,
dt: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -197,7 +197,7 @@ def __call__(
self,
hidden_states: mx.array,
mask: Optional[mx.array],
cache: Optional[MambaCache] = None,
cache: Optional[ArraysCache] = None,
) -> mx.array:

projected = self.in_proj(hidden_states)
Expand Down Expand Up @@ -496,7 +496,7 @@ def make_cache(self):
caches = []
for layer in self.layers:
if layer.layer_type == "mamba":
caches.append(MambaCache())
caches.append(ArraysCache(size=2))
elif layer.layer_type == "attention":
caches.append(KVCache())
return caches
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .switch_layers import SwitchGLU


Expand Down Expand Up @@ -341,7 +341,7 @@ def make_cache(self):
if layer.is_attn:
caches.append(KVCache())
else:
caches.append(MambaCache())
caches.append(ArraysCache(size=2))
return caches

def sanitize(self, weights):
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/kimi_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .gated_delta import gated_delta_update
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU
Expand Down Expand Up @@ -500,7 +500,7 @@ def make_cache(self):
caches: List[Any] = []
for layer in self.layers:
if layer.is_linear:
caches.append(MambaCache())
caches.append(ArraysCache(size=2))
else:
caches.append(KVCache())
return caches
Expand Down
6 changes: 3 additions & 3 deletions mlx_lm/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .activations import swiglu
from .base import BaseModelArgs
from .cache import MambaCache
from .cache import ArraysCache


@dataclass
Expand Down Expand Up @@ -153,7 +153,7 @@ def __call__(self, x, cache):
x, conv_cache, state_cache
)

if isinstance(cache, MambaCache):
if isinstance(cache, ArraysCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache

Expand Down Expand Up @@ -208,7 +208,7 @@ def __call__(self, inputs: mx.array, cache=None):
return logits

def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
return [ArraysCache(size=2) for _ in range(len(self.layers))]

@property
def layers(self):
Expand Down
18 changes: 9 additions & 9 deletions mlx_lm/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .activations import swiglu
from .base import BaseModelArgs, create_ssm_mask
from .cache import MambaCache
from .cache import ArraysCache
from .ssm import ssm_update


Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, args: ModelArgs, layer_idx: int):
def _conv(
self,
conv_input: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
if mask is not None:
Expand Down Expand Up @@ -134,7 +134,7 @@ def _ssm(
B: mx.array,
C: mx.array,
dt: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -169,7 +169,7 @@ def __call__(
self,
hidden_states: mx.array,
mask: Optional[mx.array],
cache: Optional[MambaCache] = None,
cache: Optional[ArraysCache] = None,
) -> mx.array:
projected = self.in_proj(hidden_states)
gate, conv_input, dt = mx.split(
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(self, args: ModelArgs, layer_idx: int):
self.norm = nn.RMSNorm(args.hidden_size)

def __call__(
self, x: mx.array, mask: Optional[mx.array], cache: Optional[MambaCache] = None
self, x: mx.array, mask: Optional[mx.array], cache: Optional[ArraysCache] = None
) -> mx.array:
output = self.mixer(self.norm(x), mask, cache)
return output + x
Expand All @@ -215,7 +215,7 @@ def __init__(self, args: ModelArgs):
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

def __call__(
self, x: mx.array, cache: Optional[list[MambaCache]] = None
self, x: mx.array, cache: Optional[list[ArraysCache]] = None
) -> mx.array:
hidden = self.embeddings(x)

Expand All @@ -240,7 +240,7 @@ def __init__(self, args: ModelArgs):
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self, inputs: mx.array, cache: Optional[list[MambaCache]] = None
self, inputs: mx.array, cache: Optional[list[ArraysCache]] = None
) -> mx.array:
hidden = self.backbone(inputs, cache)

Expand All @@ -250,8 +250,8 @@ def __call__(
logits = self.lm_head(hidden)
return logits

def make_cache(self, batch_size: int = 1) -> list[MambaCache]:
return [MambaCache() for _ in range(self.args.num_hidden_layers)]
def make_cache(self, batch_size: int = 1) -> list[ArraysCache]:
return [ArraysCache(size=2) for _ in range(self.args.num_hidden_layers)]

@property
def layers(self):
Expand Down
10 changes: 5 additions & 5 deletions mlx_lm/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .ssm import ssm_update
from .switch_layers import SwitchMLP

Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(self, args: ModelArgs):
def _conv(
self,
conv_input: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
if mask is not None:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _ssm(
B: mx.array,
C: mx.array,
dt: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -199,7 +199,7 @@ def __call__(
self,
hidden_states: mx.array,
mask: Optional[mx.array],
cache: Optional[MambaCache] = None,
cache: Optional[ArraysCache] = None,
) -> mx.array:

projected = self.in_proj(hidden_states)
Expand Down Expand Up @@ -495,7 +495,7 @@ def make_cache(self):
caches = []
for l in self.layers:
if l.block_type == "M":
caches.append(MambaCache())
caches.append(ArraysCache(size=2))
elif l.block_type == "*":
caches.append(KVCache())
return caches
Expand Down
6 changes: 3 additions & 3 deletions mlx_lm/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlx_lm.models.base import BaseModelArgs, create_attention_mask, create_ssm_mask

from .activations import swiglu
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .ssm import ssm_update


Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, config: ModelArgs) -> None:
def _conv(
self,
conv_input: mx.array,
cache: Optional[MambaCache],
cache: Optional[ArraysCache],
mask: Optional[mx.array],
) -> mx.array:
if mask is not None:
Expand Down Expand Up @@ -459,7 +459,7 @@ def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
def make_cache(self):
# TODO use RotatingKVCache is not full_attn
# full_attn = self.layer_idx in self.config.full_attention_idx
return [MambaCache() if l.is_mamba else KVCache() for l in self.layers]
return [ArraysCache(size=2) if l.is_mamba else KVCache() for l in self.layers]

def __call__(self, inputs: mx.array, cache=None) -> mx.array:
outputs = self.model(
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
create_ssm_mask,
scaled_dot_product_attention,
)
from .cache import KVCache, MambaCache
from .cache import ArraysCache, KVCache
from .gated_delta import gated_delta_update
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU
Expand Down Expand Up @@ -427,7 +427,7 @@ def layers(self):
return self.model.layers

def make_cache(self):
return [MambaCache() if l.is_linear else KVCache() for l in self.layers]
return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers]

def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache
from .cache import ArraysCache, RotatingKVCache


@dataclass
Expand Down Expand Up @@ -446,7 +446,7 @@ def make_cache(self):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(MambaCache())
cache.append(ArraysCache(size=2))
else:
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache
Loading
Loading