Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,9 @@ def __init__(
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
prompt_checkpoint_callback: Optional[
Callable[[List[Tuple[int, int, List[Any]]]], None]
] = None,
prompt_progress_callback: Optional[
Callable[[List[Tuple[int, int, int]]], None]
] = None,
Expand All @@ -963,6 +966,7 @@ def __init__(
self.prefill_step_size = prefill_step_size
self.prefill_batch_size = prefill_batch_size
self.completion_batch_size = max(completion_batch_size, prefill_batch_size)
self.prompt_checkpoint_callback = prompt_checkpoint_callback
self.prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
self._stats = BatchStats()
self._next_count = 0
Expand Down Expand Up @@ -993,12 +997,16 @@ def insert(
caches=None,
samplers: list | None = None,
logits_processors: list | None = None,
prompt_checkpoints: list | int | None = None,
):
uids = []

if max_tokens is None or isinstance(max_tokens, int):
max_tokens = [max_tokens or self.max_tokens] * len(prompts)

if prompt_checkpoints is None or isinstance(prompt_checkpoints, int):
prompt_checkpoints = [prompt_checkpoints or -1] * len(prompts)

if caches is None:
caches = [None] * len(prompts)
for i in range(len(prompts)):
Expand All @@ -1008,10 +1016,10 @@ def insert(
samplers = samplers or [None] * len(prompts)
logits_processors = logits_processors or [self.logits_processors] * len(prompts)

for p, m, c, s, lp in zip(
prompts, max_tokens, caches, samplers, logits_processors
for p, m, c, s, lp, pc in zip(
prompts, max_tokens, caches, samplers, logits_processors, prompt_checkpoints
):
self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp))
self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp, pc))
uids.append(self.uid_count)
self.uid_count += 1
# Sort in ascending order of length
Expand Down Expand Up @@ -1052,12 +1060,28 @@ def prompt_cache_nbytes(self):
return total

def _process_prompts(self, prompts):
uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts)
(
uids,
inputs,
max_tokens,
caches,
samplers,
logits_processors,
prompt_checkpoints,
) = zip(*prompts)

lengths = [len(p) for p in inputs]
max_length = max(lengths)
padding = [max_length - l for l in lengths]

# Get the checkpoint token as an offset from the end of each prompt.
# Then select the largest one so that we perform the checkpoint at
# least `pc` before the end.
prompt_checkpoints = [
(l - pc if pc > 0 else -pc) for l, pc in zip(lengths, prompt_checkpoints)
]
prompt_checkpoint = max(1, max(prompt_checkpoints))

self._stats.prompt_tokens += sum(lengths)

tokens = [mx.array(inp) for inp in inputs]
Expand All @@ -1070,8 +1094,10 @@ def _process_prompts(self, prompts):
inputs = _left_pad_prompts(inputs, max_length=max_length)
prompt_cache = _make_cache(self.model, padding, self.max_kv_size)

while inputs.shape[1] > 1:
n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1)
while inputs.shape[1] > prompt_checkpoint:
n_to_process = min(
self.prefill_step_size, inputs.shape[1] - prompt_checkpoint
)
self.model(inputs[:, :n_to_process], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
inputs = inputs[:, n_to_process:]
Expand All @@ -1090,16 +1116,22 @@ def _process_prompts(self, prompts):
# 2. Process
# 3. Finalize the KV caches so they are left padded again
else:
last_inputs = mx.array([p[-1:] for p in inputs])
last_inputs = mx.array([p[-prompt_checkpoint:] for p in inputs])
inputs = _right_pad_prompts(inputs, max_length=max_length)
prompt_cache = _merge_caches(caches)

for c in prompt_cache:
# subtract one from lengths since we don't process the last token during prefill
c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding)
# subtract from lengths since we don't process the last
# `prompt_checkpoint` tokens during prefill
c.prepare(
lengths=[l - prompt_checkpoint for l in lengths],
right_padding=padding,
)

while inputs.shape[1] > 1:
n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1)
while inputs.shape[1] > prompt_checkpoint:
n_to_process = min(
self.prefill_step_size, inputs.shape[1] - prompt_checkpoint
)
self.model(inputs[:, :n_to_process], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
inputs = inputs[:, n_to_process:]
Expand All @@ -1117,6 +1149,20 @@ def _process_prompts(self, prompts):

for c in prompt_cache:
c.finalize()

# We processed L - prompt_checkpoint tokens so call the checkpoint
# callback.
if self.prompt_checkpoint_callback is not None:
self.prompt_checkpoint_callback(
[
(uid, prompt_checkpoint, (c.extract(i) for c in prompt_cache))
for i, uid in enumerate(uids)
]
)
# Process the remaining prompt_checkpoint-1 tokens
if prompt_checkpoint > 1:
self.model(inputs[:, : prompt_checkpoint - 1], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
mx.clear_cache()

y, logprobs = self._step(
Expand Down
Loading
Loading