diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 88d144cb..f1490d73 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -392,6 +392,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, + _prompt_checkpoints, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) @@ -403,7 +404,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 if not is_cached: padded = _left_pad_prompts(inputs_raw, max_length=max_length) - prompt_cache = _make_cache(self.model, padding) + prompt_cache = _make_cache(self.model, padding, self.max_kv_size) else: last_inputs = mx.array([p[-1:] for p in inputs_raw]) padded = _right_pad_prompts(inputs_raw, max_length=max_length)