diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index b8364b237e9d..3fc735efa68e 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -213,6 +213,15 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con --kv-transfer-config '{..., "enable_permute_local_kv":"True"}' ``` +### Cross layers blocks + +By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. +To enable this feature: + +```bash +--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}' +``` + ## Example Scripts/Code Refer to these example scripts in the vLLM repository: diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 2e25e2f1ac32..cdbcdca546e7 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -14,8 +14,8 @@ tp_configs=( "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" ) dp_ep_configs=( -"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) -"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) +"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) +"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) # Select config array based on DP_EP env var @@ -57,3 +57,9 @@ if [[ -n "${FLASHINFER:-}" ]]; then else echo "FLASHINFER not set, skipping FLASHINFER runs." fi + +# Check if cross-layers is enabled (non-empty) +if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then + echo "CROSS_LAYERS_BLOCKS is set, rerunning with --enable-cross-layers" + run_tests "default backend" "--enable-cross-layers" +fi diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c2c38f51c500..560ce4407038 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -4,6 +4,7 @@ set -xe # Parse command line arguments KV_BUFFER_DEVICE="cuda" # Default to cuda ATTENTION_BACKEND="" # Default to empty (use vllm default) +CROSS_LAYERS_BLOCKS="False" while [[ $# -gt 0 ]]; do case $1 in --kv_buffer_device) @@ -14,6 +15,10 @@ while [[ $# -gt 0 ]]; do ATTENTION_BACKEND="$2" shift 2 ;; + --enable-cross-layers) + CROSS_LAYERS_BLOCKS="True" + shift 1 + ;; *) echo "Unknown option $1" echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]" @@ -34,11 +39,17 @@ else KV_CONFIG_HETERO_LAYOUT='' fi +if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then + KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"enable_cross_layers_blocks": "True"}' +else + KV_EXTRA_CONFIG='' +fi + # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" fi # Models to run diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a3fe2e21b56e..1975d2226073 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -18,8 +18,12 @@ import torch from vllm import LLM -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.config import KVTransferConfig, set_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.utils import ( + KVOutputAggregator, + TpKVTopology, + get_current_attn_backend, +) from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( @@ -46,10 +50,14 @@ from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin +from vllm.v1.worker.utils import AttentionGroup from .utils import create_request, create_scheduler, create_vllm_config @@ -366,6 +374,7 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + decode_connector.register_kv_caches(kv_caches) # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent @@ -402,6 +411,23 @@ def __init__( self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. self.src_xfer_handles_by_block_size = {self.block_size: 1} + test_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=self.attn_backend, + tensor_shape=test_shape, + ) + + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks + ) def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -1352,6 +1378,7 @@ def req_id(outputs: list[RequestOutput]) -> str: llm.llm_engine.engine_core.shutdown() +@pytest.mark.parametrize("enable_cross_layers", ["False", "True"]) @pytest.mark.parametrize( "attn_backend", [ @@ -1372,7 +1399,9 @@ def req_id(outputs: list[RequestOutput]) -> str: "TRITON_ATTN", ], ) -def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): +def test_register_kv_caches( + default_vllm_config, dist_init, attn_backend, enable_cross_layers +): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1386,6 +1415,12 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): vllm_config = create_vllm_config(attention_backend=attn_backend) + # Enable cross layers blocks + vllm_config.kv_transfer_config.kv_connector_extra_config[ + "enable_cross_layers_blocks" + ] = enable_cross_layers + set_kv_cache_layout("HND") + # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend @@ -1400,44 +1435,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): backend_cls = TritonAttentionBackend - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - - # Store tensor info for validation - - test_shape = backend_cls.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - - if is_blocks_first: - expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() - expected_base_addrs = [ - shared_tensor.data_ptr(), - unique_tensor.data_ptr(), - ] - expected_num_entries = 2 - else: - expected_tensor_size = ( - shared_tensor[0].element_size() * shared_tensor[0].numel() - ) - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] - expected_num_entries = 4 - nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, @@ -1466,6 +1463,111 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False + expected_tensor_size: int + expected_base_addrs: list[int] + expected_num_entries: int + kv_caches: dict[str, torch.Tensor] + assert str(enable_cross_layers).lower() != "true" or ( + (attn_backend not in ("FLASH_ATTN", "FLASHINFER")) + or connector.prefer_cross_layer_blocks + ) + if connector.prefer_cross_layer_blocks: + num_layers = 32 + block_size = 16 + num_blocks = 8 + kv_cache_spec = AttentionSpec( + block_size=block_size, + num_kv_heads=4, + head_size=64, + dtype=torch.bfloat16, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=kv_cache_spec.page_size_bytes * num_blocks, + shared_by=["dummy-layer"], + ) + for i in range(num_layers) + ], + # allocate_uniform_kv_caches does not use this + kv_cache_groups=[], + ) + + with set_current_vllm_config(vllm_config): + _, cross_layers_kv_cache, _ = ( + KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( + kv_cache_config=kv_cache_config, + attn_groups=[ + [ + AttentionGroup( + backend=backend_cls, + layer_names=[], + kv_cache_spec=kv_cache_spec, + kv_cache_group_id=0, + ) + ] + ], + cache_dtype=torch.bfloat16, + device=torch.cuda.current_device(), + kernel_block_sizes=[block_size], + ) + ) + # Store tensor info for validation + expected_tensor_size = ( + cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() + ) + expected_base_addrs = [ + cross_layers_kv_cache.data_ptr(), + ] + expected_num_entries = 1 + + expected_blocks_count = 8 + + kv_caches = {"all-layers": cross_layers_kv_cache} + + else: + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 + + if is_blocks_first: + expected_tensor_size = ( + shared_tensor.element_size() * shared_tensor.numel() + ) + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + else: + expected_tensor_size = ( + shared_tensor[0].element_size() * shared_tensor[0].numel() + ) + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + ] + expected_num_entries = 4 + expected_blocks_count = 8 + # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1489,16 +1591,19 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] # Validate blocks_data structure and size - expected_blocks_count = 8 assert len(blocks_data) == expected_blocks_count, ( f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - num_blocks = 2 - if is_blocks_first: - expected_block_len = expected_tensor_size // num_blocks // 2 - else: + if connector.prefer_cross_layer_blocks: + num_blocks = 8 expected_block_len = expected_tensor_size // num_blocks + else: + num_blocks = 2 + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 + else: + expected_block_len = expected_tensor_size // num_blocks for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry @@ -1507,6 +1612,8 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): f"got {block_len}" ) + assert connector.connector_worker.block_size == 16 + class FakePlatform(Platform): device_type: str = "oot" @@ -2049,6 +2156,17 @@ def test_compatibility_hash_validation( ) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker + kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + decode_connector.register_kv_caches(kv_caches) remote_config_params: dict[str, Any] = { "model": "facebook/opt-125m", @@ -2071,7 +2189,9 @@ def test_compatibility_hash_validation( ) ) remote_hash = compute_nixl_compatibility_hash( - remote_vllm_config, decode_worker.backend_name + remote_vllm_config, + decode_worker.backend_name, + decode_worker.kv_topo.cross_layers_blocks, ) prefill_block_size = config_overrides.get("block_size", 16) @@ -2150,6 +2270,27 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker + backend = get_current_attn_backend(local_vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + decode_worker.kv_topo = TpKVTopology( + tp_rank=decode_worker.tp_rank, + engine_id=decode_worker.engine_id, + remote_tp_size=decode_worker._tp_size, # shared state + remote_block_size=decode_worker._block_size, # shared state + is_mla=decode_worker.use_mla, + total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), + attn_backend=backend, + tensor_shape=test_shape, + ) + + decode_worker.compat_hash = compute_nixl_compatibility_hash( + decode_worker.vllm_config, + decode_worker.backend_name, + decode_worker.kv_topo.cross_layers_blocks, + ) + if error_scenario == "handshake_decode_error": msg_bytes = b"this is not valid msgpack data" elif error_scenario == "handshake_validation_error": diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index fd833e293938..019201ede73e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -192,8 +193,6 @@ def copy_kv_blocks( dst_device=dst_device, ) - from vllm.platforms import current_platform - if direction == "h2d": copy_fn = current_platform.insert_blocks_to_device else: @@ -316,12 +315,14 @@ class TpKVTopology: attn_backend: type[AttentionBackend] engine_id: EngineId remote_block_size: dict[EngineId, int] + tensor_shape: torch.Size | None = None def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. + _MOCK_BLOCK_SIZE = 16 kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. @@ -329,6 +330,36 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) + self._cross_layers_blocks = False + if self.tensor_shape is not None: + self._cross_layers_blocks = ( + len(self.tensor_shape) == len(kv_cache_shape) + 1 + ) + + if self._cross_layers_blocks: + # prepend layers dimension + _MOCK_NUM_LAYERS = 80 + kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=self._cross_layers_blocks + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + + # In case of cross layers permute kv_cache_shape according to + # stride_order to retrieve physical position of block_size + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + + # In the default non-cross layers layout the block_size position + # is logical while in the cross layers case it is the physical + # position. This matches the shape of the actual kv cache tensors + # passed at register_kv_caches()/register_cross_layers_kv_cache() + block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) + + assert block_size_position is not None + self._block_size_position = -(len(kv_cache_shape) - block_size_position) + @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -336,7 +367,9 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not (self.is_mla or self.is_kv_layout_blocks_first) + return not ( + self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first + ) @property def tp_size(self) -> int: @@ -346,6 +379,14 @@ def tp_size(self) -> int: def block_size(self) -> int: return self.remote_block_size[self.engine_id] + @property + def cross_layers_blocks(self) -> bool: + return self._cross_layers_blocks + + @property + def block_size_position(self) -> int: + return self._block_size_position + def tp_ratio( self, remote_tp_size: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d03d70860396..8ce939ee405e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -54,7 +54,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket -from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.block_table import BlockTable @@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str + vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool ) -> str: """ Compute compatibility hash for NIXL KV transfer. @@ -216,6 +216,7 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), + "cross_layers_blocks": cross_layers_blocks, } compat_hash = hash_factors(factors) @@ -298,6 +299,26 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1): + @property + def prefer_cross_layer_blocks(self) -> bool: + backend = get_current_attn_backend(self._vllm_config) + if backend.get_name() not in ( + "FLASH_ATTN", + "FLASHINFER", + ): + return False + + # For now there is no benefit to run cross layers when backend + # does not support on HND + if get_kv_cache_layout() != "HND": + return False + + extra_config = self.kv_transfer_config.kv_connector_extra_config + return ( + str(extra_config.get("enable_cross_layers_blocks", "False")).lower() + == "true" + ) + def __init__( self, vllm_config: VllmConfig, @@ -309,7 +330,7 @@ def __init__( assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id - + self.kv_transfer_config = vllm_config.kv_transfer_config if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( NixlConnectorScheduler(vllm_config, self.engine_id) @@ -395,6 +416,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + assert self.connector_worker is not None + + cross_layer_name = "ALL_LAYERS" + kv_caches = {cross_layer_name: kv_cache} + + self.connector_worker.register_kv_caches(kv_caches) + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) @@ -976,20 +1007,17 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Get the attention backend from the first layer # NOTE (NickLucche) models with multiple backends are not supported yet - backend = get_current_attn_backend(vllm_config) + self.attn_backend = get_current_attn_backend(vllm_config) - self.backend_name = backend.get_name() + self.backend_name = self.attn_backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) - self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name - ) - self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( - "enforce_handshake_compat", True - ) + # lazy initialized in register_kv_caches + self.compat_hash: str | None = None + self.kv_topo: TpKVTopology | None = None self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} @@ -998,17 +1026,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = TpKVTopology( - tp_rank=self.tp_rank, - engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, - ) self._physical_blocks_per_logical_kv_block = 1 + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) + def _nixl_handshake( self, host: str, @@ -1022,6 +1045,7 @@ def _nixl_handshake( # Regardless, only handshake with the remote TP rank(s) that current # local rank will read from. Note that With homogeneous TP, # this happens to be the same single rank_i. + assert self.kv_topo is not None p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) remote_rank_to_agent_name = {} path = make_zmq_path("tcp", host, port) @@ -1059,6 +1083,7 @@ def _nixl_handshake( ) # Check compatibility hash BEFORE decoding agent metadata + assert self.compat_hash is not None if ( self.enforce_compat_hash and handshake_payload.compatibility_hash != self.compat_hash @@ -1267,6 +1292,20 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=self.attn_backend, + tensor_shape=next(iter(kv_caches.values())).shape, + ) + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks + ) + if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -1301,29 +1340,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None - # TODO (NickLucche): Get kernel_block_size in a cleaner way - # NHD default "view" for non-MLA cache - if self.device_type == "cpu": - block_size_position = -2 - else: - block_size_position = -2 if self.use_mla else -3 - # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - + cache_list = ( + cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] + ) for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: continue - kernel_block_size = cache.shape[block_size_position] - + kernel_block_size = cache.shape[self.kv_topo.block_size_position] if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" @@ -1385,6 +1416,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks + if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 @@ -1440,6 +1472,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ) # Wrap metadata in payload with hash for defensive decoding + assert self.compat_hash is not None encoder = msgspec.msgpack.Encoder() self.xfer_handshake_metadata = NixlHandshakePayload( compatibility_hash=self.compat_hash, @@ -1461,6 +1494,8 @@ def register_local_xfer_handler( register another local_xfer_handler using remote block len to ensure data copy correctness. """ + assert self.kv_topo is not None + block_size_ratio = self.block_size // block_size blocks_data = [] for i, base_addr in enumerate(self.seen_base_addresses): @@ -1573,6 +1608,7 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| + assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) if engine_id not in self.dst_num_blocks: @@ -1701,6 +1737,7 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size + assert self.kv_topo is not None tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1837,6 +1874,7 @@ def post_process_device_kv_on_receive( if len(self.device_kv_caches) == 0: return assert block_size_ratio >= 1, "Only nP < nD supported currently." + assert self.kv_topo is not None if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( "Post-processing device kv cache on receive by converting " @@ -1856,7 +1894,7 @@ def post_process_device_kv_on_receive( block_size_ratio, ) - split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) + split_k_and_v = self.kv_topo.split_k_and_v for block_ids in block_ids_list: indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) @@ -1881,6 +1919,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: The scheduler process (via the MultiprocExecutor) will use this output to track which workers are done. """ + assert self.kv_topo is not None done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -1950,6 +1989,7 @@ def _get_new_notifs(self) -> set[str]: are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. """ + assert self.kv_topo is not None notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: @@ -2109,7 +2149,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - assert meta.remote is not None + assert meta.remote is not None and self.kv_topo is not None remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( meta.remote.engine_id ) @@ -2182,6 +2222,7 @@ def _read_blocks( Post a READ point-to-point xfer request from a single local worker to a single remote worker. """ + assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( @@ -2414,6 +2455,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ + assert self.kv_topo is not None if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2