Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 142 additions & 10 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import codecs
import contextlib
import functools
import importlib
import json
import time
from collections.abc import Sequence
Expand All @@ -16,6 +17,9 @@
from tqdm import tqdm
from transformers import PreTrainedTokenizer

_mlx_lm_generate = importlib.import_module("mlx_lm.generate")
SequenceStateMachine = _mlx_lm_generate.SequenceStateMachine

from .models import cache
from .prompt_utils import apply_chat_template
from .turboquant import TurboQuantKVCache, turboquant_enabled
Expand Down Expand Up @@ -835,6 +839,12 @@ def _left_pad_prompts(prompts, max_length=None):
return mx.array([[0] * (max_length - len(p)) + p for p in prompts])


def _right_pad_prompts(prompts, max_length=None):
if max_length is None:
max_length = max(len(p) for p in prompts)
return mx.array([p + [0] * (max_length - len(p)) for p in prompts])


def _make_cache(model, left_padding):
"""
Convert a list of regular caches into their corresponding
Expand Down Expand Up @@ -919,6 +929,12 @@ class Batch:
max_tokens: List[int]
num_tokens: List[int]
cache: List[Any]
tokens: Optional[List[List[int]]] = None
samplers: Optional[List[Callable[[mx.array], mx.array]]] = None
logits_processors: Optional[
List[List[Callable[[mx.array, mx.array], mx.array]]]
] = None
state_machine_states: Optional[List[Any]] = None

def __len__(self):
return len(self.uids)
Expand All @@ -927,6 +943,14 @@ def filter(self, keep_idx: List[int]):
self.uids = [self.uids[k] for k in keep_idx]
self.max_tokens = [self.max_tokens[k] for k in keep_idx]
self.num_tokens = [self.num_tokens[k] for k in keep_idx]
if self.tokens is not None:
self.tokens = [self.tokens[k] for k in keep_idx]
if self.samplers is not None:
self.samplers = [self.samplers[k] for k in keep_idx]
if self.logits_processors is not None:
self.logits_processors = [self.logits_processors[k] for k in keep_idx]
if self.state_machine_states is not None:
self.state_machine_states = [self.state_machine_states[k] for k in keep_idx]
keep_idx = mx.array(keep_idx, mx.int32)
self.y = self.y[keep_idx]
self.logprobs = self.logprobs[keep_idx]
Expand All @@ -939,6 +963,23 @@ def extend(self, other):
self.logprobs = mx.concatenate([self.logprobs, other.logprobs])
self.num_tokens.extend(other.num_tokens)
self.max_tokens.extend(other.max_tokens)
if self.tokens is not None and other.tokens is not None:
self.tokens.extend(other.tokens)
if self.samplers is not None and other.samplers is not None:
self.samplers.extend(other.samplers)
elif other.samplers is not None:
self.samplers = other.samplers
if self.logits_processors is not None and other.logits_processors is not None:
self.logits_processors.extend(other.logits_processors)
elif other.logits_processors is not None:
self.logits_processors = other.logits_processors
if (
self.state_machine_states is not None
and other.state_machine_states is not None
):
self.state_machine_states.extend(other.state_machine_states)
elif other.state_machine_states is not None:
self.state_machine_states = other.state_machine_states
for c, o in zip(self.cache, other.cache):
c.extend(o)

Expand Down Expand Up @@ -978,18 +1019,39 @@ def __init__(
self.prompt_cache = prompt_cache
self._stats = BatchStats()

# Build SequenceStateMachine for stop token matching
stop_seqs = []
if stop_tokens is not None:
for t in stop_tokens:
stop_seqs.append(([t] if isinstance(t, int) else list(t), None))
self._default_state_machine = SequenceStateMachine(
{"normal": stop_seqs} if stop_seqs else {},
initial="normal",
)

# Keep legacy stopping_criteria for backward compat
self.tokenizer.stopping_criteria.add_eos_token_ids(stop_tokens)

self.active_batch = None

def insert(self, prompts, max_tokens: Union[List[int], int, None] = None):
def insert(
self,
prompts,
max_tokens: Union[List[int], int, None] = None,
samplers: Optional[List[Callable[[mx.array], mx.array]]] = None,
logits_processors: Optional[
List[List[Callable[[mx.array, mx.array], mx.array]]]
] = None,
):
uids = []

if max_tokens is None or isinstance(max_tokens, int):
max_tokens = [max_tokens or self.max_tokens] * len(prompts)

for p, m in zip(prompts, max_tokens):
self.unprocessed_prompts.append((self.uid_count, p, m))
for i, (p, m) in enumerate(zip(prompts, max_tokens)):
s = samplers[i] if samplers is not None else None
lp = logits_processors[i] if logits_processors is not None else None
self.unprocessed_prompts.append((self.uid_count, p, m, s, lp))
uids.append(self.uid_count)
self.uid_count += 1
# Sort in ascending order of length
Expand All @@ -999,7 +1061,7 @@ def insert(self, prompts, max_tokens: Union[List[int], int, None] = None):
return uids

def _process_prompts(self, prompts, **kwargs) -> Batch:
uids, inputs, max_tokens = zip(*prompts)
uids, inputs, max_tokens, per_samplers, per_lps = zip(*prompts)
lengths = [len(p) for p in inputs]
max_length = max(lengths)

Expand Down Expand Up @@ -1045,23 +1107,62 @@ def _process_prompts(self, prompts, **kwargs) -> Batch:
inputs = inputs[:, n_to_process:]
mx.clear_cache()

batch_samplers = (
list(per_samplers) if any(s is not None for s in per_samplers) else None
)
batch_lps = list(per_lps) if any(lp is not None for lp in per_lps) else None

y, logprobs = self._step(
inputs, prompt_cache, inputs_embeds=inputs_embeds, **kwargs
inputs,
prompt_cache,
inputs_embeds=inputs_embeds,
samplers=batch_samplers,
logits_processors=batch_lps,
**kwargs,
)

mx.async_eval(y, logprobs)
mx.clear_cache()
return Batch(
list(uids), y, logprobs, list(max_tokens), [0] * len(uids), prompt_cache
uids=list(uids),
y=y,
logprobs=logprobs,
max_tokens=list(max_tokens),
num_tokens=[0] * len(uids),
cache=prompt_cache,
tokens=[[] for _ in uids],
samplers=batch_samplers,
logits_processors=batch_lps,
state_machine_states=[
self._default_state_machine.make_state() for _ in uids
],
)

def _step(self, input_tokens: mx.array, prompt_cache: List[Any], **kwargs):
samplers = kwargs.pop("samplers", None)
logits_processors = kwargs.pop("logits_processors", None)

output = self.model(input_tokens, cache=prompt_cache, **kwargs)
logits = output.logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
sampled = self.sampler(logprobs)

# TODO: Add KV cache quantization if specified
# Apply per-sequence logits processors if provided
if logits_processors is not None:
for i, lps in enumerate(logits_processors):
if lps is not None:
for lp in lps:
logprobs[i] = lp(input_tokens[i], logprobs[i])

# Apply per-sequence samplers or fallback to shared sampler
if samplers is not None and any(s is not None for s in samplers):
sampled = []
for i, s in enumerate(samplers):
fn = s if s is not None else self.sampler
sampled.append(fn(logprobs[i]))
sampled = mx.stack(sampled)
else:
sampled = self.sampler(logprobs)

return sampled, logprobs

def stats(self):
Expand Down Expand Up @@ -1111,10 +1212,19 @@ def _next(self, **kwargs):

batch = self.active_batch
y, logprobs = batch.y, batch.logprobs
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
batch.y, batch.logprobs = self._step(
y[:, None],
batch.cache,
samplers=batch.samplers,
logits_processors=batch.logits_processors,
)
mx.async_eval(batch.y, batch.logprobs)

y = y.tolist()
# Track generated tokens per sequence
if batch.tokens is not None:
for i, t in enumerate(y):
batch.tokens[i].append(t)
toc = time.perf_counter()
if prompt_processing:
self._stats.prompt_time += toc - tic
Expand All @@ -1129,7 +1239,19 @@ def _next(self, **kwargs):
):
num_tok += 1
batch.num_tokens[e] = num_tok
if self.tokenizer.stopping_criteria(t):

# Use SequenceStateMachine for stop detection (supports multi-token sequences)
is_stop = False
if batch.state_machine_states is not None:
state = batch.state_machine_states[e]
new_state, matched_seq, next_name = SequenceStateMachine.match(state, t)
batch.state_machine_states[e] = new_state
if next_name is None and matched_seq is not None:
is_stop = True
else:
is_stop = self.tokenizer.stopping_criteria(t)

if is_stop:
finish_reason = "stop"
end_idx.append(e)
elif num_tok >= max_tok:
Expand Down Expand Up @@ -1382,6 +1504,16 @@ def _generate_batch(
kwargs.pop("prefill_step_size", None)
kwargs["prefill_step_size"] = None

# Ensure stop tokens are passed to BatchGenerator
if "stop_tokens" not in kwargs:
eos_ids = getattr(model.config, "eos_token_id", None)
if eos_ids is not None:
if isinstance(eos_ids, int):
eos_ids = {eos_ids}
elif isinstance(eos_ids, list):
eos_ids = set(eos_ids)
kwargs["stop_tokens"] = eos_ids

# Use batch_size for prefill and completion to ensure consistent processing
gen = BatchGenerator(
model.language_model,
Expand Down
18 changes: 18 additions & 0 deletions mlx_vlm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ def is_trimmable(self):
def trim(self, n):
return 0

@property
def nbytes(self):
if self.keys is None:
return 0
return self.keys.nbytes + self.values.nbytes

def empty(self):
return self.keys is None


class StaticKVCache(_BaseCache):
"""A static cache that grows to accommodate all tokens."""
Expand Down Expand Up @@ -237,3 +246,12 @@ def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n

@property
def nbytes(self):
if self.keys is None:
return 0
return self.keys.nbytes + self.values.nbytes

def empty(self):
return self.keys is None
27 changes: 21 additions & 6 deletions mlx_vlm/models/gemma4/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,27 @@


def masked_scatter(input_tensor, mask, source):
mask_flat = mask.flatten().astype(mx.int32)
indices = mx.cumsum(mask_flat) - 1
aligned = source.flatten()[indices % source.size]
return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape(
input_tensor.shape
)
B = input_tensor.shape[0]
if B == 1:
mask_flat = mask.flatten().astype(mx.int32)
indices = mx.cumsum(mask_flat) - 1
aligned = source.flatten()[indices % source.size]
return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape(
input_tensor.shape
)
# Per-batch scatter to avoid cross-batch index contamination
results = []
for i in range(B):
inp_i = input_tensor[i] # (seq_len, dim)
mask_i = mask[i] # (seq_len, dim)
src_i = source[i] if source.shape[0] > 1 else source[0] # (n_tokens, dim)
mask_flat = mask_i.flatten().astype(mx.int32)
indices = mx.cumsum(mask_flat) - 1
aligned = src_i.flatten()[indices % src_i.size]
results.append(
mx.where(mask_flat, aligned, inp_i.flatten()).reshape(inp_i.shape)
)
return mx.stack(results)


class MultimodalEmbedder(nn.Module):
Expand Down
8 changes: 6 additions & 2 deletions mlx_vlm/models/gemma4/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ def __call__(
if self.is_kv_shared_layer and cache is not None:
state = cache.state
keys, values = state[0], state[1]
offset = cache.offset
o = cache.offset
offset = int(o) if not isinstance(o, mx.array) else o + 0
else:
if cache is not None:
offset = cache.offset
# Snapshot offset before update_and_fetch mutates it
# (BatchRotatingKVCache.offset is a mutable mx.array)
o = cache.offset
offset = int(o) if not isinstance(o, mx.array) else o + 0

keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)

Expand Down
22 changes: 13 additions & 9 deletions mlx_vlm/models/gemma4/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,19 @@ def __call__(self, pixel_values: mx.array) -> mx.array:
else:
valid_mask = ~pool_mask

# For single batch (typical VLM case), count valid tokens and slice
# Since pooling produces contiguous valid tokens followed by padding,
# we can simply count valid tokens and take that many
all_real = []
for i in range(B):
n_valid = int(valid_mask[i].astype(mx.int32).sum().item())
all_real.append(pooled[i, :n_valid])

hidden_states = mx.concatenate(all_real, axis=0)[None] # [1, total_real, dim]
# Strip padding tokens: count valid tokens per image and slice
valid_counts = [
int(valid_mask[i].astype(mx.int32).sum().item()) for i in range(B)
]
max_valid = max(valid_counts)

if B == 1:
# Single image: no padding needed
hidden_states = pooled[:, :max_valid]
else:
# Batch: slice each image to max_valid tokens (same-shape images
# produce the same count; different-shape images get zero-padded)
hidden_states = pooled[:, :max_valid]

if self.config.standardize:
hidden_states = (hidden_states - self.std_bias) * self.std_scale
Expand Down
16 changes: 10 additions & 6 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,13 +1416,17 @@ def add_eos_token_ids(self, new_eos_token_ids: Union[int, List[int]] = None):
raise ValueError("Processor is not provided")

if new_eos_token_ids is not None:
if isinstance(new_eos_token_ids, str):
if isinstance(new_eos_token_ids, (str, int)):
new_eos_token_ids = [new_eos_token_ids]
new_eos_token_ids = [
self.tokenizer.encode(" " + token, add_special_tokens=False)[-1]
for token in new_eos_token_ids
]
self.eos_token_ids.extend(new_eos_token_ids)
resolved = []
for token in new_eos_token_ids:
if isinstance(token, int):
resolved.append(token)
else:
resolved.append(
self.tokenizer.encode(" " + token, add_special_tokens=False)[-1]
)
self.eos_token_ids.extend(resolved)

def reset(self, eos_token_ids: List[int] = None):
eos_token_ids = (
Expand Down
Loading