diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a7ec0de37263..812b0f50387c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -190,6 +190,8 @@ def __init__( self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + self.num_requests_being_async_loaded = 0 + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -210,6 +212,9 @@ def schedule(self) -> SchedulerOutput: req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + # budget for asynchronously loaded external tokens + # heuristically double than the normal token budget + async_load_token_budget = 2 * self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens @@ -409,6 +414,8 @@ def schedule(self) -> SchedulerOutput: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING + self.num_requests_being_async_loaded -= 1 + assert self.num_requests_being_async_loaded >= 0 else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", @@ -418,6 +425,12 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue + if ( + len(self.running) + self.num_requests_being_async_loaded + == self.max_num_running_reqs + ): + break + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -586,6 +599,10 @@ def schedule(self) -> SchedulerOutput: # into the WAITING_FOR_REMOTE_KV state. skipped_waiting_requests.prepend_request(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + self.num_requests_being_async_loaded += 1 + async_load_token_budget -= request.num_external_computed_tokens + if async_load_token_budget <= 0: + break continue self._update_connector_prefix_cache_stats(request) @@ -1293,6 +1310,8 @@ def finish_requests( if request.status == RequestStatus.RUNNING: running_requests_to_remove.add(request) else: + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + self.num_requests_being_async_loaded -= 1 waiting_requests_to_remove.append(request) # Remove all requests from queues at once for better efficiency