diff --git a/omlx/scheduler.py b/omlx/scheduler.py index d8eae74d..869d79b0 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -321,12 +321,27 @@ def _step( return sampled, list(logprobs) def _process_prompts(self, prompts): - uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts) + ( + uids, + inputs, + max_tokens, + caches, + samplers, + logits_processors, + prompt_checkpoints, + ) = zip(*prompts) lengths = [len(p) for p in inputs] max_length = max(lengths) padding = [max_length - l for l in lengths] + # Compute effective prompt checkpoint exactly as upstream BatchGenerator does. + # When prompt_checkpoints are default (-1), this yields 1 — matching old behavior. + prompt_checkpoints_offsets = [ + (l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(prompt_checkpoints_offsets)) + # Collect per-UID VLM embeddings registered by _schedule_waiting(). vlm_embeds_map: Dict[int, Tuple[mx.array, Dict[str, Any], int]] = {} for uid in uids: @@ -391,8 +406,8 @@ def _process_prompts(self, prompts): # per-block ArraysCache snapshots. all_boundaries = True - while inputs.shape[1] > 1: - max_allowed = inputs.shape[1] - 1 + while inputs.shape[1] > prompt_checkpoint: + max_allowed = inputs.shape[1] - prompt_checkpoint if boundary_enabled: n_to_process = self._next_boundary_limited_step( processed_tokens, @@ -482,7 +497,7 @@ def _process_prompts(self, prompts): # 2. Process # 3. Finalize the KV caches so they are left padded again else: - last_inputs = mx.array([p[-1:] for p in inputs]) + last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs]) inputs = _right_pad_prompts(inputs, max_length=max_length) prompt_cache = _merge_caches(caches) @@ -499,11 +514,15 @@ def _process_prompts(self, prompts): all_boundaries = True for c in prompt_cache: - # subtract one from lengths since we don't process the last token during prefill - c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding) + # subtract prompt_checkpoint from lengths since we don't process + # the last prompt_checkpoint tokens during prefill + c.prepare( + lengths=[max(0, l - prompt_checkpoint) for l in lengths], + right_padding=padding, + ) - while inputs.shape[1] > 1: - max_allowed = inputs.shape[1] - 1 + while inputs.shape[1] > prompt_checkpoint: + max_allowed = inputs.shape[1] - prompt_checkpoint if boundary_enabled: n_to_process = self._next_boundary_limited_step( processed_tokens, @@ -590,6 +609,33 @@ def _process_prompts(self, prompts): for c in prompt_cache: c.finalize() + + # Emit prompt checkpoint callback for upstream parity. + # When prompt_checkpoint > 1, process remaining tokens before _step. + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + (uid, prompt_checkpoint, tuple(c.extract(i) for c in prompt_cache)) + for i, uid in enumerate(uids) + ] + ) + if prompt_checkpoint > 1: + model_kwargs_cp = {} + if batched_embeds is not None and batched_embeds.shape[1] >= (prompt_checkpoint - 1): + # Slice VLM embeds for the checkpoint-to-last-1 range + model_kwargs_cp["inputs_embeds"] = batched_embeds[:, :prompt_checkpoint - 1] + if batched_extra: + model_kwargs_cp["vlm_extra_kwargs"] = _slice_vlm_extra( + batched_extra, prompt_checkpoint - 1 + ) + self.model(inputs[:, :prompt_checkpoint - 1], cache=prompt_cache, **model_kwargs_cp) + mx.eval([c.state for c in prompt_cache]) + inputs = inputs[:, prompt_checkpoint - 1:] + if batched_embeds is not None: + batched_embeds = batched_embeds[:, prompt_checkpoint - 1:] + if batched_extra: + batched_extra = _advance_vlm_extra(batched_extra, prompt_checkpoint - 1) + mx.clear_cache() # Pass remaining VLM embeddings (last token) to _step if available.