From 6af48876e73cb228888e1d3ba6974c40b34e3851 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Wed, 4 Mar 2026 10:06:41 +0200 Subject: [PATCH] [kv_offload+HMA][1/N]: Support multiple KV groups in OffloadingSpec This commit extends OffloadingSpec to support multiple KV cache groups, each with its own possible block size. We now distinguish between 1. The block size of each KV cache group (determined by the KVCacheConfig). 2. The block size used by vLLM for hashing request tokens (determined by cache_config.block_size). This will allow the offloading connector to correctly map request tokens to: 1. KVCacheBlocks (using the per-group block sizes from KVCacheConfig) 2. Request.block_hashes (using the hash block size cache_config.block_size) For now, we keep the offloading connector using the hash_block_size as the block size. Later on, we will modify the offloading connector to use the group-specific block sizes. Signed-off-by: Or Ozeri --- .../unit/test_offloading_connector.py | 42 ++++++++++++++++--- .../kv_connector/v1/offloading_connector.py | 8 ++-- vllm/v1/kv_offload/cpu.py | 16 ++++--- vllm/v1/kv_offload/factory.py | 2 +- vllm/v1/kv_offload/spec.py | 34 +++++++++++---- 5 files changed, 80 insertions(+), 22 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 74c8dbd3024a..893a5d8d4d78 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -26,8 +26,13 @@ get_request_block_hasher, init_none_hash, ) +from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.kv_offload.abstract import ( LoadStoreSpec, OffloadingEvent, @@ -43,11 +48,11 @@ ) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager from .utils import ( EOS_TOKEN_ID, create_model_runner_output, - create_scheduler, create_vllm_config, ) @@ -175,10 +180,37 @@ def __init__( }, ) - self.scheduler: Scheduler = create_scheduler( - vllm_config, num_blocks=num_gpu_blocks + block_size = vllm_config.cache_config.block_size + kv_cache_config = KVCacheConfig( + num_blocks=num_gpu_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ) + ], + ) + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks + self.num_kv_groups = len(kv_cache_config.kv_cache_groups) + + scheduler_cls = AsyncScheduler if async_scheduling else Scheduler + self.scheduler = scheduler_cls( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, + ) + + self.worker_connector = OffloadingConnector( + vllm_config, KVConnectorRole.WORKER, kv_cache_config ) - self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER) # register worker kv_caches to enable OffloadingWorker creations self.worker_connector.register_cross_layers_kv_cache( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 2eb3fa67c978..021f0144d81d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -126,6 +126,7 @@ def __init__( ): super().__init__(vllm_config, role, kv_cache_config) + assert kv_cache_config is not None spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config) self.connector_scheduler: OffloadingConnectorScheduler | None = None @@ -245,9 +246,10 @@ class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" def __init__(self, spec: OffloadingSpec): - self.gpu_block_size = spec.gpu_block_size - self.offloaded_block_size = spec.offloaded_block_size - self.block_size_factor = self.offloaded_block_size // self.gpu_block_size + assert len(spec.gpu_block_size) == 1 + self.gpu_block_size = spec.gpu_block_size[0] + self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor + self.block_size_factor = spec.block_size_factor self.manager: OffloadingManager = spec.get_manager() self._requests: dict[ReqId, Request] = {} diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index b245836a5b67..b1acff99ea1a 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -42,10 +42,8 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig): * len(kv_cache_config.kv_cache_tensors) * vllm_config.parallel_config.world_size ) - kv_bytes_per_offloaded_block = kv_bytes_per_block * ( - self.offloaded_block_size // self.gpu_block_size - ) + kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factor self.num_blocks = ( int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block if kv_bytes_per_offloaded_block > 0 @@ -67,8 +65,11 @@ def get_manager(self) -> OffloadingManager: kv_events_config is not None and kv_events_config.enable_kv_cache_events ) + assert len(self.gpu_block_size) == 1 + gpu_block_size = self.gpu_block_size[0] + offloaded_block_size = gpu_block_size * self.block_size_factor backend = CPUBackend( - block_size=self.offloaded_block_size, num_blocks=self.num_blocks + block_size=offloaded_block_size, num_blocks=self.num_blocks ) if self.eviction_policy == "lru": @@ -111,10 +112,13 @@ def get_handlers( "CPU Offloading is currently only supported on CUDA-alike GPUs" ) + assert len(self.gpu_block_size) == 1 + gpu_block_size = self.gpu_block_size[0] + self._handlers = CpuGpuOffloadingHandlers( attn_backends=attn_backends, - gpu_block_size=self.gpu_block_size, - cpu_block_size=self.offloaded_block_size, + gpu_block_size=gpu_block_size, + cpu_block_size=gpu_block_size * self.block_size_factor, num_cpu_blocks=self.num_blocks, gpu_caches=kv_caches, ) diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index 8fe018b89908..d42f2cc63ba5 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -33,7 +33,7 @@ def loader() -> type[OffloadingSpec]: def create_spec( cls, config: "VllmConfig", - kv_cache_config: "KVCacheConfig | None", + kv_cache_config: "KVCacheConfig", ) -> OffloadingSpec: kv_transfer_config = config.kv_transfer_config assert kv_transfer_config is not None diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index 1d41ea71f46b..6d5c74985ae1 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -21,9 +21,7 @@ class OffloadingSpec(ABC): """Spec for an offloading connector""" - def __init__( - self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" - ): + def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"): logger.warning( "Initializing OffloadingSpec. This API is experimental and " "subject to change in the future as we iterate the design." @@ -35,12 +33,34 @@ def __init__( assert kv_transfer_config is not None self.extra_config = kv_transfer_config.kv_connector_extra_config - self.gpu_block_size = vllm_config.cache_config.block_size - self.offloaded_block_size = int( - self.extra_config.get("block_size", self.gpu_block_size) + # block size used by vLLM for hashing request tokens for the sake + # of enabling prefix caching + self.hash_block_size = vllm_config.cache_config.block_size + # gpu block size per group + self.gpu_block_size: tuple[int, ...] = tuple( + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups ) - assert self.offloaded_block_size % self.gpu_block_size == 0 + for block_size in self.gpu_block_size: + assert block_size % self.hash_block_size == 0 + + # offloaded_block_size / gpu_block_size + self.block_size_factor: int = 1 + + offloaded_block_size = self.extra_config.get("block_size") + if offloaded_block_size is not None: + offloaded_block_size_int = int(offloaded_block_size) + gpu_block_sizes = set(self.gpu_block_size) + assert len(gpu_block_sizes) == 1, ( + "If 'block_size' is specified in kv_connector_extra_config, " + "there must be at least one KV cache group, " + "and all groups must have the same block size." + ) + gpu_block_size = gpu_block_sizes.pop() + + assert offloaded_block_size_int % gpu_block_size == 0 + self.block_size_factor = offloaded_block_size_int // gpu_block_size @abstractmethod def get_manager(self) -> OffloadingManager: