Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,27 @@ def _step(
return sampled, list(logprobs)

def _process_prompts(self, prompts):
uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts)
(
uids,
inputs,
max_tokens,
caches,
samplers,
logits_processors,
prompt_checkpoints,
) = zip(*prompts)

lengths = [len(p) for p in inputs]
max_length = max(lengths)
padding = [max_length - l for l in lengths]

# Compute effective prompt checkpoint exactly as upstream BatchGenerator does.
# When prompt_checkpoints are default (-1), this yields 1 — matching old behavior.
prompt_checkpoints_offsets = [
(l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints)
]
prompt_checkpoint = max(1, max(prompt_checkpoints_offsets))

# Collect per-UID VLM embeddings registered by _schedule_waiting().
vlm_embeds_map: Dict[int, Tuple[mx.array, Dict[str, Any], int]] = {}
for uid in uids:
Expand Down Expand Up @@ -391,8 +406,8 @@ def _process_prompts(self, prompts):
# per-block ArraysCache snapshots.
all_boundaries = True

while inputs.shape[1] > 1:
max_allowed = inputs.shape[1] - 1
while inputs.shape[1] > prompt_checkpoint:
max_allowed = inputs.shape[1] - prompt_checkpoint
if boundary_enabled:
n_to_process = self._next_boundary_limited_step(
processed_tokens,
Expand Down Expand Up @@ -482,7 +497,7 @@ def _process_prompts(self, prompts):
# 2. Process
# 3. Finalize the KV caches so they are left padded again
else:
last_inputs = mx.array([p[-1:] for p in inputs])
last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs])
inputs = _right_pad_prompts(inputs, max_length=max_length)
prompt_cache = _merge_caches(caches)

Expand All @@ -499,11 +514,15 @@ def _process_prompts(self, prompts):
all_boundaries = True

for c in prompt_cache:
# subtract one from lengths since we don't process the last token during prefill
c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding)
# subtract prompt_checkpoint from lengths since we don't process
# the last prompt_checkpoint tokens during prefill
c.prepare(
lengths=[max(0, l - prompt_checkpoint) for l in lengths],
right_padding=padding,
)

while inputs.shape[1] > 1:
max_allowed = inputs.shape[1] - 1
while inputs.shape[1] > prompt_checkpoint:
max_allowed = inputs.shape[1] - prompt_checkpoint
if boundary_enabled:
n_to_process = self._next_boundary_limited_step(
processed_tokens,
Expand Down Expand Up @@ -590,6 +609,33 @@ def _process_prompts(self, prompts):

for c in prompt_cache:
c.finalize()

# Emit prompt checkpoint callback for upstream parity.
# When prompt_checkpoint > 1, process remaining tokens before _step.
if self.prompt_checkpoint_callback is not None:
self.prompt_checkpoint_callback(
[
(uid, prompt_checkpoint, tuple(c.extract(i) for c in prompt_cache))
for i, uid in enumerate(uids)
]
)
if prompt_checkpoint > 1:
model_kwargs_cp = {}
if batched_embeds is not None and batched_embeds.shape[1] >= (prompt_checkpoint - 1):
# Slice VLM embeds for the checkpoint-to-last-1 range
model_kwargs_cp["inputs_embeds"] = batched_embeds[:, :prompt_checkpoint - 1]
if batched_extra:
model_kwargs_cp["vlm_extra_kwargs"] = _slice_vlm_extra(
batched_extra, prompt_checkpoint - 1
)
self.model(inputs[:, :prompt_checkpoint - 1], cache=prompt_cache, **model_kwargs_cp)
mx.eval([c.state for c in prompt_cache])
inputs = inputs[:, prompt_checkpoint - 1:]
if batched_embeds is not None:
batched_embeds = batched_embeds[:, prompt_checkpoint - 1:]
if batched_extra:
batched_extra = _advance_vlm_extra(batched_extra, prompt_checkpoint - 1)

mx.clear_cache()

# Pass remaining VLM embeddings (last token) to _step if available.
Expand Down