diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index d59a9cbdd46a..10fa4f14f237 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,7 +9,7 @@ import time import uuid from collections import defaultdict -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock, patch import msgspec @@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init): # Prefill connector will register KV cache to populate proper handshake # metadata. + # TODO this must match with values used in kv cache config + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) prefill_connector = NixlConnector( - vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + vllm_config, KVConnectorRole.WORKER, kv_cache_config + ) + kv_cache_spec = cast( + AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec ) kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + num_blocks=kv_cache_config.num_blocks, + block_size=kv_cache_spec.block_size, + num_kv_heads=kv_cache_spec.num_kv_heads, + head_size=kv_cache_spec.head_size, ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, @@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector( - vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + vllm_config, KVConnectorRole.WORKER, kv_cache_config ) decode_connector.register_kv_caches(kv_caches) @@ -525,11 +533,13 @@ def test_multi_xfer_one_engine( request_id = "req_id" # Test worker role in decode server. - connector = NixlConnector( - vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) - ) + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0 + vllm_config, + connector.engine_id, + hand_shake_latency=0, + kv_cache_config=kv_cache_config, ) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) worker = connector.connector_worker @@ -1479,18 +1489,22 @@ def test_register_kv_caches( patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_module}.threading.Thread") as mock_thread, patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, + patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends, ): # Ensure get_attn_backend returns the correct value due to # _cached_get_attn_backend returning the backend from previous # test run if not mocking. mock_get_attn_backend.return_value = backend_cls + mock_get_attn_backends.return_value = [backend_cls] # Create connector - connector = NixlConnector( - vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) - ) + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0 + vllm_config, + connector.engine_id, + hand_shake_latency=0, + kv_cache_config=kv_cache_config, ) # Get the mock instance @@ -1515,6 +1529,13 @@ def test_register_kv_caches( num_layers = 32 block_size = 16 num_blocks = 8 + # Keep the fake worker's expected num_blocks in sync with the + # cross-layer tensor we are about to register. + worker_kv_cache_config = make_kv_cache_config( + block_size=block_size, num_blocks=num_blocks + ) + connector.connector_worker.kv_cache_config = worker_kv_cache_config + connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks kv_cache_spec = AttentionSpec( block_size=block_size, num_kv_heads=4, @@ -1568,11 +1589,17 @@ def test_register_kv_caches( else: # Create test kv cache tensors using proper backend shape + kv_cache_spec = cast( + AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec + ) kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + num_blocks=kv_cache_config.num_blocks, + block_size=kv_cache_spec.block_size, + num_kv_heads=kv_cache_spec.num_kv_heads, + head_size=kv_cache_spec.head_size, ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, @@ -1606,7 +1633,7 @@ def test_register_kv_caches( unique_tensor[1].data_ptr(), ] expected_num_entries = 4 - expected_blocks_count = 8 + expected_blocks_count = kv_cache_config.num_blocks * 4 # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1639,7 +1666,7 @@ def test_register_kv_caches( num_blocks = 8 expected_block_len = expected_tensor_size // num_blocks else: - num_blocks = 2 + num_blocks = kv_cache_config.num_blocks if is_blocks_first: expected_block_len = expected_tensor_size // num_blocks // 2 else: @@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation( "enforce_handshake_compat": enforce_handshake_compat }, ) + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) decode_connector = NixlConnector( - local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + local_vllm_config, KVConnectorRole.WORKER, kv_cache_config ) decode_worker = decode_connector.connector_worker + kv_cache_spec = cast( + AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec + ) kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + num_blocks=kv_cache_config.num_blocks, + block_size=kv_cache_spec.block_size, + num_kv_heads=kv_cache_spec.num_kv_heads, + head_size=kv_cache_spec.head_size, ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c8a6c1301444..dd23d9dfaf64 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -38,7 +38,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.worker.utils import AttentionGroup, select_common_block_size +from vllm.v1.worker.utils import select_common_block_size BLOCK_SIZE = 16 NUM_BLOCKS = 10 @@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec: def test_select_common_block_size_prefers_manager_block_size(): backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)]) backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)]) - attn_groups = [ - AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), - AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), - ] - selected_size = select_common_block_size(128, attn_groups) + selected_size = select_common_block_size(128, [backend_a, backend_b]) assert selected_size == 128 def test_select_common_block_size_uses_largest_shared_int(): backend_a = _make_mock_backend_for_kernel_block_size([128, 64]) backend_b = _make_mock_backend_for_kernel_block_size([64, 32]) - attn_groups = [ - AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), - AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), - ] - selected_size = select_common_block_size(256, attn_groups) + selected_size = select_common_block_size(256, [backend_a, backend_b]) assert selected_size == 64 def test_select_common_block_size_no_valid_option(): backend_a = _make_mock_backend_for_kernel_block_size([64]) backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)]) - attn_groups = [ - AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0), - AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0), - ] with pytest.raises(ValueError): - select_common_block_size(48, attn_groups) + select_common_block_size(48, [backend_a, backend_b]) def test_update_states_new_request(model_runner, dist_init): diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 6e0366c5202f..ef45a69079a6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -357,15 +357,6 @@ def __post_init__(self): # stride_order to retrieve physical position of block_size kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - # In the default non-cross layers layout the block_size position - # is logical while in the cross layers case it is the physical - # position. This matches the shape of the actual kv cache tensors - # passed at register_kv_caches()/register_cross_layers_kv_cache() - block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) - - assert block_size_position is not None - self._block_size_position = -(len(kv_cache_shape) - block_size_position) - @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -389,10 +380,6 @@ def block_size(self) -> int: def cross_layers_blocks(self) -> bool: return self._cross_layers_blocks - @property - def block_size_position(self) -> int: - return self._block_size_position - def tp_ratio( self, remote_tp_size: int, @@ -483,23 +470,46 @@ def get_target_remote_ranks_from_engine_id( return self.get_target_remote_ranks(remote_tp_size) -def get_current_attn_backend(vllm_config: VllmConfig): +def get_current_attn_backends( + vllm_config: VllmConfig, layer_names: list[str] | None = None +) -> list[type[AttentionBackend]]: + """Get all distinct attention backends for the given layers. + + Args: + vllm_config: The current vLLM configuration. + layer_names: Optional list of layer names to scope the lookup. + When None, all attention layers are considered. + + Returns: + Deduplicated list of attention backend classes. + """ layer_type = cast(type[Any], AttentionLayerBase) - layers = get_layers_from_vllm_config(vllm_config, layer_type, None) + layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) if layers: - backend = next(iter(layers.values())).get_attn_backend() - else: - # Fallback for tests, when static_forward_context is empty. - logger.debug( - "No layers found in the vLLM config. " - "Falling back to default attention backend." - ) - from vllm.v1.attention.selector import get_attn_backend + seen: dict[str, type[AttentionBackend]] = {} + for layer in layers.values(): + backend = layer.get_attn_backend() + seen[backend.full_cls_name()] = backend + return list(seen.values()) + + # Fallback for tests, when static_forward_context is empty. + logger.debug( + "No layers found in the vLLM config. Falling back to default attention backend." + ) + from vllm.v1.attention.selector import get_attn_backend - backend = get_attn_backend( + return [ + get_attn_backend( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, kv_cache_dtype=vllm_config.cache_config.cache_dtype, use_mla=vllm_config.model_config.use_mla, ) - return backend + ] + + +def get_current_attn_backend( + vllm_config: VllmConfig, layer_names: list[str] | None = None +) -> type[AttentionBackend]: + """Get the first attention backend for the given layers.""" + return get_current_attn_backends(vllm_config, layer_names)[0] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 356a837fb36f..1432ef5e342b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import msgspec import numpy as np @@ -27,6 +27,7 @@ EngineId, TpKVTopology, get_current_attn_backend, + get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_on_receive, kv_postprocess_layout_on_receive, @@ -61,6 +62,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -945,7 +947,8 @@ def __init__( # Config. self.vllm_config = vllm_config - self.block_size = vllm_config.cache_config.block_size + # mypy will complain on re-assignment otherwise. + self.block_size: int = cast(int, vllm_config.cache_config.block_size) if vllm_config.kv_transfer_config is None: raise ValueError("kv_transfer_config must be set for NixlConnector") @@ -993,7 +996,7 @@ def __init__( self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() - self.num_blocks = 0 + self.num_blocks = kv_cache_config.num_blocks self.enable_permute_local_kv = False # KV Caches and nixl tracking data. @@ -1128,11 +1131,30 @@ def __init__( self.xfer_stats = NixlKVConnectorStats() self._physical_blocks_per_logical_kv_block = 1 + self._sync_block_size_with_kernel() self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) + def _sync_block_size_with_kernel(self) -> None: + backends = get_current_attn_backends(self.vllm_config) + kernel_block_size = select_common_block_size(self.block_size, backends) + if self.block_size != kernel_block_size: + logger.info_once( + "User-specified logical block size (%s) does not match" + " physical kernel block size (%s). Using the latter.", + self.block_size, + kernel_block_size, + ) + assert self.block_size > kernel_block_size + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) + self.block_size = kernel_block_size + self._block_size[self.engine_id] = kernel_block_size + self.num_blocks *= self._physical_blocks_per_logical_kv_block + def _nixl_handshake( self, host: str, @@ -1466,7 +1488,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() - self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = ( cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] @@ -1483,26 +1504,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug( "Registering layer %s with cache shape: %s", layer_name, cache.shape ) - kernel_block_size = cache.shape[self.kv_topo.block_size_position] - if self.block_size != kernel_block_size: - logger.info_once( - "User-specified logical block size (%s) does not match" - " physical kernel block size (%s). Using the latter. ", - self.block_size, - kernel_block_size, - ) - self._physical_blocks_per_logical_kv_block = ( - self.block_size // kernel_block_size - ) - self.block_size = kernel_block_size - self._block_size[self.engine_id] = kernel_block_size - seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size() if tensor_size_bytes is None: tensor_size_bytes = curr_tensor_size_bytes - self.num_blocks = cache.shape[0] assert cache.shape[0] == self.num_blocks, ( "All kv cache tensors must have the same number of blocks" @@ -1511,9 +1517,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer.append( curr_tensor_size_bytes // self.num_blocks ) - self.slot_size_per_layer.append( - self.block_len_per_layer[-1] // self.block_size - ) if not self.use_mla: # Different kv cache shape is not supported by HeteroTP @@ -1531,7 +1534,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Different block lengths collected: %s", set(self.block_len_per_layer) ) assert len(self.block_len_per_layer) == len(seen_base_addresses) - assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) @@ -1547,10 +1549,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.dst_num_blocks[self.engine_id] = self.num_blocks if self.kv_topo.is_kv_layout_blocks_first: - for i in range(len(self.slot_size_per_layer)): - assert self.slot_size_per_layer[i] % 2 == 0 - self.slot_size_per_layer[i] //= 2 - # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in # registerMem allowing faster descs queries. In order to be able to diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6df8745a500d..d06c40ed64d8 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -258,7 +258,8 @@ def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: def select_common_block_size( - kv_manager_block_size: int, attn_groups: list[AttentionGroup] + kv_manager_block_size: int, + backends: list[type[AttentionBackend]], ) -> int: """ Select a block size that is supported by all backends and is a factor of @@ -269,7 +270,7 @@ def select_common_block_size( Args: kv_manager_block_size: Block size of KV cache. - attn_groups: List of attention groups. + backends: List of attention backend classes. Returns: The selected block size. @@ -297,8 +298,6 @@ def block_size_is_supported( return False return True - backends = [group.backend for group in attn_groups] - # Case 1: if the block_size of kv cache manager is supported by all backends, # return it directly. if block_size_is_supported(backends, kv_manager_block_size): @@ -356,8 +355,9 @@ def prepare_kernel_block_sizes( if isinstance(kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual block splitting. kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + group_backends = [g.backend for g in attn_groups[kv_cache_gid]] selected_kernel_size = select_common_block_size( - kv_manager_block_size, attn_groups[kv_cache_gid] + kv_manager_block_size, group_backends ) kernel_block_sizes.append(selected_kernel_size) elif isinstance(kv_cache_spec, MambaSpec):