Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cef874a
feat: native MTP speculative decoding for Qwen3.5
AirRunner Mar 12, 2026
e18fe67
fix(mtp): eliminate SSM state contamination on draft rejection
AirRunner Mar 12, 2026
f8616ac
fix(mtp): server integration (yield types, cache fallback, batching)
AirRunner Mar 12, 2026
dfda205
fix(mtp): address @janhilgard code review feedback (double-norm, quan…
AirRunner Mar 12, 2026
b967112
feat(mtp): add --mtp CLI flag for generate and server
AirRunner Mar 12, 2026
f9488e7
test(mtp): add unit tests for MTP speculative decoding
AirRunner Mar 12, 2026
bb0a223
fix(mtp): warn when --mtp flag is used with a model without MTP head
AirRunner Mar 12, 2026
a66d242
style: apply black and isort formatting
AirRunner Mar 13, 2026
04a4383
fix(mtp): stack per-expert MTP weights for MoE models
Thump604 Mar 17, 2026
fb00df4
feat: configurable KVCache step size and pre-allocation
Thump604 Mar 22, 2026
1199dbe
Replay B1 checkpoint and prompt-cache slice
lyonsno Mar 22, 2026
9be5267
Handle exact checkpoint cache hits safely
lyonsno Mar 22, 2026
1dab9cc
Add constraint tests for LRU cache extraction and checkpoint semantics
lyonsno Mar 23, 2026
b7400e7
Skip checkpoint insertion for non-thinking models
lyonsno Mar 23, 2026
71b03b1
fix: handle 1D logits in mtp_generate_step logits processors
Thump604 Mar 23, 2026
fffd7c7
Restore checkpoint creation for non-thinking models
lyonsno Mar 24, 2026
7d18c44
fix: adaptive prefill step for long context + Nemotron SSM init
Thump604 Mar 25, 2026
f0bd14f
Widen memory guardrail test tolerance from 1.35x to 2.0x
lyonsno Mar 25, 2026
6070f37
Rename count to ref_count and add _has_rewind_impl helper
lyonsno Mar 25, 2026
451ffab
Merge PR #1042: prompt cache rewind support
Thump604 Mar 25, 2026
4be9c29
feat: probabilistic MTP acceptance (speculative sampling)
Thump604 Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 330 additions & 32 deletions mlx_lm/generate.py

Large diffs are not rendered by default.

67 changes: 61 additions & 6 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
max_context: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
Expand All @@ -25,6 +26,10 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
max_context (Optional[int]): If provided, pre-allocate the KV cache
buffer to hold this many tokens. Eliminates reallocation and
concatenation during generation. Useful when the maximum context
length is known (e.g., from server configuration).
"""
if hasattr(model, "make_cache"):
return model.make_cache()
Expand All @@ -35,7 +40,7 @@ def make_prompt_cache(
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
return [KVCache(max_size=max_context) for _ in range(num_layers)]


def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
Expand Down Expand Up @@ -164,6 +169,22 @@ def empty(self):
"""
raise NotImplementedError("Cache sub-class must implement this.")

def rewind(self, num_to_trim: int) -> bool:
raise NotImplementedError("Cache sub-class must implement rewind.")

def _has_rewind_impl(self):
"""Check whether this cache has a real rewind implementation.

Returns True if the concrete class overrides rewind() beyond the
_BaseCache default. This uses method identity rather than a separate
opt-in flag so that third-party caches that implement rewind()
participate automatically without needing to know about this helper.
"""
try:
return type(self).rewind is not _BaseCache.rewind
except Exception:
return False

@classmethod
def from_state(cls, state, meta_state):
# Create an instance of cls without calling __init__
Expand Down Expand Up @@ -230,12 +251,14 @@ def nbytes(self):
class QuantizedKVCache(_BaseCache):
step = 256

def __init__(self, group_size: int = 64, bits: int = 8):
def __init__(self, group_size: int = 64, bits: int = 8, step: Optional[int] = None):
self.keys = None
self.values = None
self.offset = 0
self.group_size = group_size
self.bits = bits
if step is not None:
self.step = step

def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
Expand Down Expand Up @@ -323,19 +346,28 @@ def nbytes(self):
class KVCache(_BaseCache):
step = 256

def __init__(self):
def __init__(self, max_size: Optional[int] = None, step: Optional[int] = None):
self.keys = None
self.values = None
self.offset = 0
self._max_size = max_size
if step is not None:
self.step = step

def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
if self._max_size is not None and self.keys is None:
# Pre-allocate to max_size — eliminates all future
# boundary reallocations and concatenations.
alloc_size = self._max_size
else:
n_steps = (self.step + keys.shape[2] - 1) // self.step
alloc_size = n_steps * self.step
k_shape = (B, n_kv_heads, alloc_size, k_head_dim)
v_shape = (B, n_kv_heads, alloc_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
Expand Down Expand Up @@ -594,6 +626,9 @@ def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
instance.left_padding = None
instance.lengths = None
# Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens
# in an MTP draft-verification step. Cleared after each step.
instance.rollback_state = None
return instance

def __init__(self, size, left_padding: Optional[List[int]] = None):
Expand Down Expand Up @@ -1247,8 +1282,28 @@ def trim(self, n):
self._offset -= n
self._idx -= n
self.offset -= n
if self.rotated:
self.left_padding += n
return n

def can_rewind(self, num_to_trim: int) -> bool:
if num_to_trim <= 0:
return True
if self.keys is None or self.values is None:
return False
if self._idx < 0 or self._idx > self.keys.shape[2]:
return False
if num_to_trim > self._offset or num_to_trim > self._idx:
return False
return True

def rewind(self, num_to_trim: int) -> bool:
if not self.can_rewind(num_to_trim):
return False
if num_to_trim <= 0:
return True
return self.trim(num_to_trim) == num_to_trim

def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("BatchRotatingKVCache Quantization NYI")

Expand Down
Loading