Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llms/mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
return mask * -1e9


def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if reference_idx is not None:
c = cache[reference_idx]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
Expand Down
20 changes: 14 additions & 6 deletions llms/mlx_lm/models/cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache

@dataclass
class ModelArgs(BaseModelArgs):
Expand Down Expand Up @@ -95,7 +95,6 @@ def __call__(

if cache is not None:
keys, values = cache.update_and_fetch(keys, values)

# Apply sliding window attention if enabled
if self.sliding_window is not None:
window_size = self.sliding_window
Expand All @@ -104,8 +103,8 @@ def __call__(
if mask is not None:
mask = mask[..., -window_size:]

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
Expand Down Expand Up @@ -171,7 +170,7 @@ def __call__(
):
h = self.embed_tokens(inputs)

mask = create_attention_mask(h, cache)
mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1)

if cache is None:
cache = [None] * len(self.layers)
Expand All @@ -198,6 +197,15 @@ def __call__(
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
return caches

@property
def layers(self):
Expand Down
9 changes: 4 additions & 5 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)


if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step(
prompt: mx.array,
model: nn.Module,
Expand Down