diff --git a/tests/test_logger.py b/tests/test_logger.py index b4f44f52d4df..2c1cf9c2ce24 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -452,6 +452,7 @@ def test_request_logger_log_outputs_integration(): prompt_embeds=None, params=None, lora_request=None, + cache_hit_threshold=None, ) request_logger.log_outputs( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index acac3753d712..073073493d2c 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1415,6 +1415,176 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 +def _iterate_until_done(scheduler: Scheduler): + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + +@pytest.mark.parametrize( + "global_threshold," + "request_num_tokens," + "request_local_hit_blocks," + "request_external_hit_blocks," + "request_thresholds," + "request_expected_scehduled", + [ + ( + 0.0, + [57, 34, 28], + [1, 1, 0], + [0, 1, 0], + # expected hit ratio: [0.281, 0.941, 0.0] + # calculated as (local + external) * BLOCK_SIZE / tokens + [None, 0.4, 0.1], + [True, True, False], + ), + ( + 0.3, + [157, 134, 128, 20, 150], + [4, 1, 0, 0, 1], + [2, 4, 0, 1, 0], + # expected hit ratio: [0.611, 0.597, 0.0, 0.8, 0.106] + [0.8, 0.4, 0.1, None, None], + [False, True, False, True, False], + ), + ], +) +def test_cache_hit_threshold( + # we validate global_threshold is used when request threshold is None + global_threshold: float, + # number of tokens in each request + request_num_tokens: list[int], + # number of blocks hit in local cache per request + request_local_hit_blocks: list[int], + # number of blocks hit in external cache per request + request_external_hit_blocks: list[int], + # optional cache_hit_threshold for each request + request_thresholds: list[float | None], + # bool per request indicating if it is expected to be scheduled + request_expected_scehduled: list[bool], +): + assert ( + len(request_num_tokens) + == len(request_thresholds) + == len(request_local_hit_blocks) + == len(request_external_hit_blocks) + == len(request_expected_scehduled) + ) + + scheduler = create_scheduler( + enable_prefix_caching=True, + global_cache_hit_threshold=global_threshold, + use_kv_connector=True, + ) + + _insert_to_local_cache(request_local_hit_blocks, scheduler) + _mock_external_cache_hit(request_external_hit_blocks, scheduler) + + requests, scheduler_output = _create_and_schedule_requests( + request_num_tokens, request_thresholds, scheduler + ) + + # assert all requests expected to be scheduled are indeed scheduled + assert [ + r.request_id + for r, expected in zip(requests, request_expected_scehduled) + if expected + ] == [s.req_id for s in scheduler_output.scheduled_new_reqs] + + # assert other requests are "finished" due to cache threshold + requests_expected_not_scheduled = [ + r for r, expected in zip(requests, request_expected_scehduled) if not expected + ] + assert all( + r.status == RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD + for r in requests_expected_not_scheduled + ) + + _iterate_until_done(scheduler) + assert_scheduler_empty(scheduler) + + +def _create_and_schedule_requests( + request_num_tokens: list[int], + request_thresholds: list[float | None], + scheduler: Scheduler, +): + num_requests = len(request_num_tokens) + requests = create_requests( + num_requests=num_requests, + num_tokens=request_num_tokens, + block_size=scheduler.cache_config.block_size, + cache_hit_thresholds=request_thresholds, + ) + + for request in requests: + scheduler.add_request(request) + + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + return requests, scheduler_output + + +def _mock_external_cache_hit(request_external_hit_blocks, scheduler: Scheduler): + BLOCK_SIZE = scheduler.cache_config.block_size + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.side_effect = [ + (i * BLOCK_SIZE, False) for i in request_external_hit_blocks + ] + + +def _insert_to_local_cache(request_local_hit_blocks, scheduler: Scheduler): + """Schedule requests to fill in the local cache""" + BLOCK_SIZE = scheduler.cache_config.block_size + num_total_requests = len(request_local_hit_blocks) + + requests_to_schedule = [ + i for i, hit_blocks in enumerate(request_local_hit_blocks) if hit_blocks > 0 + ] + + num_requests_to_schedule = len(requests_to_schedule) + if num_requests_to_schedule == 0: + # nothing to do + return + + # Mock no external Cache Hit for this cache-warmup phase + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = (0, False) + + # set threshold to 0.0 to ensure all are scheduled + zero_thresholds: list[float | None] = [0.0] * num_total_requests + + # Only requests with local hits should run and populate the cache + # We create all requests to make sure the correct tokens are cached + # (since the tokens are generated according to request id) + requests = create_requests( + num_requests=num_total_requests, + num_tokens=[x * BLOCK_SIZE for x in request_local_hit_blocks], + block_size=BLOCK_SIZE, + cache_hit_thresholds=zero_thresholds, + ) + + # Only schedule the request we want to run and populate the cache + for i in requests_to_schedule: + scheduler.add_request(requests[i]) + + scheduler_output = scheduler.schedule() + + # verify all were indeed scheduled + assert len(scheduler_output.scheduled_new_reqs) == num_requests_to_schedule + + # iterate until all scheduled requests are done + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + _iterate_until_done(scheduler) + assert_scheduler_empty(scheduler) + + def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], @@ -1487,13 +1657,7 @@ def test_memory_leak(): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) - # Iterate until done. - while True: - scheduler_output = scheduler.schedule() - if len(scheduler.running) == 0: - break - model_runner_output = make_output(scheduler) - scheduler.update_from_output(scheduler_output, model_runner_output) + _iterate_until_done(scheduler) # Confirm no memory leak. assert_scheduler_empty(scheduler) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 751a29795634..e1f2a4f75dbb 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -57,6 +57,7 @@ def create_scheduler( pipeline_parallel_size: int = 1, use_ec_connector: bool = False, ec_role: str | None = None, + global_cache_hit_threshold: float = 0.0, ) -> Scheduler | AsyncScheduler: """Create scheduler under test. @@ -89,6 +90,14 @@ def create_scheduler( enable_chunked_prefill=enable_chunked_prefill, async_scheduling=async_scheduling, is_encoder_decoder=model_config.is_encoder_decoder, + global_cache_hit_threshold=global_cache_hit_threshold, + ) + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC cache_config = CacheConfig( @@ -171,8 +180,8 @@ def create_scheduler( def create_requests( num_requests: int, - num_tokens: int = 10, mm_hashes_list: list[list[str]] | None = None, + num_tokens: int | list[int] = 10, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, @@ -180,12 +189,14 @@ def create_requests( same_prompt: bool = False, block_size: int = 16, req_ids: list[str] | None = None, + cache_hit_thresholds: list[float | None] | None = None, ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: init_none_hash(sha256) _none_hash_initialized = True - + if cache_hit_thresholds is not None: + assert len(cache_hit_thresholds) == num_requests block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams( ignore_eos=False, @@ -243,7 +254,16 @@ def create_requests( ) mm_features.append(mm_feature) - prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens + request_num_tokens: int = ( + num_tokens[i] if isinstance(num_tokens, list) else num_tokens + ) + prompt_token_ids = ( + [0] * request_num_tokens if same_prompt else [i] * request_num_tokens + ) + if cache_hit_thresholds is not None: + cache_hit_threshold = cache_hit_thresholds[i] + else: + cache_hit_threshold = None request = Request( request_id=req_ids[i], prompt_token_ids=prompt_token_ids, @@ -252,6 +272,7 @@ def create_requests( mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, + cache_hit_threshold=cache_hit_threshold, ) requests.append(request) return requests diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 5ff9fc930a56..e0a7850d9030 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -152,6 +152,14 @@ def default_factory(**kwargs): kwargs["is_encoder_decoder"] = False return SchedulerConfig(**kwargs) + global_cache_hit_threshold: float = 0.0 + """The threshold for cache hit ratio to handle all requests, + except for requests which override it using the "cache_hit_threshold" field. + This feature enables Decode-first optimization in P/D disaggregation: + Decode nodes can avoide remote Prefill in case of high cache hit ratio. + If set to 0.0, the optimization is disabled. + """ + def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: if self.async_scheduling: @@ -292,4 +300,13 @@ def verify_max_model_len(self, max_model_len: int) -> Self: f"{self.max_num_partial_prefills=}." ) + if (self.global_cache_hit_threshold < 0.0) or ( + self.global_cache_hit_threshold > 1.0 + ): + raise ValueError( + "global_cache_hit_threshold " + f"({self.global_cache_hit_threshold}) " + "must be between 0.0 and 1.0, inclusive." + ) + return self diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 1f8f5e5dbff9..e6ab6c1b6b9a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1376,7 +1376,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}" + f"compilation_config={self.compilation_config!r}", + f"global_cache_hit_threshold={self.scheduler_config.global_cache_hit_threshold}", ) @model_validator(mode="after") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 82be97ce6842..598cca4ff27e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -579,6 +579,7 @@ class EngineArgs: kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False + global_cache_hit_threshold: float = SchedulerConfig.global_cache_hit_threshold def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1124,6 +1125,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument( "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] ) + scheduler_group.add_argument( + "--global-cache-hit-threshold", + **scheduler_kwargs["global_cache_hit_threshold"], + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", **scheduler_kwargs["disable_hybrid_kv_cache_manager"], @@ -1651,6 +1656,7 @@ def create_engine_config( disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, stream_interval=self.stream_interval, + global_cache_hit_threshold=self.global_cache_hit_threshold, ) if not model_config.is_multimodal_model and self.default_mm_loras: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 205efd1d582e..8c285fcfe8b6 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -59,6 +59,7 @@ def generate( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... @@ -74,6 +75,7 @@ def encode( priority: int = 0, truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model. diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 7512723515e0..b31fcf6c58a7 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -62,10 +62,13 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: stream = request_dict.pop("stream", False) # Since SamplingParams is created fresh per request, safe to skip clone sampling_params = SamplingParams(**request_dict, skip_clone=True) + cache_hit_threshold = request_dict.pop("cache_hit_threshold", None) request_id = random_uuid() assert engine is not None - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = engine.generate( + prompt, sampling_params, request_id, cache_hit_threshold=cache_hit_threshold + ) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index c9e809353b59..4ac1e72139e2 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -26,6 +26,7 @@ def log_inputs( prompt_embeds: torch.Tensor | None, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, + cache_hit_threshold: float | None = None, ) -> None: if logger.isEnabledFor(logging.DEBUG): max_log_len = self.max_log_len @@ -47,10 +48,12 @@ def log_inputs( ) logger.info( - "Received request %s: params: %s, lora_request: %s.", + "Received request %s: params: %s, lora_request: %s ", + "cache_hit_threshold: %s.", request_id, params, lora_request, + cache_hit_threshold, ) def log_outputs( diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index a76dc73d9ba3..45d7bb962fa3 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -154,6 +154,17 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +def _validate_cache_hit_threshold(cls, data): + cache_hit_threshold = data.get("cache_hit_threshold") + if cache_hit_threshold is not None and ( + cache_hit_threshold < 0.0 or cache_hit_threshold > 1.0 + ): + raise ValueError( + "Parameter `cache_hit_threshold` must be between 0.0 and 1.0 if provided." + ) + return data + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -346,6 +357,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) + cache_hit_threshold: float | None = Field( + default=None, + description="Minimum required KV-cache hit ratio to process the request.", + ) + vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( @@ -649,6 +665,18 @@ def check_generation_prompt(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def _validate_cache_hit_threshold(cls, data): + cache_hit_threshold = data.get("cache_hit_threshold") + if cache_hit_threshold is not None and ( + cache_hit_threshold < 0.0 or cache_hit_threshold > 1.0 + ): + raise ValueError( + "`cache_hit_threshold` must be between 0.0 and 1.0 if provided." + ) + return data + @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index aa79e9da30fc..0eaf4e0d3b83 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -406,11 +406,14 @@ async def create_chat_completion( sampling_params, ) + cache_hit_threshold = request.cache_hit_threshold + self._log_inputs( sub_request_id, engine_prompt, params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = ( @@ -436,6 +439,7 @@ async def create_chat_completion( trace_headers=trace_headers, priority=request.priority, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=request.cache_hit_threshold, ) generator = self.engine_client.generate( @@ -448,6 +452,7 @@ async def create_chat_completion( prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=request.cache_hit_threshold, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index fc773c402ede..c85b58f1fc4f 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -168,6 +168,11 @@ class CompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) + cache_hit_threshold: float | None = Field( + default=None, + description="Minimum required KV-cache hit ratio to process the request.", + ) + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( @@ -400,6 +405,18 @@ def check_cache_salt_support(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def _validate_cache_hit_threshold(cls, data): + cache_hit_threshold = data.get("cache_hit_threshold") + if cache_hit_threshold is not None and ( + cache_hit_threshold < 0.0 or cache_hit_threshold > 1.0 + ): + raise ValueError( + "`cache_hit_threshold` must be between 0.0 and 1.0 if provided." + ) + return data + class CompletionLogProbs(OpenAIBaseModel): text_offset: list[int] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 24cf486a61fe..3d1113e20912 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -189,12 +189,14 @@ async def create_completion( ) request_id_item = f"{request_id}-{i}" + cache_hit_threshold = request.cache_hit_threshold self._log_inputs( request_id_item, engine_prompt, params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = ( @@ -224,6 +226,7 @@ async def create_completion( trace_headers=trace_headers, priority=request.priority, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.generate( @@ -236,6 +239,7 @@ async def create_completion( prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=request.cache_hit_threshold, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 0433f28d978f..70e5340e2db7 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -655,12 +655,13 @@ async def _prepare_generators( for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) self._log_inputs( request_id_item, engine_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.encode( @@ -670,6 +671,7 @@ async def _prepare_generators( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) generators.append(generator) @@ -1236,6 +1238,7 @@ async def _process_inputs( trace_headers: Mapping[str, str] | None, priority: int, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1252,6 +1255,7 @@ async def _process_inputs( trace_headers=trace_headers, priority=priority, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=cache_hit_threshold, ) return engine_request, tokenization_kwargs @@ -1297,11 +1301,13 @@ async def _generate_with_builtin_tools( while True: # Ensure that each sub-request has a unique request id. sub_request_id = f"{request_id}_{sub_request}" + cache_hit_threshold = kwargs.get("cache_hit_threshold") self._log_inputs( sub_request_id, engine_prompt, params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = kwargs.get("trace_headers") engine_request, tokenization_kwargs = await self._process_inputs( @@ -1311,6 +1317,7 @@ async def _generate_with_builtin_tools( lora_request=lora_request, trace_headers=trace_headers, priority=priority, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.generate( @@ -1377,6 +1384,7 @@ def _log_inputs( inputs: PromptType, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, + cache_hit_threshold: float | None = None, ) -> None: if self.request_logger is None: return @@ -1390,6 +1398,7 @@ def _log_inputs( prompt_embeds, params=params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) async def _get_trace_headers( diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 22a6188878c8..b746fb7fd083 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -238,12 +238,14 @@ async def _process_chunked_request( # Create engine prompt for this chunk chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens) + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) # Log the chunk self._log_inputs( chunk_request_id, chunk_engine_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) # Create generator for this chunk and wrap it to return indices @@ -254,6 +256,7 @@ async def _process_chunked_request( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) generators.append(original_generator) @@ -345,12 +348,14 @@ async def _create_single_prompt_generator( ) -> AsyncGenerator[PoolingRequestOutput, None]: """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) self._log_inputs( request_id_item, engine_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) # Return the original generator without wrapping @@ -361,6 +366,7 @@ async def _create_single_prompt_generator( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) async def _prepare_generators( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 30a459386a73..1eb7634deca7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -138,6 +138,11 @@ def __init__( config=self.vllm_config, role=ECConnectorRole.SCHEDULER ) + # List to collect requests that are below the cache hit ratio + # threshold. These requests will be finished and the list cleared + # in update_from_output(). + self.cache_hit_below_threshold_request_ids: list[str] = [] + num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -626,6 +631,39 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens ) + # Cache hit threshold in request overrides global setting + scheduler_config = self.vllm_config.scheduler_config + cache_hit_threshold = ( + request.cache_hit_threshold + if request.cache_hit_threshold is not None + else scheduler_config.global_cache_hit_threshold + ) + + # Check if cache hit is above threshold + cache_hit_percent = ( + num_computed_tokens / request.num_prompt_tokens + if request.num_prompt_tokens > 0 + else 0.0 + ) + if cache_hit_percent < cache_hit_threshold: + threshold_source = ( + "request" + if request.cache_hit_threshold is not None + else "global" + ) + logger.debug( + "Request %s rejected: cache hit rate %.2f" + " < threshold %.2f (%s)", + request.request_id, + cache_hit_percent, + cache_hit_threshold, + threshold_source, + ) + self.waiting.pop_request() + self.cache_hit_below_threshold_request_ids.append( + request.request_id + ) + continue else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -1437,6 +1475,26 @@ def update_from_output( batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) + # Handle requests that were rejected due to low cache hit rate. + if self.cache_hit_below_threshold_request_ids: + for req_id in self.cache_hit_below_threshold_request_ids: + req = self.requests.get(req_id) + if req is None: + # The request is already finished, e.g. aborted. + continue + # Add EngineCoreOutput for this Request. + req.status = RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD + outputs[req.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=[], + finish_reason=req.get_finished_reason(), + ) + ) + self._free_request(req) + # Clear the list after finishing all such requests. + self.cache_hit_below_threshold_request_ids.clear() + # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = { diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e8e44746bf47..d9d41e1e5a3d 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -20,7 +20,7 @@ # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "error", "cache_threshold") class FinishReason(enum.IntEnum): @@ -34,6 +34,7 @@ class FinishReason(enum.IntEnum): abort - aborted by client error - retryable request-level internal error (e.g., KV load failure). Invariant: always converted to 500 Internal Server Error. + cache_threshold - not handled due to cache hit below threshold """ @@ -41,6 +42,7 @@ class FinishReason(enum.IntEnum): LENGTH = 1 ABORT = 2 ERROR = 3 + CACHE_THRESHOLD = 4 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -63,6 +65,7 @@ class EngineCoreRequest( cache_salt: str | None data_parallel_rank: int | None prompt_embeds: torch.Tensor | None = None + cache_hit_threshold: float | None = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 9f40f41a10a5..18f7842f30b5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -297,6 +297,7 @@ async def add_request( priority: int = 0, data_parallel_rank: int | None = None, prompt_text: str | None = None, + cache_hit_threshold: float | None = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -362,6 +363,7 @@ async def add_request( trace_headers, priority, data_parallel_rank, + cache_hit_threshold=cache_hit_threshold, ) prompt_text = get_prompt_text(prompt) @@ -533,6 +535,7 @@ async def generate( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request @@ -561,6 +564,7 @@ async def generate( priority=priority, data_parallel_rank=data_parallel_rank, prompt_text=prompt_text, + cache_hit_threshold=cache_hit_threshold, ) # The output_handler task pushes items into the queue. @@ -771,6 +775,7 @@ async def encode( priority: int = 0, truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ Main function called by the API server to kick off a request @@ -808,6 +813,7 @@ async def encode( tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, + cache_hit_threshold=cache_hit_threshold, ) # The output_handler task pushes items into the queue. diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index f7f1608ec1b4..c455aa05b99d 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -488,6 +488,7 @@ def process_inputs( priority: int = 0, data_parallel_rank: int | None = None, resumable: bool = False, + cache_hit_threshold: float | None = None, ) -> EngineCoreRequest: self._validate_lora(lora_request) self._validate_params(params) @@ -633,6 +634,7 @@ def process_inputs( data_parallel_rank=data_parallel_rank, trace_headers=trace_headers, resumable=resumable, + cache_hit_threshold=cache_hit_threshold, ) def _validate_model_inputs( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5811e94dd3cc..06ad799484e5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -219,6 +219,7 @@ def add_request( trace_headers: Mapping[str, str] | None = None, priority: int = 0, prompt_text: str | None = None, + cache_hit_threshold: float | None = None, ) -> None: # Validate the request_id type. if not isinstance(request_id, str): @@ -244,6 +245,7 @@ def add_request( tokenization_kwargs, trace_headers, priority, + cache_hit_threshold=cache_hit_threshold, ) if isinstance(prompt, str): prompt_text = prompt diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index cb1a860e38fb..45487466a34d 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -266,7 +266,10 @@ def update_from_output( num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens - if is_prefilling: + + # num_new_generation_tokens can be 0, e.g. if cache hit threshold is not met + # and in that case we do not want to influence TTFT stats + if is_prefilling and num_new_generation_tokens > 0: self.num_prompt_tokens += prompt_len first_token_latency = self._time_since(req_stats.arrival_time) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b963fea43df5..ec6aa44138fe 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -72,6 +72,7 @@ def __init__( cache_salt: str | None = None, priority: int = 0, trace_headers: Mapping[str, str] | None = None, + cache_hit_threshold: float | None = None, block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, resumable: bool = False, ) -> None: @@ -95,6 +96,8 @@ def __init__( # P/D: Connector-specific KV transfer parameters. self.kv_transfer_params: dict[str, Any] | None = None + self.cache_hit_threshold: float | None = cache_hit_threshold + if pooling_params is not None: # Pooling models. self.max_tokens = 1 @@ -192,6 +195,7 @@ def from_engine_core_request( trace_headers=request.trace_headers, block_hasher=block_hasher, resumable=request.resumable, + cache_hit_threshold=request.cache_hit_threshold, ) def append_output_token_ids( @@ -298,6 +302,7 @@ class RequestStatus(enum.IntEnum): FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() FINISHED_ERROR = enum.auto() + FINISHED_CACHE_HIT_BELOW_THRESHOLD = enum.auto() def __str__(self) -> str: return self.name @@ -322,4 +327,5 @@ def get_finished_reason(status: "RequestStatus") -> FinishReason | None: RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, RequestStatus.FINISHED_ERROR: FinishReason.ERROR, RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP, + RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD: FinishReason.CACHE_THRESHOLD, }