diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ef143cba7fb5..ee6f747bdde1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -62,6 +62,7 @@ PromMetricT, ) from vllm.forward_context import ForwardContext + from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -446,6 +447,16 @@ def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: # Scheduler-side methods # ============================== + def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None: + """ + Bind the GPU block pool to the connector for per-GPU block status tracking. + For example, inc/dec ref counts, or iterate over the prefix cache blocks. + + Args: + gpu_block_pool: the GPU block pool. + """ + return + @abstractmethod def get_num_new_matched_tokens( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 3888d2e0f44c..8fe736e56be7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext + from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -219,6 +220,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) + def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None: + for c in self._connectors: + c.bind_gpu_block_pool(gpu_block_pool) + # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py index 6475b941ba59..4f4d3234833d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py @@ -165,7 +165,6 @@ def build_connector_worker_meta(self): # --- Scheduler-side methods --- - # NOTE: New API only for SimpleCPUOffloadConnector. def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None: if self.scheduler_manager is not None: self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f61d54faedc7..e9d221220a89 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -236,9 +236,7 @@ def __init__( ) # Bind GPU block pool to the KV connector. This must happen after # kv_cache_manager is constructed so block_pool is available. - if self.connector is not None and hasattr( - self.connector, "bind_gpu_block_pool" - ): + if self.connector is not None: self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool) self.use_pp = self.parallel_config.pipeline_parallel_size > 1