diff --git a/mlx_lm/models/baichuan_m1.py b/mlx_lm/models/baichuan_m1.py index 3221c02a4..1a52b4f36 100644 --- a/mlx_lm/models/baichuan_m1.py +++ b/mlx_lm/models/baichuan_m1.py @@ -8,7 +8,7 @@ from .activations import swiglu from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import CacheList, KVCache, MambaCache, RotatingKVCache +from .cache import ArraysCache, CacheList, KVCache, RotatingKVCache @dataclass @@ -223,7 +223,7 @@ def make_cache(self) -> List[Any]: caches = [] for i, layer in enumerate(self.model.layers): is_swa = i in self.config.sliding_window_layers - conv_cache = MambaCache() + conv_cache = ArraysCache(size=2) if is_swa: kv_cache = RotatingKVCache(max_size=self.config.sliding_window) else: diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index aecbcb72e..745691249 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -646,11 +646,6 @@ def empty(self): return self.cache[0] is None -class MambaCache(ArraysCache): - def __init__(self, left_padding: Optional[List[int]] = None): - super().__init__(size=2, left_padding=left_padding) - - class ChunkedKVCache(_BaseCache): step = 256 diff --git a/mlx_lm/models/falcon_h1.py b/mlx_lm/models/falcon_h1.py index 531b78b80..773893485 100644 --- a/mlx_lm/models/falcon_h1.py +++ b/mlx_lm/models/falcon_h1.py @@ -13,7 +13,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import CacheList, KVCache, MambaCache +from .cache import ArraysCache, CacheList, KVCache from .rope_utils import initialize_rope from .ssm import ssm_update @@ -236,7 +236,7 @@ def __init__(self, args): def _conv( self, conv_input: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: if mask is not None: @@ -273,7 +273,7 @@ def _ssm( B: mx.array, C: mx.array, dt: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: batch_size, seq_len, _ = hidden_states.shape @@ -495,7 +495,7 @@ def sanitize(self, weights): def make_cache(self): return [ - CacheList(MambaCache(), KVCache()) + CacheList(ArraysCache(size=2), KVCache()) for _ in range(self.args.num_hidden_layers) ] diff --git a/mlx_lm/models/granitemoehybrid.py b/mlx_lm/models/granitemoehybrid.py index 40dae9977..a95c887e2 100644 --- a/mlx_lm/models/granitemoehybrid.py +++ b/mlx_lm/models/granitemoehybrid.py @@ -13,7 +13,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .rope_utils import initialize_rope from .ssm import ssm_update from .switch_layers import SwitchGLU @@ -123,7 +123,7 @@ def __init__(self, args: ModelArgs): def _conv( self, conv_input: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: if mask is not None: @@ -160,7 +160,7 @@ def _ssm( B: mx.array, C: mx.array, dt: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: batch_size, seq_len, _ = hidden_states.shape @@ -197,7 +197,7 @@ def __call__( self, hidden_states: mx.array, mask: Optional[mx.array], - cache: Optional[MambaCache] = None, + cache: Optional[ArraysCache] = None, ) -> mx.array: projected = self.in_proj(hidden_states) @@ -496,7 +496,7 @@ def make_cache(self): caches = [] for layer in self.layers: if layer.layer_type == "mamba": - caches.append(MambaCache()) + caches.append(ArraysCache(size=2)) elif layer.layer_type == "attention": caches.append(KVCache()) return caches diff --git a/mlx_lm/models/jamba.py b/mlx_lm/models/jamba.py index f7515c018..c41380caf 100644 --- a/mlx_lm/models/jamba.py +++ b/mlx_lm/models/jamba.py @@ -14,7 +14,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .switch_layers import SwitchGLU @@ -341,7 +341,7 @@ def make_cache(self): if layer.is_attn: caches.append(KVCache()) else: - caches.append(MambaCache()) + caches.append(ArraysCache(size=2)) return caches def sanitize(self, weights): diff --git a/mlx_lm/models/kimi_linear.py b/mlx_lm/models/kimi_linear.py index bf61f5aa3..633f6eb95 100644 --- a/mlx_lm/models/kimi_linear.py +++ b/mlx_lm/models/kimi_linear.py @@ -13,7 +13,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .gated_delta import gated_delta_update from .rope_utils import initialize_rope from .switch_layers import SwitchGLU @@ -500,7 +500,7 @@ def make_cache(self): caches: List[Any] = [] for layer in self.layers: if layer.is_linear: - caches.append(MambaCache()) + caches.append(ArraysCache(size=2)) else: caches.append(KVCache()) return caches diff --git a/mlx_lm/models/mamba.py b/mlx_lm/models/mamba.py index 319a950db..0eff678fb 100644 --- a/mlx_lm/models/mamba.py +++ b/mlx_lm/models/mamba.py @@ -8,7 +8,7 @@ from .activations import swiglu from .base import BaseModelArgs -from .cache import MambaCache +from .cache import ArraysCache @dataclass @@ -153,7 +153,7 @@ def __call__(self, x, cache): x, conv_cache, state_cache ) - if isinstance(cache, MambaCache): + if isinstance(cache, ArraysCache): cache[0] = new_conv_cache cache[1] = new_state_cache @@ -208,7 +208,7 @@ def __call__(self, inputs: mx.array, cache=None): return logits def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] + return [ArraysCache(size=2) for _ in range(len(self.layers))] @property def layers(self): diff --git a/mlx_lm/models/mamba2.py b/mlx_lm/models/mamba2.py index 87db6a6b9..f562b3374 100644 --- a/mlx_lm/models/mamba2.py +++ b/mlx_lm/models/mamba2.py @@ -9,7 +9,7 @@ from .activations import swiglu from .base import BaseModelArgs, create_ssm_mask -from .cache import MambaCache +from .cache import ArraysCache from .ssm import ssm_update @@ -97,7 +97,7 @@ def __init__(self, args: ModelArgs, layer_idx: int): def _conv( self, conv_input: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: if mask is not None: @@ -134,7 +134,7 @@ def _ssm( B: mx.array, C: mx.array, dt: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: batch_size, seq_len, _ = hidden_states.shape @@ -169,7 +169,7 @@ def __call__( self, hidden_states: mx.array, mask: Optional[mx.array], - cache: Optional[MambaCache] = None, + cache: Optional[ArraysCache] = None, ) -> mx.array: projected = self.in_proj(hidden_states) gate, conv_input, dt = mx.split( @@ -200,7 +200,7 @@ def __init__(self, args: ModelArgs, layer_idx: int): self.norm = nn.RMSNorm(args.hidden_size) def __call__( - self, x: mx.array, mask: Optional[mx.array], cache: Optional[MambaCache] = None + self, x: mx.array, mask: Optional[mx.array], cache: Optional[ArraysCache] = None ) -> mx.array: output = self.mixer(self.norm(x), mask, cache) return output + x @@ -215,7 +215,7 @@ def __init__(self, args: ModelArgs): self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) def __call__( - self, x: mx.array, cache: Optional[list[MambaCache]] = None + self, x: mx.array, cache: Optional[list[ArraysCache]] = None ) -> mx.array: hidden = self.embeddings(x) @@ -240,7 +240,7 @@ def __init__(self, args: ModelArgs): self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( - self, inputs: mx.array, cache: Optional[list[MambaCache]] = None + self, inputs: mx.array, cache: Optional[list[ArraysCache]] = None ) -> mx.array: hidden = self.backbone(inputs, cache) @@ -250,8 +250,8 @@ def __call__( logits = self.lm_head(hidden) return logits - def make_cache(self, batch_size: int = 1) -> list[MambaCache]: - return [MambaCache() for _ in range(self.args.num_hidden_layers)] + def make_cache(self, batch_size: int = 1) -> list[ArraysCache]: + return [ArraysCache(size=2) for _ in range(self.args.num_hidden_layers)] @property def layers(self): diff --git a/mlx_lm/models/nemotron_h.py b/mlx_lm/models/nemotron_h.py index cb94edcb2..5fb500eea 100644 --- a/mlx_lm/models/nemotron_h.py +++ b/mlx_lm/models/nemotron_h.py @@ -14,7 +14,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .ssm import ssm_update from .switch_layers import SwitchMLP @@ -125,7 +125,7 @@ def __init__(self, args: ModelArgs): def _conv( self, conv_input: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: if mask is not None: @@ -162,7 +162,7 @@ def _ssm( B: mx.array, C: mx.array, dt: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: batch_size, seq_len, _ = hidden_states.shape @@ -199,7 +199,7 @@ def __call__( self, hidden_states: mx.array, mask: Optional[mx.array], - cache: Optional[MambaCache] = None, + cache: Optional[ArraysCache] = None, ) -> mx.array: projected = self.in_proj(hidden_states) @@ -495,7 +495,7 @@ def make_cache(self): caches = [] for l in self.layers: if l.block_type == "M": - caches.append(MambaCache()) + caches.append(ArraysCache(size=2)) elif l.block_type == "*": caches.append(KVCache()) return caches diff --git a/mlx_lm/models/plamo2.py b/mlx_lm/models/plamo2.py index 9e32aa150..b835de0a5 100644 --- a/mlx_lm/models/plamo2.py +++ b/mlx_lm/models/plamo2.py @@ -10,7 +10,7 @@ from mlx_lm.models.base import BaseModelArgs, create_attention_mask, create_ssm_mask from .activations import swiglu -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .ssm import ssm_update @@ -101,7 +101,7 @@ def __init__(self, config: ModelArgs) -> None: def _conv( self, conv_input: mx.array, - cache: Optional[MambaCache], + cache: Optional[ArraysCache], mask: Optional[mx.array], ) -> mx.array: if mask is not None: @@ -459,7 +459,7 @@ def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: def make_cache(self): # TODO use RotatingKVCache is not full_attn # full_attn = self.layer_idx in self.config.full_attention_idx - return [MambaCache() if l.is_mamba else KVCache() for l in self.layers] + return [ArraysCache(size=2) if l.is_mamba else KVCache() for l in self.layers] def __call__(self, inputs: mx.array, cache=None) -> mx.array: outputs = self.model( diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 28b6c97dc..47c9971bd 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -15,7 +15,7 @@ create_ssm_mask, scaled_dot_product_attention, ) -from .cache import KVCache, MambaCache +from .cache import ArraysCache, KVCache from .gated_delta import gated_delta_update from .rope_utils import initialize_rope from .switch_layers import SwitchGLU @@ -427,7 +427,7 @@ def layers(self): return self.model.layers def make_cache(self): - return [MambaCache() if l.is_linear else KVCache() for l in self.layers] + return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: diff --git a/mlx_lm/models/recurrent_gemma.py b/mlx_lm/models/recurrent_gemma.py index 86cffcb0e..4659d6c3d 100644 --- a/mlx_lm/models/recurrent_gemma.py +++ b/mlx_lm/models/recurrent_gemma.py @@ -8,7 +8,7 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import MambaCache, RotatingKVCache +from .cache import ArraysCache, RotatingKVCache @dataclass @@ -446,7 +446,7 @@ def make_cache(self): cache = [] for layer in self.layers: if layer.temporal_block_type == "recurrent": - cache.append(MambaCache()) + cache.append(ArraysCache(size=2)) else: cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) return cache diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 887dd153d..c36b27a53 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -720,13 +720,11 @@ def progress_callback(info): if request is not None: rqueue, request, args = request - is_batchable = self._is_batchable(args) - # Can it be added to the current batch? if ( batch_generator is not None and current_model == args.model - and is_batchable + and self._is_batchable(args) ): try: prompt = self._tokenize(current_tokenizer, request, args) @@ -773,13 +771,9 @@ def progress_callback(info): } continue - # We have no batch and it actually is not a batchable request - # so serve single sequence at a time. - elif batch_generator is None and not is_batchable: - self._serve_single((rqueue, request, args)) - continue - - # No batch so make one and serve this batched + # No batch generator. Load the model and if it's not + # batchable serve sequential, o/w make a batch generaotr and + # serve batched elif batch_generator is None: try: model, tokenizer = self.model_provider.load( @@ -789,6 +783,10 @@ def progress_callback(info): rqueue.put(e) continue + if not self._is_batchable(args): + self._serve_single((rqueue, request, args)) + continue + current_model = args.model current_tokenizer = tokenizer current_model_key = self.model_provider.model_key @@ -881,9 +879,8 @@ def progress(tokens_processed, tokens_total): try: # Load the model and tokenizer - model, tokenizer = self.model_provider.load( - args.model.model, args.model.adapter, args.model.draft - ) + model = self.model_provider.model + tokenizer = self.model_provider.tokenizer draft_model = self.model_provider.draft_model # Prepare the prompt diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index 444ec8d75..1a5d3199d 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -16,7 +16,6 @@ CacheList, ChunkedKVCache, KVCache, - MambaCache, QuantizedKVCache, RotatingKVCache, load_prompt_cache, @@ -103,14 +102,14 @@ def test_save_load_mixed_cache(self): cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") cache = [ - MambaCache(), + ArraysCache(size=2), KVCache(), RotatingKVCache(8), - MambaCache(), + ArraysCache(size=2), ChunkedKVCache(256), ] for c in cache: - if isinstance(c, MambaCache): + if isinstance(c, ArraysCache): c[0] = mx.random.uniform(shape=(4, 4, 4)) c[1] = mx.random.uniform(shape=(4, 4, 4)) else: @@ -121,7 +120,7 @@ def test_save_load_mixed_cache(self): save_prompt_cache(cache_file, cache) loaded_cache = load_prompt_cache(cache_file) for c, lc in zip(cache, loaded_cache): - if isinstance(c, MambaCache): + if isinstance(c, ArraysCache): self.assertTrue(mx.array_equal(c[0], lc[0])) self.assertTrue(mx.array_equal(c[1], lc[1])) else: @@ -133,10 +132,10 @@ def test_save_load_mixed_cache(self): self.assertTrue(mx.array_equal(k, lk)) self.assertTrue(mx.array_equal(v, lv)) - def test_save_load_mamba_cache(self): + def test_save_load_arrays_cache(self): cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - cache = [MambaCache()] + cache = [ArraysCache(size=2)] cache[0][0] = mx.zeros((1, 4, 4)) cache[0][1] = mx.zeros((1, 4, 4)) @@ -182,16 +181,18 @@ def test_trim_cache(self): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 3) - # Can't trim mamba cache - cache = [MambaCache() for _ in range(2)] + # Can't trim arrays cache + cache = [ArraysCache(size=2) for _ in range(2)] for c in cache: - c.state = mx.zeros((5, 5)) + c[0] = mx.zeros((5, 5)) + c[1] = mx.zeros((5, 5)) num_trimmed = trim_prompt_cache(cache, 7) self.assertEqual(num_trimmed, 0) # All cache's have to be trimmable - cache = [MambaCache(), KVCache()] - cache[0].state = mx.zeros((5, 5)) + cache = [ArraysCache(size=2), KVCache()] + cache[0][0] = mx.zeros((5, 5)) + cache[0][1] = mx.zeros((5, 5)) x = mx.random.uniform(shape=(1, 8, 10, 4)) cache[1].update_and_fetch(x, x) num_trimmed = trim_prompt_cache(cache, 1) @@ -338,7 +339,7 @@ def test_cache_list(self): m = c.trim(5) self.assertEqual(m, 5) - c = CacheList(MambaCache(), KVCache()) + c = CacheList(ArraysCache(size=2), KVCache()) self.assertFalse(c.is_trimmable()) c1 = CacheList(ArraysCache(size=1), KVCache()) @@ -570,12 +571,12 @@ def test_save_load_batch_caches(self): cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") cache = [ - MambaCache(left_padding=[1, 2]), + ArraysCache(size=2, left_padding=[1, 2]), BatchKVCache(left_padding=[1, 2]), BatchRotatingKVCache(max_size=10, left_padding=[1, 2]), ] for c in cache: - if isinstance(c, MambaCache): + if isinstance(c, ArraysCache): c[0] = mx.random.uniform(shape=(4, 4, 4)) c[1] = mx.random.uniform(shape=(4, 4, 4)) else: