diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index eca61c43..74ec5a10 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -448,7 +448,16 @@ def _schedule_waiting(self) -> List[Request]: break # Determine tokens to process and cache to use - tokens_to_process = request.remaining_tokens or request.prompt_token_ids + # Note: Don't use `remaining_tokens or prompt_token_ids` because empty list + # is falsy in Python. For exact cache match, remaining_tokens=[] but we should + # pass just the last token so BatchGenerator can start generation. + if request.remaining_tokens is not None and len(request.remaining_tokens) == 0: + # Exact cache match - pass only last token for generation kickoff + tokens_to_process = request.prompt_token_ids[-1:] + elif request.remaining_tokens: + tokens_to_process = request.remaining_tokens + else: + tokens_to_process = request.prompt_token_ids cache_to_use = request.prompt_cache # May be None # Validate cache before using it @@ -590,16 +599,18 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: if request is not None and request.prompt_token_ids: if self.block_aware_cache is not None: # Store in paged cache + # Key includes both prompt and output tokens for multi-turn chat caching if hasattr(request, '_extracted_cache') and request._extracted_cache is not None: try: + full_token_sequence = list(request.prompt_token_ids) + list(request.output_token_ids) self.block_aware_cache.store_cache( request_id, - request.prompt_token_ids, + full_token_sequence, request._extracted_cache, ) logger.debug( f"Stored paged cache for request {request_id} " - f"({len(request.prompt_token_ids)} tokens)" + f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)" ) except Exception as e: logger.debug(f"Failed to store paged cache for {request_id}: {e}") @@ -609,15 +620,18 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: elif self.prefix_cache is not None: # Store in standard prefix cache + # Key includes both prompt and output tokens for multi-turn chat caching + # The next turn's prompt will include the previous response if hasattr(request, '_extracted_cache') and request._extracted_cache is not None: try: + full_token_sequence = list(request.prompt_token_ids) + list(request.output_token_ids) self.prefix_cache.store_cache( - request.prompt_token_ids, + full_token_sequence, request._extracted_cache, ) logger.debug( f"Stored cache for request {request_id} " - f"({len(request.prompt_token_ids)} tokens)" + f"({len(full_token_sequence)} tokens: {len(request.prompt_token_ids)} prompt + {len(request.output_token_ids)} output)" ) except Exception as e: logger.debug(f"Failed to store cache for {request_id}: {e}")