diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 7defa16d..4b9d6cf0 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -2,6 +2,7 @@ import codecs import contextlib import functools +import importlib import json import time from collections.abc import Sequence @@ -16,6 +17,9 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer +_mlx_lm_generate = importlib.import_module("mlx_lm.generate") +SequenceStateMachine = _mlx_lm_generate.SequenceStateMachine + from .models import cache from .prompt_utils import apply_chat_template from .turboquant import TurboQuantKVCache, turboquant_enabled @@ -835,6 +839,12 @@ def _left_pad_prompts(prompts, max_length=None): return mx.array([[0] * (max_length - len(p)) + p for p in prompts]) +def _right_pad_prompts(prompts, max_length=None): + if max_length is None: + max_length = max(len(p) for p in prompts) + return mx.array([p + [0] * (max_length - len(p)) for p in prompts]) + + def _make_cache(model, left_padding): """ Convert a list of regular caches into their corresponding @@ -919,6 +929,12 @@ class Batch: max_tokens: List[int] num_tokens: List[int] cache: List[Any] + tokens: Optional[List[List[int]]] = None + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None + state_machine_states: Optional[List[Any]] = None def __len__(self): return len(self.uids) @@ -927,6 +943,14 @@ def filter(self, keep_idx: List[int]): self.uids = [self.uids[k] for k in keep_idx] self.max_tokens = [self.max_tokens[k] for k in keep_idx] self.num_tokens = [self.num_tokens[k] for k in keep_idx] + if self.tokens is not None: + self.tokens = [self.tokens[k] for k in keep_idx] + if self.samplers is not None: + self.samplers = [self.samplers[k] for k in keep_idx] + if self.logits_processors is not None: + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + if self.state_machine_states is not None: + self.state_machine_states = [self.state_machine_states[k] for k in keep_idx] keep_idx = mx.array(keep_idx, mx.int32) self.y = self.y[keep_idx] self.logprobs = self.logprobs[keep_idx] @@ -939,6 +963,23 @@ def extend(self, other): self.logprobs = mx.concatenate([self.logprobs, other.logprobs]) self.num_tokens.extend(other.num_tokens) self.max_tokens.extend(other.max_tokens) + if self.tokens is not None and other.tokens is not None: + self.tokens.extend(other.tokens) + if self.samplers is not None and other.samplers is not None: + self.samplers.extend(other.samplers) + elif other.samplers is not None: + self.samplers = other.samplers + if self.logits_processors is not None and other.logits_processors is not None: + self.logits_processors.extend(other.logits_processors) + elif other.logits_processors is not None: + self.logits_processors = other.logits_processors + if ( + self.state_machine_states is not None + and other.state_machine_states is not None + ): + self.state_machine_states.extend(other.state_machine_states) + elif other.state_machine_states is not None: + self.state_machine_states = other.state_machine_states for c, o in zip(self.cache, other.cache): c.extend(o) @@ -978,18 +1019,39 @@ def __init__( self.prompt_cache = prompt_cache self._stats = BatchStats() + # Build SequenceStateMachine for stop token matching + stop_seqs = [] + if stop_tokens is not None: + for t in stop_tokens: + stop_seqs.append(([t] if isinstance(t, int) else list(t), None)) + self._default_state_machine = SequenceStateMachine( + {"normal": stop_seqs} if stop_seqs else {}, + initial="normal", + ) + + # Keep legacy stopping_criteria for backward compat self.tokenizer.stopping_criteria.add_eos_token_ids(stop_tokens) self.active_batch = None - def insert(self, prompts, max_tokens: Union[List[int], int, None] = None): + def insert( + self, + prompts, + max_tokens: Union[List[int], int, None] = None, + samplers: Optional[List[Callable[[mx.array], mx.array]]] = None, + logits_processors: Optional[ + List[List[Callable[[mx.array, mx.array], mx.array]]] + ] = None, + ): uids = [] if max_tokens is None or isinstance(max_tokens, int): max_tokens = [max_tokens or self.max_tokens] * len(prompts) - for p, m in zip(prompts, max_tokens): - self.unprocessed_prompts.append((self.uid_count, p, m)) + for i, (p, m) in enumerate(zip(prompts, max_tokens)): + s = samplers[i] if samplers is not None else None + lp = logits_processors[i] if logits_processors is not None else None + self.unprocessed_prompts.append((self.uid_count, p, m, s, lp)) uids.append(self.uid_count) self.uid_count += 1 # Sort in ascending order of length @@ -999,7 +1061,7 @@ def insert(self, prompts, max_tokens: Union[List[int], int, None] = None): return uids def _process_prompts(self, prompts, **kwargs) -> Batch: - uids, inputs, max_tokens = zip(*prompts) + uids, inputs, max_tokens, per_samplers, per_lps = zip(*prompts) lengths = [len(p) for p in inputs] max_length = max(lengths) @@ -1045,23 +1107,62 @@ def _process_prompts(self, prompts, **kwargs) -> Batch: inputs = inputs[:, n_to_process:] mx.clear_cache() + batch_samplers = ( + list(per_samplers) if any(s is not None for s in per_samplers) else None + ) + batch_lps = list(per_lps) if any(lp is not None for lp in per_lps) else None + y, logprobs = self._step( - inputs, prompt_cache, inputs_embeds=inputs_embeds, **kwargs + inputs, + prompt_cache, + inputs_embeds=inputs_embeds, + samplers=batch_samplers, + logits_processors=batch_lps, + **kwargs, ) mx.async_eval(y, logprobs) mx.clear_cache() return Batch( - list(uids), y, logprobs, list(max_tokens), [0] * len(uids), prompt_cache + uids=list(uids), + y=y, + logprobs=logprobs, + max_tokens=list(max_tokens), + num_tokens=[0] * len(uids), + cache=prompt_cache, + tokens=[[] for _ in uids], + samplers=batch_samplers, + logits_processors=batch_lps, + state_machine_states=[ + self._default_state_machine.make_state() for _ in uids + ], ) def _step(self, input_tokens: mx.array, prompt_cache: List[Any], **kwargs): + samplers = kwargs.pop("samplers", None) + logits_processors = kwargs.pop("logits_processors", None) + output = self.model(input_tokens, cache=prompt_cache, **kwargs) logits = output.logits[:, -1, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) - # TODO: Add KV cache quantization if specified + # Apply per-sequence logits processors if provided + if logits_processors is not None: + for i, lps in enumerate(logits_processors): + if lps is not None: + for lp in lps: + logprobs[i] = lp(input_tokens[i], logprobs[i]) + + # Apply per-sequence samplers or fallback to shared sampler + if samplers is not None and any(s is not None for s in samplers): + sampled = [] + for i, s in enumerate(samplers): + fn = s if s is not None else self.sampler + sampled.append(fn(logprobs[i])) + sampled = mx.stack(sampled) + else: + sampled = self.sampler(logprobs) + return sampled, logprobs def stats(self): @@ -1111,10 +1212,19 @@ def _next(self, **kwargs): batch = self.active_batch y, logprobs = batch.y, batch.logprobs - batch.y, batch.logprobs = self._step(y[:, None], batch.cache) + batch.y, batch.logprobs = self._step( + y[:, None], + batch.cache, + samplers=batch.samplers, + logits_processors=batch.logits_processors, + ) mx.async_eval(batch.y, batch.logprobs) y = y.tolist() + # Track generated tokens per sequence + if batch.tokens is not None: + for i, t in enumerate(y): + batch.tokens[i].append(t) toc = time.perf_counter() if prompt_processing: self._stats.prompt_time += toc - tic @@ -1129,7 +1239,19 @@ def _next(self, **kwargs): ): num_tok += 1 batch.num_tokens[e] = num_tok - if self.tokenizer.stopping_criteria(t): + + # Use SequenceStateMachine for stop detection (supports multi-token sequences) + is_stop = False + if batch.state_machine_states is not None: + state = batch.state_machine_states[e] + new_state, matched_seq, next_name = SequenceStateMachine.match(state, t) + batch.state_machine_states[e] = new_state + if next_name is None and matched_seq is not None: + is_stop = True + else: + is_stop = self.tokenizer.stopping_criteria(t) + + if is_stop: finish_reason = "stop" end_idx.append(e) elif num_tok >= max_tok: @@ -1382,6 +1504,16 @@ def _generate_batch( kwargs.pop("prefill_step_size", None) kwargs["prefill_step_size"] = None + # Ensure stop tokens are passed to BatchGenerator + if "stop_tokens" not in kwargs: + eos_ids = getattr(model.config, "eos_token_id", None) + if eos_ids is not None: + if isinstance(eos_ids, int): + eos_ids = {eos_ids} + elif isinstance(eos_ids, list): + eos_ids = set(eos_ids) + kwargs["stop_tokens"] = eos_ids + # Use batch_size for prefill and completion to ensure consistent processing gen = BatchGenerator( model.language_model, diff --git a/mlx_vlm/models/cache.py b/mlx_vlm/models/cache.py index 505c8c27..df85bfc9 100644 --- a/mlx_vlm/models/cache.py +++ b/mlx_vlm/models/cache.py @@ -172,6 +172,15 @@ def is_trimmable(self): def trim(self, n): return 0 + @property + def nbytes(self): + if self.keys is None: + return 0 + return self.keys.nbytes + self.values.nbytes + + def empty(self): + return self.keys is None + class StaticKVCache(_BaseCache): """A static cache that grows to accommodate all tokens.""" @@ -237,3 +246,12 @@ def trim(self, n): n = min(self.offset, n) self.offset -= n return n + + @property + def nbytes(self): + if self.keys is None: + return 0 + return self.keys.nbytes + self.values.nbytes + + def empty(self): + return self.keys is None diff --git a/mlx_vlm/models/gemma4/gemma4.py b/mlx_vlm/models/gemma4/gemma4.py index 1698f316..5b57c9ba 100644 --- a/mlx_vlm/models/gemma4/gemma4.py +++ b/mlx_vlm/models/gemma4/gemma4.py @@ -11,12 +11,27 @@ def masked_scatter(input_tensor, mask, source): - mask_flat = mask.flatten().astype(mx.int32) - indices = mx.cumsum(mask_flat) - 1 - aligned = source.flatten()[indices % source.size] - return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape( - input_tensor.shape - ) + B = input_tensor.shape[0] + if B == 1: + mask_flat = mask.flatten().astype(mx.int32) + indices = mx.cumsum(mask_flat) - 1 + aligned = source.flatten()[indices % source.size] + return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape( + input_tensor.shape + ) + # Per-batch scatter to avoid cross-batch index contamination + results = [] + for i in range(B): + inp_i = input_tensor[i] # (seq_len, dim) + mask_i = mask[i] # (seq_len, dim) + src_i = source[i] if source.shape[0] > 1 else source[0] # (n_tokens, dim) + mask_flat = mask_i.flatten().astype(mx.int32) + indices = mx.cumsum(mask_flat) - 1 + aligned = src_i.flatten()[indices % src_i.size] + results.append( + mx.where(mask_flat, aligned, inp_i.flatten()).reshape(inp_i.shape) + ) + return mx.stack(results) class MultimodalEmbedder(nn.Module): diff --git a/mlx_vlm/models/gemma4/language.py b/mlx_vlm/models/gemma4/language.py index ac419b9f..fc1f8f7d 100644 --- a/mlx_vlm/models/gemma4/language.py +++ b/mlx_vlm/models/gemma4/language.py @@ -204,10 +204,14 @@ def __call__( if self.is_kv_shared_layer and cache is not None: state = cache.state keys, values = state[0], state[1] - offset = cache.offset + o = cache.offset + offset = int(o) if not isinstance(o, mx.array) else o + 0 else: if cache is not None: - offset = cache.offset + # Snapshot offset before update_and_fetch mutates it + # (BatchRotatingKVCache.offset is a mutable mx.array) + o = cache.offset + offset = int(o) if not isinstance(o, mx.array) else o + 0 keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) diff --git a/mlx_vlm/models/gemma4/vision.py b/mlx_vlm/models/gemma4/vision.py index 7ac6a804..6a2d6927 100644 --- a/mlx_vlm/models/gemma4/vision.py +++ b/mlx_vlm/models/gemma4/vision.py @@ -488,15 +488,19 @@ def __call__(self, pixel_values: mx.array) -> mx.array: else: valid_mask = ~pool_mask - # For single batch (typical VLM case), count valid tokens and slice - # Since pooling produces contiguous valid tokens followed by padding, - # we can simply count valid tokens and take that many - all_real = [] - for i in range(B): - n_valid = int(valid_mask[i].astype(mx.int32).sum().item()) - all_real.append(pooled[i, :n_valid]) - - hidden_states = mx.concatenate(all_real, axis=0)[None] # [1, total_real, dim] + # Strip padding tokens: count valid tokens per image and slice + valid_counts = [ + int(valid_mask[i].astype(mx.int32).sum().item()) for i in range(B) + ] + max_valid = max(valid_counts) + + if B == 1: + # Single image: no padding needed + hidden_states = pooled[:, :max_valid] + else: + # Batch: slice each image to max_valid tokens (same-shape images + # produce the same count; different-shape images get zero-padded) + hidden_states = pooled[:, :max_valid] if self.config.standardize: hidden_states = (hidden_states - self.std_bias) * self.std_scale diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 6f7e40ca..1df3436f 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1416,13 +1416,17 @@ def add_eos_token_ids(self, new_eos_token_ids: Union[int, List[int]] = None): raise ValueError("Processor is not provided") if new_eos_token_ids is not None: - if isinstance(new_eos_token_ids, str): + if isinstance(new_eos_token_ids, (str, int)): new_eos_token_ids = [new_eos_token_ids] - new_eos_token_ids = [ - self.tokenizer.encode(" " + token, add_special_tokens=False)[-1] - for token in new_eos_token_ids - ] - self.eos_token_ids.extend(new_eos_token_ids) + resolved = [] + for token in new_eos_token_ids: + if isinstance(token, int): + resolved.append(token) + else: + resolved.append( + self.tokenizer.encode(" " + token, add_special_tokens=False)[-1] + ) + self.eos_token_ids.extend(resolved) def reset(self, eos_token_ids: List[int] = None): eos_token_ids = (