Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,16 @@ def _schedule_waiting(self) -> List[Request]:
break

# Determine tokens to process and cache to use
tokens_to_process = request.remaining_tokens or request.prompt_token_ids
# Note: Don't use `remaining_tokens or prompt_token_ids` because empty list
# is falsy in Python. For exact cache match, remaining_tokens=[] but we should
# pass just the last token so BatchGenerator can start generation.
if request.remaining_tokens is not None and len(request.remaining_tokens) == 0:
# Exact cache match - pass only last token for generation kickoff
tokens_to_process = request.prompt_token_ids[-1:]
elif request.remaining_tokens:
tokens_to_process = request.remaining_tokens
else:
tokens_to_process = request.prompt_token_ids
cache_to_use = request.prompt_cache # May be None

# Validate cache before using it
Expand Down Expand Up @@ -590,16 +599,18 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
if request is not None and request.prompt_token_ids:
if self.block_aware_cache is not None:
# Store in paged cache
# Key includes both prompt and output tokens for multi-turn chat caching
if hasattr(request, '_extracted_cache') and request._extracted_cache is not None:
try:
full_token_sequence = list(request.prompt_token_ids) + list(request.output_token_ids)
self.block_aware_cache.store_cache(
request_id,
request.prompt_token_ids,
full_token_sequence,
request._extracted_cache,
)
logger.debug(
f"Stored paged cache for request {request_id} "
f"({len(request.prompt_token_ids)} tokens)"
f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)"
)
except Exception as e:
logger.debug(f"Failed to store paged cache for {request_id}: {e}")
Expand All @@ -609,15 +620,18 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:

elif self.prefix_cache is not None:
# Store in standard prefix cache
# Key includes both prompt and output tokens for multi-turn chat caching
# The next turn's prompt will include the previous response
if hasattr(request, '_extracted_cache') and request._extracted_cache is not None:
try:
full_token_sequence = list(request.prompt_token_ids) + list(request.output_token_ids)
self.prefix_cache.store_cache(
request.prompt_token_ids,
full_token_sequence,
request._extracted_cache,
)
logger.debug(
f"Stored cache for request {request_id} "
f"({len(request.prompt_token_ids)} tokens)"
f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)"
)
except Exception as e:
logger.debug(f"Failed to store cache for {request_id}: {e}")
Expand Down