diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index b44f2db1926b..ef94d64b10be 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -10,7 +10,11 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 - from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput + from vllm.v1.core.sched.output import ( + GrammarOutput, + KVCacheUsageMetrics, + SchedulerOutput, + ) from vllm.v1.engine import EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -226,6 +230,11 @@ def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" raise NotImplementedError + @abstractmethod + def get_kv_cache_usage(self) -> "KVCacheUsageMetrics": + """Return current KV cache usage (percentage, used blocks, used tokens).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> "SchedulerStats | None": """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 7e53f4f2ec9e..bf01c3840afb 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -179,6 +179,21 @@ def make_empty(cls) -> "CachedRequestData": ) +@dataclass +class KVCacheUsageMetrics: + """KV cache usage metrics from the scheduler.""" + + # Usage as a percentage (0.0 to 100.0). + usage_pct: float + # Number of blocks currently in use (excludes the reserved null block). + used_blocks: int + # Total number of allocatable blocks (num_gpu_blocks - 1). + total_blocks: int + # Approximate number of tokens represented by + # used blocks (used_blocks * block_size). + used_tokens: int + + @bc_linter_include @dataclass class SchedulerOutput: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bf397ad681ca..f22088cad8dd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -42,6 +42,7 @@ from vllm.v1.core.sched.output import ( CachedRequestData, GrammarOutput, + KVCacheUsageMetrics, NewRequestData, SchedulerOutput, ) @@ -1255,6 +1256,22 @@ def get_grammar_bitmask( ) return GrammarOutput(structured_output_request_ids, bitmask) + def get_kv_cache_usage(self) -> KVCacheUsageMetrics: + """Return current KV cache usage (percentage, used blocks, used tokens).""" + pool = self.kv_cache_manager.block_pool + total_blocks = pool.num_gpu_blocks - 1 # exclude null block + num_free = pool.get_num_free_blocks() + used_blocks = total_blocks - num_free + usage_fraction = self.kv_cache_manager.usage # 0.0 to 1.0 + usage_pct = usage_fraction * 100.0 + used_tokens = used_blocks * self.block_size + return KVCacheUsageMetrics( + usage_pct=usage_pct, + used_blocks=used_blocks, + total_blocks=total_blocks, + used_tokens=used_tokens, + ) + def update_from_output( self, scheduler_output: SchedulerOutput, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a258fe295068..25ae5347eac7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -41,7 +41,7 @@ init_none_hash, ) from vllm.v1.core.sched.interface import PauseState, SchedulerInterface -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import KVCacheUsageMetrics, SchedulerOutput from vllm.v1.engine import ( EngineCoreOutput, EngineCoreOutputs, @@ -343,7 +343,11 @@ def log_error_detail(self, scheduler_output: SchedulerOutput): raise err @contextmanager - def log_iteration_details(self, scheduler_output: SchedulerOutput): + def log_iteration_details( + self, + scheduler_output: SchedulerOutput, + kv_cache_usage: KVCacheUsageMetrics | None = None, + ): if not self.vllm_config.observability_config.enable_logging_iteration_details: yield return @@ -351,25 +355,36 @@ def log_iteration_details(self, scheduler_output: SchedulerOutput): iteration_details = compute_iteration_details(scheduler_output) before = time.monotonic() yield - logger.info( - "".join( + log_parts: list[str] = [ + "Iteration(", + str(self._iteration_index), + "): ", + str(iteration_details.num_ctx_requests), + " context requests, ", + str(iteration_details.num_ctx_tokens), + " context tokens, ", + str(iteration_details.num_generation_requests), + " generation requests, ", + str(iteration_details.num_generation_tokens), + " generation tokens, iteration elapsed time: ", + format((time.monotonic() - before) * 1000, ".2f"), + " ms", + ] + if kv_cache_usage is not None: + log_parts.extend( [ - "Iteration(", - str(self._iteration_index), - "): ", - str(iteration_details.num_ctx_requests), - " context requests, ", - str(iteration_details.num_ctx_tokens), - " context tokens, ", - str(iteration_details.num_generation_requests), - " generation requests, ", - str(iteration_details.num_generation_tokens), - " generation tokens, iteration elapsed time: ", - format((time.monotonic() - before) * 1000, ".2f"), - " ms", + ", kv cache: ", + format(kv_cache_usage.usage_pct, ".1f"), + "% (", + str(kv_cache_usage.used_blocks), + "/", + str(kv_cache_usage.total_blocks), + " blocks, ", + str(kv_cache_usage.used_tokens), + " tokens)", ] ) - ) + logger.info("".join(log_parts)) self._iteration_index += 1 def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: @@ -388,7 +403,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) with ( self.log_error_detail(scheduler_output), - self.log_iteration_details(scheduler_output), + self.log_iteration_details( + scheduler_output, + self.scheduler.get_kv_cache_usage(), + ), ): model_output = future.result() if model_output is None: @@ -462,8 +480,8 @@ def step_with_batch_queue( grammar_output, non_block=True ) else: - # We need to defer sampling until we have processed the model output - # from the prior step. + # We need to defer sampling until we have processed + # the model output from the prior step. deferred_scheduler_output = scheduler_output if not deferred_scheduler_output: @@ -488,7 +506,10 @@ def step_with_batch_queue( future, scheduler_output, exec_model_fut = batch_queue.pop() with ( self.log_error_detail(scheduler_output), - self.log_iteration_details(scheduler_output), + self.log_iteration_details( + scheduler_output, + self.scheduler.get_kv_cache_usage(), + ), ): model_output = future.result() if model_output is None: