From 5195265c1a54d4fa84f99d3debb9c780af6d8e88 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 4 Dec 2025 13:11:02 +0000 Subject: [PATCH 01/84] Cross layers implementation Signed-off-by: Liran Schour --- .../kv_connector/v1/nixl_connector.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 514b8534aaa6..6a941dd013a9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, ClassVar import msgspec import numpy as np @@ -20,7 +20,8 @@ import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology @@ -251,6 +252,8 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): + prefer_cross_layer_blocks: ClassVar[bool] = True + def __init__( self, vllm_config: VllmConfig, @@ -348,6 +351,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + assert self.connector_worker is not None + + cross_layer_name = "ALL_LAYERS" + + kv_caches = {cross_layer_name: kv_cache} + self.connector_worker.cross_layers = True + self.connector_worker.register_kv_caches(kv_caches) + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) @@ -783,6 +797,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): if vllm_config.kv_transfer_config is None: raise ValueError("kv_transfer_config must be set for NixlConnector") self.kv_transfer_config = vllm_config.kv_transfer_config + self.cross_layers = False self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] @@ -1202,7 +1217,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO (NickLucche): Get kernel_block_size in a cleaner way # NHD default "view" for non-MLA cache - if self.device_type == "cpu": + if self.device_type == "cpu" or self.cross_layers: block_size_position = -2 else: block_size_position = -2 if self.use_mla else -3 @@ -1211,15 +1226,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + cache_list = cache_or_caches if not self.cross_layers and split_k_and_v else [cache_or_caches] for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: continue kernel_block_size = cache.shape[block_size_position] - if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" From 0f368880ae9357b81d86ece972ae01088f432a3d Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 7 Dec 2025 13:29:16 +0000 Subject: [PATCH 02/84] Fix linting Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6a941dd013a9..5824fc95ee44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Optional import msgspec import numpy as np @@ -21,7 +21,6 @@ from vllm import envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata -from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology @@ -1226,8 +1225,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - - cache_list = cache_or_caches if not self.cross_layers and split_k_and_v else [cache_or_caches] + cache_list = ( + cache_or_caches + if not self.cross_layers and split_k_and_v + else [cache_or_caches] + ) for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: From 8d36b4ba9d2a72795b092049654cec02e2ba90ec Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 10 Dec 2025 05:00:31 +0000 Subject: [PATCH 03/84] Add cross layers compatibility check Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5824fc95ee44..0f438966aa26 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -74,6 +74,7 @@ # 2: Add remote_request_id to kv_transfer_params # NIXL_CONNECTOR_VERSION: int = 2 +CROSS_LAYERS: bool = True GET_META_MSG = b"get_meta_msg" @@ -186,6 +187,7 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), + "cross_layers": CROSS_LAYERS, } compat_hash = hash_factors(factors) From 2a20197c493057f3701379730063e74d7301e438 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 11 Dec 2025 07:58:41 +0000 Subject: [PATCH 04/84] Move cross_layers logic into TpKVTopology Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/utils.py | 11 +++- .../kv_connector/v1/nixl_connector.py | 58 ++++++++++--------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 99d3be57c138..a3d75102d568 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -200,6 +200,7 @@ class TpKVTopology: attn_backend: type[AttentionBackend] engine_id: str remote_block_size: dict[str, int] + cross_layers: bool def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -223,7 +224,7 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + return not (self.cross_layers or self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) @property def tp_size(self) -> int: @@ -233,6 +234,14 @@ def tp_size(self) -> int: def block_size(self) -> int: return self.remote_block_size[self.engine_id] + def block_size_position(self, device_type: str) -> int: + if device_type == "cpu" or self.cross_layers: + block_size_position = -2 + else: + block_size_position = -2 if self.is_mla else -3 + + return block_size_position + def tp_ratio( self, remote_tp_size: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f438966aa26..b9f82ffa0f0b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -144,7 +144,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str + vllm_config: VllmConfig, attn_backend_name: str, cross_layers: str ) -> str: """ Compute compatibility hash for NIXL KV transfer. @@ -187,7 +187,7 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), - "cross_layers": CROSS_LAYERS, + "cross_layers": cross_layers, } compat_hash = hash_factors(factors) @@ -360,7 +360,6 @@ def register_cross_layers_kv_cache( cross_layer_name = "ALL_LAYERS" kv_caches = {cross_layer_name: kv_cache} - self.connector_worker.cross_layers = True self.connector_worker.register_kv_caches(kv_caches) def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): @@ -798,7 +797,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): if vllm_config.kv_transfer_config is None: raise ValueError("kv_transfer_config must be set for NixlConnector") self.kv_transfer_config = vllm_config.kv_transfer_config - self.cross_layers = False self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] @@ -946,13 +944,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) - self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name - ) - self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( - "enforce_handshake_compat", True - ) - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -960,16 +951,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = TpKVTopology( - tp_rank=self.tp_rank, - engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, - ) - self._use_pallas = self.kv_topo._use_pallas + self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( @@ -1179,6 +1161,33 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) + + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, + cross_layers = next(iter(kv_caches)) == "ALL_LAYERS" + ) + self._use_pallas = self.kv_topo._use_pallas + + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name, self.kv_topo.cross_layers + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) + if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -1218,10 +1227,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO (NickLucche): Get kernel_block_size in a cleaner way # NHD default "view" for non-MLA cache - if self.device_type == "cpu" or self.cross_layers: - block_size_position = -2 - else: - block_size_position = -2 if self.use_mla else -3 + block_size_position = self.kv_topo.block_size_position(self.device_type) # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() @@ -1229,7 +1235,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = ( cache_or_caches - if not self.cross_layers and split_k_and_v + if not self.kv_topo.cross_layers and split_k_and_v else [cache_or_caches] ) for cache in cache_list: From 073b30e06a3c33dd97f63aad636286cae6398535 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 11 Dec 2025 12:47:50 +0000 Subject: [PATCH 05/84] Code review minor fix Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b9f82ffa0f0b..4a5a4ccb1c80 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -74,7 +74,6 @@ # 2: Add remote_request_id to kv_transfer_params # NIXL_CONNECTOR_VERSION: int = 2 -CROSS_LAYERS: bool = True GET_META_MSG = b"get_meta_msg" @@ -144,7 +143,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str, cross_layers: str + vllm_config: VllmConfig, attn_backend_name: str, cross_layers: bool ) -> str: """ Compute compatibility hash for NIXL KV transfer. From b403a9e00b188f53efee371d7331fe181a507ba1 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 11 Dec 2025 12:54:21 +0000 Subject: [PATCH 06/84] Linting... Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 7 ++++++- .../kv_transfer/kv_connector/v1/nixl_connector.py | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index a3d75102d568..cb5f6a941570 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -224,7 +224,12 @@ def is_kv_layout_blocks_first(self) -> bool: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not (self.cross_layers or self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + return not ( + self.cross_layers + or self.is_mla + or self._use_pallas + or self.is_kv_layout_blocks_first + ) @property def tp_size(self) -> int: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4a5a4ccb1c80..872b0b72f993 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -950,7 +950,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( @@ -1176,7 +1175,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backend=backend, - cross_layers = next(iter(kv_caches)) == "ALL_LAYERS" + cross_layers=next(iter(kv_caches)) == "ALL_LAYERS", ) self._use_pallas = self.kv_topo._use_pallas From 06d31842590cb6277379e5a1c56c3e0ed23a466a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 17 Dec 2025 09:56:30 +0000 Subject: [PATCH 07/84] Code review fixes Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/utils.py | 25 +++++++++---- .../kv_connector/v1/nixl_connector.py | 37 ++++++++----------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index cb5f6a941570..582322f15b4d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -198,9 +198,9 @@ class TpKVTopology: is_mla: bool total_num_kv_heads: int attn_backend: type[AttentionBackend] + tensor_shape: torch.Size engine_id: str remote_block_size: dict[str, int] - cross_layers: bool def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -217,6 +217,12 @@ def __post_init__(self): attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + test_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 + ) + + self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) + @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -225,7 +231,7 @@ def is_kv_layout_blocks_first(self) -> bool: def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). return not ( - self.cross_layers + self._cross_layers_blocks or self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first @@ -239,13 +245,16 @@ def tp_size(self) -> int: def block_size(self) -> int: return self.remote_block_size[self.engine_id] - def block_size_position(self, device_type: str) -> int: - if device_type == "cpu" or self.cross_layers: - block_size_position = -2 - else: - block_size_position = -2 if self.is_mla else -3 + @property + def use_pallas(self) -> bool: + return self._use_pallas + + @property + def cross_layers_blocks(self) -> bool: + return self._cross_layers_blocks - return block_size_position + def block_size_position(self, device_type: str) -> int: + return -2 if self.is_mla or self._cross_layers_blocks else -3 def tp_ratio( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 872b0b72f993..800ca5772d6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -930,19 +930,21 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla - backend = get_attn_backend( + self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, use_mla=self.use_mla, ) - self.backend_name = backend.get_name() + self.backend_name = self.attn_backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + self.compat_hash: str | None = None + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -952,6 +954,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._physical_blocks_per_logical_kv_block = 1 + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) + def _nixl_handshake( self, host: str, @@ -1159,14 +1165,6 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla, - ) - self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, engine_id=self.engine_id, @@ -1174,16 +1172,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, - cross_layers=next(iter(kv_caches)) == "ALL_LAYERS", + attn_backend=self.attn_backend, + tensor_shape=next(iter(kv_caches.values())).shape, ) - self._use_pallas = self.kv_topo._use_pallas self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.kv_topo.cross_layers - ) - self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( - "enforce_handshake_compat", True + self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) if self.use_host_buffer: @@ -1220,7 +1214,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None # TODO (NickLucche): Get kernel_block_size in a cleaner way @@ -1233,7 +1226,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = ( cache_or_caches - if not self.kv_topo.cross_layers and split_k_and_v + if not self.kv_topo.cross_layers_blocks and self.kv_topo.split_k_and_v else [cache_or_caches] ) for cache in cache_list: @@ -1583,7 +1576,7 @@ def _validate_remote_agent_handshake( remote_engine_id ) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, ( + assert not self.kv_topo.use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) kv_cache_layout = ( @@ -1745,7 +1738,9 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): if len(self.device_kv_caches) == 0: return split_k_and_v = not ( - self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first + self.use_mla + or self.kv_topo.use_pallas + or self.kv_topo.is_kv_layout_blocks_first ) sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): From cd278668dcdb4709b0d309e0a39dfc99e5b7f28b Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 22 Dec 2025 09:24:47 +0200 Subject: [PATCH 08/84] Update vllm/distributed/kv_transfer/kv_connector/utils.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 582322f15b4d..20995601ac89 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -222,6 +222,9 @@ def __post_init__(self): ) self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) + if self._cross_layers_blocks: + # expect one additional dimension (num_layers) + assert len(self.tensor_shape) == len(test_shape) + 1 @property def is_kv_layout_blocks_first(self) -> bool: From 19319af66c9809b5eab5e60f4c1123c37b1e4a40 Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 22 Dec 2025 09:25:10 +0200 Subject: [PATCH 09/84] Update vllm/distributed/kv_transfer/kv_connector/utils.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 20995601ac89..4bf5a67eb62e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -256,7 +256,8 @@ def use_pallas(self) -> bool: def cross_layers_blocks(self) -> bool: return self._cross_layers_blocks - def block_size_position(self, device_type: str) -> int: + @property + def block_size_position(self) -> int: return -2 if self.is_mla or self._cross_layers_blocks else -3 def tp_ratio( From 994bf1de0c2673e09e0fdfd0b4bb0bbbb4b5383c Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 22 Dec 2025 09:29:59 +0200 Subject: [PATCH 10/84] Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 800ca5772d6d..96dea928d5c8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1218,7 +1218,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO (NickLucche): Get kernel_block_size in a cleaner way # NHD default "view" for non-MLA cache - block_size_position = self.kv_topo.block_size_position(self.device_type) + if self.device_type == "cpu": + block_size_position = -2 + else: + block_size_position = self.kv_topo.block_size_position # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() From 0efeba3b0719f723d55cf9e713f9610e64865dee Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 22 Dec 2025 09:42:16 +0200 Subject: [PATCH 11/84] Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 96dea928d5c8..61d8bfd97194 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1229,7 +1229,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = ( cache_or_caches - if not self.kv_topo.cross_layers_blocks and self.kv_topo.split_k_and_v + if self.kv_topo.split_k_and_v else [cache_or_caches] ) for cache in cache_list: From ef8e7adf45ec15b67e0813d954707d1362993d40 Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 22 Dec 2025 09:49:08 +0200 Subject: [PATCH 12/84] Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 61d8bfd97194..c4358c467c60 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -186,7 +186,7 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), - "cross_layers": cross_layers, + "cross_layers_blocks": cross_layers_blocks, } compat_hash = hash_factors(factors) From 6e2b751268a33c7fb19c418243f6eb18ca4f581a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 22 Dec 2025 07:49:52 +0000 Subject: [PATCH 13/84] Code review fixes Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 11 ++++++----- .../kv_transfer/kv_connector/v1/nixl_connector.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 4bf5a67eb62e..156199c41d66 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -198,9 +198,9 @@ class TpKVTopology: is_mla: bool total_num_kv_heads: int attn_backend: type[AttentionBackend] - tensor_shape: torch.Size engine_id: str remote_block_size: dict[str, int] + tensor_shape: torch.Size | None = None def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -221,10 +221,11 @@ def __post_init__(self): num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) - self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) - if self._cross_layers_blocks: - # expect one additional dimension (num_layers) - assert len(self.tensor_shape) == len(test_shape) + 1 + if self.tensor_shape is not None: + self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) + if self._cross_layers_blocks: + # expect one additional dimension (num_layers) + assert len(self.tensor_shape) == len(test_shape) + 1 @property def is_kv_layout_blocks_first(self) -> bool: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c4358c467c60..73920a15bca0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -944,6 +944,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self.compat_hash: str | None = None + self.kv_topo: TpKVTopology | None = None self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} From 5e66e8f69709efd670709cba5b5e0a4bb7ba0c7f Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 28 Dec 2025 07:48:03 +0000 Subject: [PATCH 14/84] Code review fix Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/utils.py | 1 + .../kv_connector/v1/nixl_connector.py | 36 +++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 156199c41d66..c6a7272903cb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -221,6 +221,7 @@ def __post_init__(self): num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) + self._cross_layers_blocks = False if self.tensor_shape is not None: self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) if self._cross_layers_blocks: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 73920a15bca0..36095cf3b760 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -143,7 +143,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str, cross_layers: bool + vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool ) -> str: """ Compute compatibility hash for NIXL KV transfer. @@ -970,6 +970,8 @@ def _nixl_handshake( start_time = time.perf_counter() + assert self.kv_topo is not None, "kv_topo is not initialized" + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. @@ -1006,6 +1008,7 @@ def _nixl_handshake( ) # Check compatibility hash BEFORE decoding agent metadata + assert self.compat_hash is not None, "compat_hash is not initialized" if ( self.enforce_compat_hash and handshake_payload.compatibility_hash != self.compat_hash @@ -1177,6 +1180,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): tensor_shape=next(iter(kv_caches.values())).shape, ) + assert self.kv_topo is not None, "kv_topo is not initialized" + self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) @@ -1229,9 +1234,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = ( - cache_or_caches - if self.kv_topo.split_k_and_v - else [cache_or_caches] + cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] ) for cache in cache_list: base_addr = cache.data_ptr() @@ -1355,6 +1358,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ) # Wrap metadata in payload with hash for defensive decoding + assert self.compat_hash is not None, "compat_hash is not initialized" encoder = msgspec.msgpack.Encoder() self.xfer_handshake_metadata = NixlHandshakePayload( compatibility_hash=self.compat_hash, @@ -1376,6 +1380,8 @@ def register_local_xfer_handler( register another local_xfer_handler using remote block len to ensure data copy correctness. """ + assert self.kv_topo is not None, "kv_topo is not initialized" + block_size_ratio = self.block_size // block_size blocks_data = [] for i, base_addr in enumerate(self.seen_base_addresses): @@ -1479,6 +1485,7 @@ def add_remote_agent( nixl_agent_meta.agent_metadata ) + assert self.kv_topo is not None, "kv_topo is not initialized" # Handle tp_size>num_kv_heads: replicate KV cache. replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) @@ -1574,6 +1581,7 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size + assert self.kv_topo is not None, "kv_topo is not initialized" tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1698,6 +1706,7 @@ def permute_device_kv(self, block_ids: list[int]): - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ + assert self.kv_topo is not None, "kv_topo is not initialized" split_k_and_v = self.kv_topo.split_k_and_v inv_order = [0, 2, 1, 3] sample_cache = list(self.device_kv_caches.values())[0][0] @@ -1739,13 +1748,11 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): ) return permuted_blocks + assert self.kv_topo is not None, "kv_topo is not initialized" + if len(self.device_kv_caches) == 0: return - split_k_and_v = not ( - self.use_mla - or self.kv_topo.use_pallas - or self.kv_topo.is_kv_layout_blocks_first - ) + sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): assert block_size_ratio > 1, "Only nP < nD supported currently." @@ -1755,7 +1762,11 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): indices = torch.tensor(block_ids, device=sample_cache.device) for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + cache_list = ( + cache_or_caches + if self.kv_topo.split_k_and_v + else [cache_or_caches] + ) for cache in cache_list: blocks_to_update = cache.index_select(0, indices) # because kv_cache is always using original layout NHD as @@ -1776,6 +1787,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) + assert self.kv_topo is not None, "kv_topo is not initialized" + # add requests that skipped transfer to done_recving done_recving.update(self._failed_recv_reqs) self._failed_recv_reqs.clear() @@ -2007,6 +2020,8 @@ def _read_blocks( request_id: str, remote_request_id: str, ): + assert self.kv_topo is not None, "kv_topo is not initialized" + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( @@ -2243,6 +2258,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ + assert self.kv_topo is not None, "kv_topo is not initialized" if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2 From eaf5e3d4542df30aeba0cdd249fc817a9bc4ce15 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 28 Dec 2025 09:27:59 +0000 Subject: [PATCH 15/84] Code review fix Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 36095cf3b760..586420eeba37 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -970,14 +970,13 @@ def _nixl_handshake( start_time = time.perf_counter() - assert self.kv_topo is not None, "kv_topo is not initialized" - # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. + assert self.kv_topo is not None, "kv_topo is not initialized" p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) path = make_zmq_path("tcp", host, port) logger.debug( @@ -1180,8 +1179,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): tensor_shape=next(iter(kv_caches.values())).shape, ) - assert self.kv_topo is not None, "kv_topo is not initialized" - self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) From e85f458ac78e7602b39e00259af336ea9063d7b1 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 28 Dec 2025 13:26:11 +0000 Subject: [PATCH 16/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/v1/nixl_connector.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 586420eeba37..d344430da46e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -976,7 +976,7 @@ def _nixl_handshake( # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) path = make_zmq_path("tcp", host, port) logger.debug( @@ -1007,7 +1007,7 @@ def _nixl_handshake( ) # Check compatibility hash BEFORE decoding agent metadata - assert self.compat_hash is not None, "compat_hash is not initialized" + assert self.compat_hash is not None if ( self.enforce_compat_hash and handshake_payload.compatibility_hash != self.compat_hash @@ -1355,7 +1355,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ) # Wrap metadata in payload with hash for defensive decoding - assert self.compat_hash is not None, "compat_hash is not initialized" + assert self.compat_hash is not None encoder = msgspec.msgpack.Encoder() self.xfer_handshake_metadata = NixlHandshakePayload( compatibility_hash=self.compat_hash, @@ -1377,7 +1377,7 @@ def register_local_xfer_handler( register another local_xfer_handler using remote block len to ensure data copy correctness. """ - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None block_size_ratio = self.block_size // block_size blocks_data = [] @@ -1482,7 +1482,7 @@ def add_remote_agent( nixl_agent_meta.agent_metadata ) - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None # Handle tp_size>num_kv_heads: replicate KV cache. replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) @@ -1578,7 +1578,7 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1703,7 +1703,7 @@ def permute_device_kv(self, block_ids: list[int]): - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None split_k_and_v = self.kv_topo.split_k_and_v inv_order = [0, 2, 1, 3] sample_cache = list(self.device_kv_caches.values())[0][0] @@ -1745,7 +1745,7 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): ) return permuted_blocks - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None if len(self.device_kv_caches) == 0: return @@ -1784,7 +1784,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None # add requests that skipped transfer to done_recving done_recving.update(self._failed_recv_reqs) @@ -2017,7 +2017,7 @@ def _read_blocks( request_id: str, remote_request_id: str, ): - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: @@ -2255,7 +2255,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - assert self.kv_topo is not None, "kv_topo is not initialized" + assert self.kv_topo is not None if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2 From 9bd95989f50f9c5e1ca69ddfe0afc9e1bcc078c6 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 28 Dec 2025 14:27:04 +0000 Subject: [PATCH 17/84] Code review fix Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f257549d840a..5b76471326d7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -973,6 +973,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + # lazy initialized in register_kv_caches self.compat_hash: str | None = None self.kv_topo: TpKVTopology | None = None From f153e83944c2810d231248447e6376dd8dbec62a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 07:24:53 +0000 Subject: [PATCH 18/84] n/a Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/utils.py | 1 + .../kv_connector/v1/nixl_connector.py | 59 ++----------------- 2 files changed, 6 insertions(+), 54 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 2840b02b241d..1649e0fa81d2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 28c46fb9e83b..5a0ed1ca98e5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, Optional import msgspec import numpy as np @@ -20,7 +20,7 @@ import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( @@ -55,7 +55,6 @@ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout -from vllm.v1.attention.selector import get_attn_backend from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.block_table import BlockTable @@ -1005,7 +1004,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, + attn_backend=self.attn_backend, ) self._physical_blocks_per_logical_kv_block = 1 @@ -1810,53 +1809,8 @@ def post_process_device_kv_on_receive( block size to large block size and convert from HND to NHD """ - assert self.kv_topo is not None - split_k_and_v = self.kv_topo.split_k_and_v - inv_order = [0, 2, 1, 3] - sample_cache = list(self.device_kv_caches.values())[0][0] - target_shape = list(sample_cache.shape) - target_shape[0] = -1 - src_shape = tuple(target_shape[i] for i in inv_order) - indices = torch.tensor(block_ids, device=sample_cache.device) - - for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - for cache in cache_list: - blocks_to_update = cache.index_select(0, indices) - permuted_blocks = blocks_to_update.reshape(src_shape).permute( - *inv_order - ) - cache.index_copy_(0, indices, permuted_blocks) - - def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]): - def _process_local_gt_remote(blocks_to_update, block_size_ratio): - n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - remote_block_size = block_size // block_size_ratio - n_blocks = block_size_ratio - # actual permute is to convert - # for local blocksize > remote blocksize - # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens - # local block[0] = remote block[0, 1, 2, 3] - # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... - # local is |h0-b0..................|h1-b0..................|... - # permute is to: - # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) - # 2. permute => (H, nblocks, remoteN, D) - # 3. flatten => (H, localN, D) - permuted_blocks = ( - blocks_to_update.reshape( - -1, n_blocks, n_kv_heads, remote_block_size, head_size - ) - .permute(0, 2, 1, 3, 4) - .flatten(2, 3) - ) - return permuted_blocks - - assert self.kv_topo is not None - if len(self.device_kv_caches) == 0: return - assert block_size_ratio >= 1, "Only nP < nD supported currently." if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( @@ -1877,17 +1831,14 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): block_size_ratio, ) + assert self.kv_topo is not None split_k_and_v = self.kv_topo.split_k_and_v for block_ids in block_ids_list: indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = ( - cache_or_caches - if self.kv_topo.split_k_and_v - else [cache_or_caches] - ) + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] for cache in cache_list: if self.enable_permute_local_kv and block_size_ratio > 1: kv_postprocess_blksize_and_layout_on_receive( From 15f2a78ad4175a4f802d26bdb645a7983b655327 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 07:30:34 +0000 Subject: [PATCH 19/84] n/a Signed-off-by: Liran Schour --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5a0ed1ca98e5..db5773f65a26 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,8 +20,8 @@ import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.selector import get_attn_backend +from vllm.attention.v1.backends.abstract import AttentionBackend +from vllm.attention.v1.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, From 0cb182507f2c827455e9876e75e12243444be58b Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 07:32:06 +0000 Subject: [PATCH 20/84] n/a Signed-off-by: Liran Schour --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index db5773f65a26..91f02c5c5671 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,8 +20,6 @@ import zmq from vllm import envs -from vllm.attention.v1.backends.abstract import AttentionBackend -from vllm.attention.v1.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, @@ -54,7 +52,9 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.selector import get_attn_backend from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.block_table import BlockTable From 9630c8ea0d7cd6f5cf44429c12a7bf76088c8a24 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 07:33:51 +0000 Subject: [PATCH 21/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 91f02c5c5671..9405f4e098db 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -51,8 +51,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket -from vllm.v1.attention.backend import AttentionMetadata -from vllm.v1.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.selector import get_attn_backend from vllm.v1.core.sched.output import SchedulerOutput From c148f6def2bc8918fb090c3beb1960401eea57e5 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 07:41:00 +0000 Subject: [PATCH 22/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1649e0fa81d2..f54a0bc781e9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -14,7 +14,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: @@ -330,9 +329,6 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) - attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS - test_shape = self.attn_backend.get_kv_cache_shape( num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) @@ -352,10 +348,7 @@ def is_kv_layout_blocks_first(self) -> bool: def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). return not ( - self._cross_layers_blocks - or self.is_mla - or self._use_pallas - or self.is_kv_layout_blocks_first + self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first ) @property @@ -366,10 +359,6 @@ def tp_size(self) -> int: def block_size(self) -> int: return self.remote_block_size[self.engine_id] - @property - def use_pallas(self) -> bool: - return self._use_pallas - @property def cross_layers_blocks(self) -> bool: return self._cross_layers_blocks From 96329f6f35b25205ef79a04220ccf4c02f07b347 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 13 Jan 2026 08:06:29 +0000 Subject: [PATCH 23/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9405f4e098db..2ddf7f843704 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1679,9 +1679,6 @@ def _validate_remote_agent_handshake( ) # Num kv_heads > tp_size and P TP > D TP case, not supported assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) - assert not self.kv_topo._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) kv_cache_layout = ( self.kv_cache_layout From b4d70450cc7b1f2971e64c6c5d99e2dfb89b6970 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 10:30:44 +0000 Subject: [PATCH 24/84] Unit test fix Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6d25ee6f61c4..af7fc20d02f5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2063,7 +2063,9 @@ def test_compatibility_hash_validation( ) ) remote_hash = compute_nixl_compatibility_hash( - remote_vllm_config, decode_worker.backend_name + remote_vllm_config, + decode_worker.backend_name, + decode_worker.cross_layers_blocks, ) prefill_block_size = config_overrides.get("block_size", 16) From 52b1155799d27ba761a00dffd585e4daa49fdacf Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 13:45:10 +0000 Subject: [PATCH 25/84] Unit test fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index af7fc20d02f5..e2f06914a47b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -19,7 +19,10 @@ from vllm import LLM from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.utils import ( + KVOutputAggregator, + TpKVTopology, +) from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( @@ -367,6 +370,21 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + decode_connector.kv_topo = TpKVTopology( + tp_rank=decode_connector.tp_rank, + engine_id=decode_connector.engine_id, + remote_tp_size=decode_connector._tp_size, # shared state + remote_block_size=decode_connector._block_size, # shared state + is_mla=decode_connector.use_mla, + total_num_kv_heads=decode_connector.model_config.get_total_num_kv_heads(), + attn_backend=decode_connector.attn_backend, + tensor_shape=next(iter(kv_caches.values())).shape, + ) + decode_connector.compat_hash = compute_nixl_compatibility_hash( + decode_connector.vllm_config, + decode_connector.backend_name, + decode_connector.kv_topo.cross_layers_blocks, + ) # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent # to validate the metadata received is the same as the one in prefill_connector. From e34db34d4a5cbc0ada66ef84e01db448ab6872fa Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 13:52:16 +0000 Subject: [PATCH 26/84] Unit test fix Signed-off-by: Liran Schour --- .../v1/kv_connector/unit/test_nixl_connector.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e2f06914a47b..6af6e6a0a2d7 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -419,7 +419,24 @@ def __init__( self._hand_shake_latency = hand_shake_latency self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. + test_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) self.src_xfer_handles_by_block_size = {self.block_size: 1} + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=self.attn_backend, + tensor_shape=test_shape, + ) + + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks + ) def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str From edc075573063926ea3824d9a5a1c82278bb4dfa9 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 14:00:54 +0000 Subject: [PATCH 27/84] Unit test fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6af6e6a0a2d7..82c96b8469f0 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -369,21 +369,21 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - - decode_connector.kv_topo = TpKVTopology( - tp_rank=decode_connector.tp_rank, - engine_id=decode_connector.engine_id, - remote_tp_size=decode_connector._tp_size, # shared state - remote_block_size=decode_connector._block_size, # shared state - is_mla=decode_connector.use_mla, - total_num_kv_heads=decode_connector.model_config.get_total_num_kv_heads(), - attn_backend=decode_connector.attn_backend, + decode_worker = decode_connector.connector_worker + decode_worker.kv_topo = TpKVTopology( + tp_rank=decode_worker.tp_rank, + engine_id=decode_worker.engine_id, + remote_tp_size=decode_worker._tp_size, # shared state + remote_block_size=decode_worker._block_size, # shared state + is_mla=decode_worker.use_mla, + total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), + attn_backend=decode_worker.attn_backend, tensor_shape=next(iter(kv_caches.values())).shape, ) - decode_connector.compat_hash = compute_nixl_compatibility_hash( - decode_connector.vllm_config, - decode_connector.backend_name, - decode_connector.kv_topo.cross_layers_blocks, + decode_worker.compat_hash = compute_nixl_compatibility_hash( + decode_worker.vllm_config, + decode_worker.backend_name, + decode_worker.kv_topo.cross_layers_blocks, ) # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent From 03af3ece6a9544c6426633a8b3039dcd2751763e Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 14:11:25 +0000 Subject: [PATCH 28/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 82c96b8469f0..052dc7b780c5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2100,7 +2100,7 @@ def test_compatibility_hash_validation( remote_hash = compute_nixl_compatibility_hash( remote_vllm_config, decode_worker.backend_name, - decode_worker.cross_layers_blocks, + decode_worker.kv_topo.cross_layers_blocks, ) prefill_block_size = config_overrides.get("block_size", 16) From f0f2cf9dfed1046577fc29d1172107d8f07485a5 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 15 Jan 2026 14:22:02 +0000 Subject: [PATCH 29/84] n/a Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 052dc7b780c5..c71116e43605 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2076,6 +2076,24 @@ def test_compatibility_hash_validation( ) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker + test_shape = decode_worker.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + decode_worker.kv_topo = TpKVTopology( + tp_rank=decode_worker.tp_rank, + engine_id=decode_worker.engine_id, + remote_tp_size=decode_worker._tp_size, # shared state + remote_block_size=decode_worker._block_size, # shared state + is_mla=decode_worker.use_mla, + total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), + attn_backend=decode_worker.attn_backend, + tensor_shape=test_shape, + ) + decode_worker.compat_hash = compute_nixl_compatibility_hash( + decode_worker.vllm_config, + decode_worker.backend_name, + decode_worker.kv_topo.cross_layers_blocks, + ) remote_config_params: dict[str, Any] = { "model": "facebook/opt-125m", From 701d4efd3d6915a27032acad1c5ef81874dfbbfe Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 18 Jan 2026 08:11:38 +0000 Subject: [PATCH 30/84] Code review fix Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 7 ++++++- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 +-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c46f96fdd98f..4e5361ae9166 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -317,6 +317,7 @@ class TpKVTopology: engine_id: EngineId remote_block_size: dict[EngineId, int] tensor_shape: torch.Size | None = None + device_type: str = "cuda" def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -366,7 +367,11 @@ def cross_layers_blocks(self) -> bool: @property def block_size_position(self) -> int: - return -2 if self.is_mla or self._cross_layers_blocks else -3 + return ( + -2 + if self.device_type == "cpu" or self.is_mla or self._cross_layers_blocks + else -3 + ) def tp_ratio( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4c59ec225c4a..8446c105ce23 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1336,13 +1336,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None - # TODO (NickLucche): Get kernel_block_size in a cleaner way - # NHD default "view" for non-MLA cache - if self.device_type == "cpu": - block_size_position = -2 - else: - block_size_position = self.kv_topo.block_size_position - # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms @@ -1355,7 +1348,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if base_addr in seen_base_addresses: continue - kernel_block_size = cache.shape[block_size_position] + kernel_block_size = cache.shape[self.kv_topo.block_size_position] if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" From e7df5f83b5e775377a9bd1a39de0a53378c690cc Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Sun, 18 Jan 2026 08:14:37 +0000 Subject: [PATCH 31/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 280 +++++++++++++++--- .../kv_transfer/kv_connector/utils.py | 11 +- .../kv_connector/v1/nixl_connector.py | 1 - 3 files changed, 249 insertions(+), 43 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c71116e43605..028d01aec64a 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -51,8 +51,11 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin +from vllm.v1.worker.utils import AttentionGroup from .utils import create_request, create_scheduler, create_vllm_config @@ -369,22 +372,8 @@ def test_kv_transfer_handshake(dist_init): # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - decode_worker = decode_connector.connector_worker - decode_worker.kv_topo = TpKVTopology( - tp_rank=decode_worker.tp_rank, - engine_id=decode_worker.engine_id, - remote_tp_size=decode_worker._tp_size, # shared state - remote_block_size=decode_worker._block_size, # shared state - is_mla=decode_worker.use_mla, - total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), - attn_backend=decode_worker.attn_backend, - tensor_shape=next(iter(kv_caches.values())).shape, - ) - decode_worker.compat_hash = compute_nixl_compatibility_hash( - decode_worker.vllm_config, - decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, - ) + decode_connector.register_kv_caches(kv_caches) + # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent # to validate the metadata received is the same as the one in prefill_connector. @@ -419,10 +408,10 @@ def __init__( self._hand_shake_latency = hand_shake_latency self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. + self.src_xfer_handles_by_block_size = {self.block_size: 1} test_shape = self.attn_backend.get_kv_cache_shape( num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) - self.src_xfer_handles_by_block_size = {self.block_size: 1} self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, engine_id=self.engine_id, @@ -492,6 +481,76 @@ def _nixl_handshake( return remote_agents +class FakeNixlConnectorWorkerCrossLayers(NixlConnectorWorker): + REMOTE_ENGINE_ID = "remote_engine" + + def __init__( + self, + *args, + hand_shake_latency: float = 1.8, + kv_cache_layout="HND", + **kwargs, + ): + super().__init__(*args, **kwargs) + self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = kv_cache_layout + # Mock register_kv_caches attribute needed for tests that do not call it. + self.src_xfer_handles_by_block_size = {self.block_size: 1} + + def _nixl_handshake( + self, host: str, port: int, remote_tp_size: int, expected_engine_id: str + ) -> dict[int, str]: + # Mimic slow _nixl_handshake, as well as bypass zmq communication. + time.sleep(self._hand_shake_latency) + # These should've been done in register_kv_caches(), called by + # gpu_model_runner. Here we just hardcode some dummy values. + slot_size_bytes = 4096 + self.slot_size_per_layer = [slot_size_bytes] + self.block_len_per_layer = [slot_size_bytes * self.block_size] + self.num_blocks = 1 + self.dst_num_blocks[self.engine_id] = self.num_blocks + + assert expected_engine_id == self.REMOTE_ENGINE_ID + + # Adjust remote block length metadata to satisfy heterogeneous TP + # invariants enforced during handshake validation. + remote_block_lens = list(self.block_len_per_layer) + tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) + if remote_tp_size > self.world_size: + # P TP > D TP case, block_len of remote is smaller + remote_block_lens = [ + block_len // (-tp_ratio) for block_len in remote_block_lens + ] + elif remote_tp_size < self.world_size: + remote_block_lens = [ + block_len * tp_ratio for block_len in remote_block_lens + ] + + # When remote tp_size > local tp_size, handshake with multiple + # remote ranks. + num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio + remote_agents: dict[int, str] = {} + for remote_tp_rank in range(num_hanshakes): + remote_agent_name = self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=remote_tp_rank, + num_blocks=1, + block_lens=remote_block_lens, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", + block_size=self.block_size, + ), + remote_tp_rank=remote_tp_rank, + remote_tp_size=remote_tp_size, + ) + remote_agents[remote_tp_rank] = remote_agent_name + return remote_agents + + class TestNixlHandshake: @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", @@ -1539,6 +1598,166 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ) +@pytest.mark.parametrize( + "attn_backend", + [ + pytest.param( + "FLASH_ATTN", + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASH_ATTN is not supported on ROCm", + ), + ), + pytest.param( + "ROCM_ATTN", + marks=pytest.mark.skipif( + not current_platform.is_rocm(), + reason="Attention backend ROCM_ATTN is only supported on ROCm", + ), + ), + "TRITON_ATTN", + ], +) +def test_register_kv_caches_cross_layers(default_vllm_config, dist_init, attn_backend): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data in cross layers mode. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config(attention_backend=attn_backend) + + # Import the appropriate backend based on the parameter + if attn_backend == "FLASH_ATTN": + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + backend_cls = FlashAttentionBackend + elif attn_backend == "ROCM_ATTN": + from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend + + backend_cls = RocmAttentionBackend + else: # TRITON_ATTN + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + + backend_cls = TritonAttentionBackend + + num_layers = 32 + block_size = 16 + num_blocks = 2 + kv_cache_spec = AttentionSpec( + block_size=block_size, num_kv_heads=4, head_size=64, dtype=torch.bfloat16 + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=kv_cache_spec.page_size_bytes * num_blocks, + shared_by=["dummy-layer"], + ) + for i in range(num_layers) + ], + # allocate_uniform_kv_caches does not use this + kv_cache_groups=[], + ) + _, cross_layers_kv_cache, _ = ( + KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( + kv_cache_config=kv_cache_config, + attn_groups=[ + [ + AttentionGroup( + backend=backend_cls, + layer_names=[], + kv_cache_spec=kv_cache_spec, + kv_cache_group_id=0, + ) + ] + ], + cache_dtype=torch.bfloat16, + device=torch.cuda.current_device(), + kernel_block_sizes=[block_size], + ) + ) + + # Store tensor info for validation + expected_tensor_size = ( + cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() + ) + expected_base_addrs = [ + cross_layers_kv_cache.data_ptr(), + ] + expected_num_entries = 1 + + nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" + with ( + patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, + patch(f"{nixl_module}.threading.Event"), + patch(f"{nixl_module}.threading.Thread") as mock_thread, + patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, + ): + # Ensure get_attn_backend returns the correct value due to + # _cached_get_attn_backend returning the backend from previous + # test run if not mocking. + mock_get_attn_backend.return_value = backend_cls + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorkerCrossLayers( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Appease NixlHandshakePayload encoding with some bytes + mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata" + + # Reassure the shutdown() check that the thread is terminated + mock_thread.return_value.is_alive.return_value = False + + # Execute register_kv_caches + connector.register_kv_caches({"all-layers": cross_layers_kv_cache}) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == expected_num_entries + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, ( + f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}" + ) + assert base_addr == expected_base_addrs[i], ( + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " + f"got {base_addr}" + ) + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = num_blocks + assert len(blocks_data) == expected_blocks_count, ( + f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" + ) + + expected_block_len = expected_tensor_size // num_blocks + + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, ( + f"Block entry {i}: Expected block len {expected_block_len}, " + f"got {block_len}" + ) + + class FakePlatform(Platform): device_type: str = "oot" @@ -2076,24 +2295,17 @@ def test_compatibility_hash_validation( ) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker - test_shape = decode_worker.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - decode_worker.kv_topo = TpKVTopology( - tp_rank=decode_worker.tp_rank, - engine_id=decode_worker.engine_id, - remote_tp_size=decode_worker._tp_size, # shared state - remote_block_size=decode_worker._block_size, # shared state - is_mla=decode_worker.use_mla, - total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), - attn_backend=decode_worker.attn_backend, - tensor_shape=test_shape, - ) - decode_worker.compat_hash = compute_nixl_compatibility_hash( - decode_worker.vllm_config, - decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, + kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + decode_connector.register_kv_caches(kv_caches) remote_config_params: dict[str, Any] = { "model": "facebook/opt-125m", diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 4e5361ae9166..3c82e6e7d6e3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -331,16 +331,11 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) - test_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 - ) - self._cross_layers_blocks = False if self.tensor_shape is not None: - self._cross_layers_blocks = len(self.tensor_shape) != len(test_shape) - if self._cross_layers_blocks: - # expect one additional dimension (num_layers) - assert len(self.tensor_shape) == len(test_shape) + 1 + self._cross_layers_blocks = ( + len(self.tensor_shape) == len(kv_cache_shape) + 1 + ) @property def is_kv_layout_blocks_first(self) -> bool: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8446c105ce23..1f58b81e102f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1295,7 +1295,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): attn_backend=self.attn_backend, tensor_shape=next(iter(kv_caches.values())).shape, ) - self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) From 6dae9b52f3178b97a5ce5106ccc4f957e09f4bf9 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 19 Jan 2026 15:02:22 +0000 Subject: [PATCH 32/84] Code review fix Signed-off-by: Liran Schour --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1f58b81e102f..e4111293bcac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -301,7 +301,8 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: - return True + extra_config = self.kv_transfer_config.kv_connector_extra_config + return extra_config.get("cross_layers_block", False) def __init__( self, @@ -314,6 +315,7 @@ def __init__( assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + self.kv_transfer_config = vllm_config.kv_transfer_config if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( From 012bb9e890542b08f4e966b4b2e2cce69c361ea0 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 20 Jan 2026 08:56:06 +0000 Subject: [PATCH 33/84] Handle hetrogenous TP for FLASHINFER and TRITON Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/utils.py | 26 ++++++++++++++++++- .../kv_connector/v1/nixl_connector.py | 26 ++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 3c82e6e7d6e3..8f0b15c38cf8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -323,7 +323,7 @@ def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + num_blocks=1, block_size=16, num_kv_heads=4, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. @@ -331,12 +331,31 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) + self._kv_heads_position: int | None = None self._cross_layers_blocks = False if self.tensor_shape is not None: self._cross_layers_blocks = ( len(self.tensor_shape) == len(kv_cache_shape) + 1 ) + if self._cross_layers_blocks: + # prepend layers dimension + kv_cache_shape = (80,) + kv_cache_shape + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=self._cross_layers_blocks + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + + # permute kv_cache_shape according to stride_order + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + + physical_kv_heads_position = kv_cache_shape.index(4) + assert physical_kv_heads_position is not None + + self._physical_kv_heads_position = physical_kv_heads_position + @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -368,6 +387,11 @@ def block_size_position(self) -> int: else -3 ) + @property + def physical_kv_heads_position(self) -> int: + assert self._physical_kv_heads_position is not None + return self._physical_kv_heads_position + def tp_ratio( self, remote_tp_size: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e4111293bcac..00b73bd1a468 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: extra_config = self.kv_transfer_config.kv_connector_extra_config - return extra_config.get("cross_layers_block", False) + return bool(str(extra_config.get("enable_cross_layers_block", "True"))) def __init__( self, @@ -408,8 +408,8 @@ def register_cross_layers_kv_cache( assert self.connector_worker is not None cross_layer_name = "ALL_LAYERS" - kv_caches = {cross_layer_name: kv_cache} + self.connector_worker.register_kv_caches(kv_caches) def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): @@ -1411,7 +1411,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.kv_topo.is_kv_layout_blocks_first: + + if self.kv_topo.cross_layers_blocks: + assert len(kv_caches) == 1 + tensor_shape = list(kv_caches.values())[0].shape + + # NOTE (liranschour) When FlashInfer is used, memory is registered + # with joint KV for each block and in Triton joint KV and num_layers + # for each block. + # In order to be able to split on kv_heads dim as required by + # heterogeneous TP, one must be able to index (K/V, layer_idx) separately. + # Hence we multiply the number of 'virtual' regions here and divide + # `block_len` below. + multiply = 1 + for dim_idx in range(1, self.kv_topo.physical_kv_heads_position): + multiply *= tensor_shape[dim_idx] + logger.info("Multiply is %d", multiply) + for i in range(len(self.slot_size_per_layer)): + assert self.slot_size_per_layer[i] % multiply == 0 + self.slot_size_per_layer[i] //= multiply + self.num_regions *= multiply + elif self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 From 99b340193029af5518f85142ae3a854f62dac866 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 20 Jan 2026 09:24:33 +0000 Subject: [PATCH 34/84] n/a Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 00b73bd1a468..08eb57971925 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1426,7 +1426,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): multiply = 1 for dim_idx in range(1, self.kv_topo.physical_kv_heads_position): multiply *= tensor_shape[dim_idx] - logger.info("Multiply is %d", multiply) + logger.info( + "Multiply is %d shape %s kv_heads_pos %d", + multiply, + tensor_shape, + self.kv_topo.physical_kv_heads_position, + ) for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % multiply == 0 self.slot_size_per_layer[i] //= multiply From 9fe2eb6b566be06ef0f7b04390ef6a4bb27e2a5e Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 20 Jan 2026 09:45:31 +0000 Subject: [PATCH 35/84] n/a Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 028d01aec64a..1d52894849e7 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -18,7 +18,7 @@ import torch from vllm import LLM -from vllm.config import KVTransferConfig +from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.utils import ( KVOutputAggregator, TpKVTopology, @@ -1664,24 +1664,26 @@ def test_register_kv_caches_cross_layers(default_vllm_config, dist_init, attn_ba # allocate_uniform_kv_caches does not use this kv_cache_groups=[], ) - _, cross_layers_kv_cache, _ = ( - KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( - kv_cache_config=kv_cache_config, - attn_groups=[ - [ - AttentionGroup( - backend=backend_cls, - layer_names=[], - kv_cache_spec=kv_cache_spec, - kv_cache_group_id=0, - ) - ] - ], - cache_dtype=torch.bfloat16, - device=torch.cuda.current_device(), - kernel_block_sizes=[block_size], + + with set_current_vllm_config(vllm_config): + _, cross_layers_kv_cache, _ = ( + KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( + kv_cache_config=kv_cache_config, + attn_groups=[ + [ + AttentionGroup( + backend=backend_cls, + layer_names=[], + kv_cache_spec=kv_cache_spec, + kv_cache_group_id=0, + ) + ] + ], + cache_dtype=torch.bfloat16, + device=torch.cuda.current_device(), + kernel_block_sizes=[block_size], + ) ) - ) # Store tensor info for validation expected_tensor_size = ( From 043c4d8faa787490abb15845a6bbd38abe6e6ffd Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 20 Jan 2026 10:34:12 +0000 Subject: [PATCH 36/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 8f0b15c38cf8..7c361c5aa440 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -353,9 +353,14 @@ def __post_init__(self): physical_kv_heads_position = kv_cache_shape.index(4) assert physical_kv_heads_position is not None - self._physical_kv_heads_position = physical_kv_heads_position + physical_block_size_position = kv_cache_shape.index(16) + assert physical_block_size_position is not None + self._physical_block_size_position = -( + len(kv_cache_shape) - physical_block_size_position + ) + @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -381,11 +386,7 @@ def cross_layers_blocks(self) -> bool: @property def block_size_position(self) -> int: - return ( - -2 - if self.device_type == "cpu" or self.is_mla or self._cross_layers_blocks - else -3 - ) + return self._physical_block_size_position @property def physical_kv_heads_position(self) -> int: From fe7197ce6ddad57a95b5cbffce6e571d96afaa9a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 09:29:37 +0000 Subject: [PATCH 37/84] Run cross layers only for FlashAttention and FLASHINFER Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 08eb57971925..2e0718baeea4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -301,6 +301,11 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: + if self.attn_backend.get_name() not in ["FLASH_ATTN", "FLASHINFER"]: + # For now there is no benefit to run cross layers when backend + # does not support on HND + return False + extra_config = self.kv_transfer_config.kv_connector_extra_config return bool(str(extra_config.get("enable_cross_layers_block", "True"))) From 5d59ea682fb7e0b066dc55eedd15c8410941c123 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 09:59:52 +0000 Subject: [PATCH 38/84] Enhance test_register_kv_caches Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 293 ++++++------------ .../kv_connector/v1/nixl_connector.py | 8 +- 2 files changed, 103 insertions(+), 198 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1d52894849e7..49c739916eef 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1490,44 +1490,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): backend_cls = TritonAttentionBackend - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - - # Store tensor info for validation - - test_shape = backend_cls.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - - if is_blocks_first: - expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() - expected_base_addrs = [ - shared_tensor.data_ptr(), - unique_tensor.data_ptr(), - ] - expected_num_entries = 2 - else: - expected_tensor_size = ( - shared_tensor[0].element_size() * shared_tensor[0].numel() - ) - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] - expected_num_entries = 4 - nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, @@ -1556,174 +1518,105 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False - # Execute register_kv_caches - connector.register_kv_caches(kv_caches) - - # Verify get_reg_descs was called with caches_data - assert mock_wrapper_instance.get_reg_descs.called - caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] - assert len(caches_data) == expected_num_entries - - for i, cache_entry in enumerate(caches_data): - base_addr, size, _tp_rank, _ = cache_entry - assert size == expected_tensor_size, ( - f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}" + expected_tensor_size: int + expected_base_addrs: list[int] + expected_num_entries: int + kv_caches: dict[str, torch.Tensor] + if connector.prefer_cross_layer_blocks: + num_layers = 32 + block_size = 16 + num_blocks = 8 + kv_cache_spec = AttentionSpec( + block_size=block_size, + num_kv_heads=4, + head_size=64, + dtype=torch.bfloat16, ) - assert base_addr == expected_base_addrs[i], ( - f"Entry {i}: Expected base address {expected_base_addrs[i]}, " - f"got {base_addr}" + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=kv_cache_spec.page_size_bytes * num_blocks, + shared_by=["dummy-layer"], + ) + for i in range(num_layers) + ], + # allocate_uniform_kv_caches does not use this + kv_cache_groups=[], ) - # Verify get_xfer_descs was called with blocks_data - assert mock_wrapper_instance.get_xfer_descs.called - blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] - - # Validate blocks_data structure and size - expected_blocks_count = 8 - assert len(blocks_data) == expected_blocks_count, ( - f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" - ) + with set_current_vllm_config(vllm_config): + _, cross_layers_kv_cache, _ = ( + KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( + kv_cache_config=kv_cache_config, + attn_groups=[ + [ + AttentionGroup( + backend=backend_cls, + layer_names=[], + kv_cache_spec=kv_cache_spec, + kv_cache_group_id=0, + ) + ] + ], + cache_dtype=torch.bfloat16, + device=torch.cuda.current_device(), + kernel_block_sizes=[block_size], + ) + ) + # Store tensor info for validation + expected_tensor_size = ( + cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() + ) + expected_base_addrs = [ + cross_layers_kv_cache.data_ptr(), + ] + expected_num_entries = 1 - num_blocks = 2 - if is_blocks_first: - expected_block_len = expected_tensor_size // num_blocks // 2 + kv_caches = {"all-layers": cross_layers_kv_cache} else: - expected_block_len = expected_tensor_size // num_blocks - - for i, block_entry in enumerate(blocks_data): - block_start_addr, block_len, tp_rank = block_entry - assert block_len == expected_block_len, ( - f"Block entry {i}: Expected block len {expected_block_len}, " - f"got {block_len}" + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + # Store tensor info for validation -@pytest.mark.parametrize( - "attn_backend", - [ - pytest.param( - "FLASH_ATTN", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="Attention backend FLASH_ATTN is not supported on ROCm", - ), - ), - pytest.param( - "ROCM_ATTN", - marks=pytest.mark.skipif( - not current_platform.is_rocm(), - reason="Attention backend ROCM_ATTN is only supported on ROCm", - ), - ), - "TRITON_ATTN", - ], -) -def test_register_kv_caches_cross_layers(default_vllm_config, dist_init, attn_backend): - """ - Test that register_kv_caches() properly calls nixl_wrapper methods with - correct data in cross layers mode. - - This test verifies: - 1. nixl_wrapper.get_reg_descs() is called with caches_data containing - tensor metadata - 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing - block layout info - """ - - vllm_config = create_vllm_config(attention_backend=attn_backend) - - # Import the appropriate backend based on the parameter - if attn_backend == "FLASH_ATTN": - from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - - backend_cls = FlashAttentionBackend - elif attn_backend == "ROCM_ATTN": - from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend - - backend_cls = RocmAttentionBackend - else: # TRITON_ATTN - from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend - - backend_cls = TritonAttentionBackend - - num_layers = 32 - block_size = 16 - num_blocks = 2 - kv_cache_spec = AttentionSpec( - block_size=block_size, num_kv_heads=4, head_size=64, dtype=torch.bfloat16 - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[ - KVCacheTensor( - size=kv_cache_spec.page_size_bytes * num_blocks, - shared_by=["dummy-layer"], - ) - for i in range(num_layers) - ], - # allocate_uniform_kv_caches does not use this - kv_cache_groups=[], - ) - - with set_current_vllm_config(vllm_config): - _, cross_layers_kv_cache, _ = ( - KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( - kv_cache_config=kv_cache_config, - attn_groups=[ - [ - AttentionGroup( - backend=backend_cls, - layer_names=[], - kv_cache_spec=kv_cache_spec, - kv_cache_group_id=0, - ) - ] - ], - cache_dtype=torch.bfloat16, - device=torch.cuda.current_device(), - kernel_block_sizes=[block_size], + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) - ) - - # Store tensor info for validation - expected_tensor_size = ( - cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() - ) - expected_base_addrs = [ - cross_layers_kv_cache.data_ptr(), - ] - expected_num_entries = 1 - - nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" - with ( - patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, - patch(f"{nixl_module}.threading.Event"), - patch(f"{nixl_module}.threading.Thread") as mock_thread, - patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, - ): - # Ensure get_attn_backend returns the correct value due to - # _cached_get_attn_backend returning the backend from previous - # test run if not mocking. - mock_get_attn_backend.return_value = backend_cls + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - # Create connector - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeNixlConnectorWorkerCrossLayers( - vllm_config, connector.engine_id, hand_shake_latency=0 - ) - - # Get the mock instance - mock_wrapper_instance = mock_nixl_wrapper.return_value - connector.connector_worker.nixl_wrapper = mock_wrapper_instance - - # Appease NixlHandshakePayload encoding with some bytes - mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata" - - # Reassure the shutdown() check that the thread is terminated - mock_thread.return_value.is_alive.return_value = False + if is_blocks_first: + expected_tensor_size = ( + shared_tensor.element_size() * shared_tensor.numel() + ) + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + else: + expected_tensor_size = ( + shared_tensor[0].element_size() * shared_tensor[0].numel() + ) + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + ] + expected_num_entries = 4 # Execute register_kv_caches - connector.register_kv_caches({"all-layers": cross_layers_kv_cache}) + connector.register_kv_caches(kv_caches) # Verify get_reg_descs was called with caches_data assert mock_wrapper_instance.get_reg_descs.called @@ -1745,12 +1638,20 @@ def test_register_kv_caches_cross_layers(default_vllm_config, dist_init, attn_ba blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] # Validate blocks_data structure and size - expected_blocks_count = num_blocks + expected_blocks_count = 8 assert len(blocks_data) == expected_blocks_count, ( f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - expected_block_len = expected_tensor_size // num_blocks + if connector.prefer_cross_layer_blocks: + num_blocks = 8 + expected_block_len = expected_tensor_size // num_blocks + else: + num_blocks = 2 + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 + else: + expected_block_len = expected_tensor_size // num_blocks for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2e0718baeea4..4c293f8c0662 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -301,7 +301,11 @@ def add_new_req_to_recv( class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: - if self.attn_backend.get_name() not in ["FLASH_ATTN", "FLASHINFER"]: + backend = get_current_attn_backend(self._vllm_config) + if backend().get_name() not in [ + "FLASH_ATTN", + "FLASHINFER", + ]: # For now there is no benefit to run cross layers when backend # does not support on HND return False @@ -1431,7 +1435,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): multiply = 1 for dim_idx in range(1, self.kv_topo.physical_kv_heads_position): multiply *= tensor_shape[dim_idx] - logger.info( + logger.debug( "Multiply is %d shape %s kv_heads_pos %d", multiply, tensor_shape, From 392e5d594b0f4f5539475a6fbe44fe0f20a196d3 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 12:16:56 +0000 Subject: [PATCH 39/84] Documentation Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index af38087e4b3d..a1051a70bb2b 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -184,6 +184,13 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con --kv-transfer-config '{..., "enable_permute_local_kv":"True"}' ``` +### Cross layers blocks +By default, this feature is enabled. On backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. +To disable this feature: +``` +--kv-transfer-config '{..., "kv_connector_extra_config":{"enable_cross_layers_block": "False"}}' +``` + ## Example Scripts/Code Refer to these example scripts in the vLLM repository: From 3c6921f5fe02d76d03f0e3aa5efcf9486f26703f Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 13:01:49 +0000 Subject: [PATCH 40/84] Code review fix Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index a1051a70bb2b..ac5d3f558e92 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -185,8 +185,10 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con ``` ### Cross layers blocks + By default, this feature is enabled. On backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. To disable this feature: + ``` --kv-transfer-config '{..., "kv_connector_extra_config":{"enable_cross_layers_block": "False"}}' ``` From ed3180cbaa9c2812793d9cee8a2feab69bdf0a8f Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 13:14:14 +0000 Subject: [PATCH 41/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 123 ------------------ 1 file changed, 123 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b4085a221cf5..899261950cbe 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -427,129 +427,6 @@ def __init__( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) - def _nixl_handshake( - self, host: str, port: int, remote_tp_size: int, expected_engine_id: str - ) -> dict[int, str]: - # Mimic slow _nixl_handshake, as well as bypass zmq communication. - time.sleep(self._hand_shake_latency) - # These should've been done in register_kv_caches(), called by - # gpu_model_runner. Here we just hardcode some dummy values. - slot_size_bytes = 4096 - self.slot_size_per_layer = [slot_size_bytes] - self.block_len_per_layer = [slot_size_bytes * self.block_size] - self.num_blocks = 1 - self.dst_num_blocks[self.engine_id] = self.num_blocks - - assert expected_engine_id == self.REMOTE_ENGINE_ID - - # Adjust remote block length metadata to satisfy heterogeneous TP - # invariants enforced during handshake validation. - remote_block_lens = list(self.block_len_per_layer) - tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) - if remote_tp_size > self.world_size: - # P TP > D TP case, block_len of remote is smaller - remote_block_lens = [ - block_len // (-tp_ratio) for block_len in remote_block_lens - ] - elif remote_tp_size < self.world_size: - remote_block_lens = [ - block_len * tp_ratio for block_len in remote_block_lens - ] - - # When remote tp_size > local tp_size, handshake with multiple - # remote ranks. - num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio - remote_agents: dict[int, str] = {} - for remote_tp_rank in range(num_hanshakes): - remote_agent_name = self.add_remote_agent( - NixlAgentMetadata( - engine_id=self.REMOTE_ENGINE_ID, - agent_metadata=FakeNixlWrapper.AGENT_METADATA, - kv_caches_base_addr=[0], - device_id=remote_tp_rank, - num_blocks=1, - block_lens=remote_block_lens, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", - block_size=self.block_size, - ), - remote_tp_rank=remote_tp_rank, - remote_tp_size=remote_tp_size, - ) - remote_agents[remote_tp_rank] = remote_agent_name - return remote_agents - - -class FakeNixlConnectorWorkerCrossLayers(NixlConnectorWorker): - REMOTE_ENGINE_ID = "remote_engine" - - def __init__( - self, - *args, - hand_shake_latency: float = 1.8, - kv_cache_layout="HND", - **kwargs, - ): - super().__init__(*args, **kwargs) - self._hand_shake_latency = hand_shake_latency - self.kv_cache_layout = kv_cache_layout - # Mock register_kv_caches attribute needed for tests that do not call it. - self.src_xfer_handles_by_block_size = {self.block_size: 1} - - def _nixl_handshake( - self, host: str, port: int, remote_tp_size: int, expected_engine_id: str - ) -> dict[int, str]: - # Mimic slow _nixl_handshake, as well as bypass zmq communication. - time.sleep(self._hand_shake_latency) - # These should've been done in register_kv_caches(), called by - # gpu_model_runner. Here we just hardcode some dummy values. - slot_size_bytes = 4096 - self.slot_size_per_layer = [slot_size_bytes] - self.block_len_per_layer = [slot_size_bytes * self.block_size] - self.num_blocks = 1 - self.dst_num_blocks[self.engine_id] = self.num_blocks - - assert expected_engine_id == self.REMOTE_ENGINE_ID - - # Adjust remote block length metadata to satisfy heterogeneous TP - # invariants enforced during handshake validation. - remote_block_lens = list(self.block_len_per_layer) - tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) - if remote_tp_size > self.world_size: - # P TP > D TP case, block_len of remote is smaller - remote_block_lens = [ - block_len // (-tp_ratio) for block_len in remote_block_lens - ] - elif remote_tp_size < self.world_size: - remote_block_lens = [ - block_len * tp_ratio for block_len in remote_block_lens - ] - - # When remote tp_size > local tp_size, handshake with multiple - # remote ranks. - num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio - remote_agents: dict[int, str] = {} - for remote_tp_rank in range(num_hanshakes): - remote_agent_name = self.add_remote_agent( - NixlAgentMetadata( - engine_id=self.REMOTE_ENGINE_ID, - agent_metadata=FakeNixlWrapper.AGENT_METADATA, - kv_caches_base_addr=[0], - device_id=remote_tp_rank, - num_blocks=1, - block_lens=remote_block_lens, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", - block_size=self.block_size, - ), - remote_tp_rank=remote_tp_rank, - remote_tp_size=remote_tp_size, - ) - remote_agents[remote_tp_rank] = remote_agent_name - return remote_agents - class TestNixlHandshake: @patch( From ced9ad455d98374f9a2ce83729fd8e208f510c78 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 13:21:51 +0000 Subject: [PATCH 42/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 53 +++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 30 ++--------- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 899261950cbe..a78f3b64570f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -427,6 +427,59 @@ def __init__( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) + def _nixl_handshake( + self, host: str, port: int, remote_tp_size: int, expected_engine_id: str + ) -> dict[int, str]: + # Mimic slow _nixl_handshake, as well as bypass zmq communication. + time.sleep(self._hand_shake_latency) + # These should've been done in register_kv_caches(), called by + # gpu_model_runner. Here we just hardcode some dummy values. + slot_size_bytes = 4096 + self.slot_size_per_layer = [slot_size_bytes] + self.block_len_per_layer = [slot_size_bytes * self.block_size] + self.num_blocks = 1 + self.dst_num_blocks[self.engine_id] = self.num_blocks + + assert expected_engine_id == self.REMOTE_ENGINE_ID + + # Adjust remote block length metadata to satisfy heterogeneous TP + # invariants enforced during handshake validation. + remote_block_lens = list(self.block_len_per_layer) + tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) + if remote_tp_size > self.world_size: + # P TP > D TP case, block_len of remote is smaller + remote_block_lens = [ + block_len // (-tp_ratio) for block_len in remote_block_lens + ] + elif remote_tp_size < self.world_size: + remote_block_lens = [ + block_len * tp_ratio for block_len in remote_block_lens + ] + + # When remote tp_size > local tp_size, handshake with multiple + # remote ranks. + num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio + remote_agents: dict[int, str] = {} + for remote_tp_rank in range(num_hanshakes): + remote_agent_name = self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=remote_tp_rank, + num_blocks=1, + block_lens=remote_block_lens, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", + block_size=self.block_size, + ), + remote_tp_rank=remote_tp_rank, + remote_tp_size=remote_tp_size, + ) + remote_agents[remote_tp_rank] = remote_agent_name + return remote_agents + class TestNixlHandshake: @patch( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 37d2e7ce3f67..333dded2c944 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -302,10 +302,10 @@ class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: backend = get_current_attn_backend(self._vllm_config) - if backend().get_name() not in [ + if backend().get_name() not in ( "FLASH_ATTN", "FLASHINFER", - ]: + ): # For now there is no benefit to run cross layers when backend # does not support on HND return False @@ -1421,31 +1421,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.kv_topo.cross_layers_blocks: - assert len(kv_caches) == 1 - tensor_shape = list(kv_caches.values())[0].shape - - # NOTE (liranschour) When FlashInfer is used, memory is registered - # with joint KV for each block and in Triton joint KV and num_layers - # for each block. - # In order to be able to split on kv_heads dim as required by - # heterogeneous TP, one must be able to index (K/V, layer_idx) separately. - # Hence we multiply the number of 'virtual' regions here and divide - # `block_len` below. - multiply = 1 - for dim_idx in range(1, self.kv_topo.physical_kv_heads_position): - multiply *= tensor_shape[dim_idx] - logger.debug( - "Multiply is %d shape %s kv_heads_pos %d", - multiply, - tensor_shape, - self.kv_topo.physical_kv_heads_position, - ) - for i in range(len(self.slot_size_per_layer)): - assert self.slot_size_per_layer[i] % multiply == 0 - self.slot_size_per_layer[i] //= multiply - self.num_regions *= multiply - elif self.kv_topo.is_kv_layout_blocks_first: + if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 From 580dbc4c0fcbb536c8e75863190d25fbaa641fce Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 13:25:26 +0000 Subject: [PATCH 43/84] Code review fix Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 6 +++--- .../kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index ac5d3f558e92..4c9e8b164b65 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -186,11 +186,11 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con ### Cross layers blocks -By default, this feature is enabled. On backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. -To disable this feature: +By default, this feature is enabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. +You can disable this feature: ``` ---kv-transfer-config '{..., "kv_connector_extra_config":{"enable_cross_layers_block": "False"}}' +--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "False"}}' ``` ## Example Scripts/Code diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 333dded2c944..b8cf19038115 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -311,7 +311,7 @@ def prefer_cross_layer_blocks(self) -> bool: return False extra_config = self.kv_transfer_config.kv_connector_extra_config - return bool(str(extra_config.get("enable_cross_layers_block", "True"))) + return bool(str(extra_config.get("enable_cross_layers_blocks", "True"))) def __init__( self, From 19fff296a2e4a617661ec116d653129e4b347127 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 13:27:52 +0000 Subject: [PATCH 44/84] Code review fix Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a78f3b64570f..94b36fdfef28 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1390,6 +1390,7 @@ def req_id(outputs: list[RequestOutput]) -> str: ), ), "TRITON_ATTN", + "FLASHINFER", ], ) def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): From dd97e99b7eb13707aab663d0632af4d6372e9c9c Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:26:08 +0000 Subject: [PATCH 45/84] n/a Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 4c9e8b164b65..7b1a93347f20 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -189,7 +189,7 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con By default, this feature is enabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. You can disable this feature: -``` +```bash --kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "False"}}' ``` From fce2050e442bf37ffd68d30ae9ec481fb71ba6ce Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:37:57 +0000 Subject: [PATCH 46/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 3 +++ vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 94b36fdfef28..36aed0465243 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1407,6 +1407,9 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): vllm_config = create_vllm_config(attention_backend=attn_backend) + # Enable cross layers blocks + vllm_config.kv_transfer_config.extra_config["enable_cross_layers_blocks"] = True + # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b8cf19038115..71f5ff13fc58 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -311,7 +311,7 @@ def prefer_cross_layer_blocks(self) -> bool: return False extra_config = self.kv_transfer_config.kv_connector_extra_config - return bool(str(extra_config.get("enable_cross_layers_blocks", "True"))) + return bool(str(extra_config.get("enable_cross_layers_blocks", "False"))) def __init__( self, From 716115096543e36e0c1e4f8608d90d32ff4cc243 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:39:28 +0000 Subject: [PATCH 47/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 36aed0465243..4e5058376c81 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1408,7 +1408,9 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): vllm_config = create_vllm_config(attention_backend=attn_backend) # Enable cross layers blocks - vllm_config.kv_transfer_config.extra_config["enable_cross_layers_blocks"] = True + vllm_config.kv_transfer_config.kv_connector_extra_config[ + "enable_cross_layers_blocks" + ] = True # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": From 4d3890ec5a9207cca0d9804e75d9a60af2645616 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:42:48 +0000 Subject: [PATCH 48/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 4e5058376c81..ebcda2fcbc52 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1421,10 +1421,14 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend backend_cls = RocmAttentionBackend - else: # TRITON_ATTN + elif attn_backend == "TRITON_ATTN": from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend backend_cls = TritonAttentionBackend + else: # FLASHINFER + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + backend_cls = FlashInferBackend nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( From d92bf969731eb67183b31bc279cade140f649d55 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:49:47 +0000 Subject: [PATCH 49/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index ebcda2fcbc52..f2389d5c3a19 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1513,7 +1513,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ] expected_num_entries = 1 + kv_heads_idx = cross_layers_kv_cache.shape.index(4) + expected_blocks_count = 8 if kv_heads_idx == 1 else 16 + kv_caches = {"all-layers": cross_layers_kv_cache} + else: # Create test kv cache tensors using proper backend shape kv_cache_shape = backend_cls.get_kv_cache_shape( @@ -1554,6 +1558,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): unique_tensor[1].data_ptr(), ] expected_num_entries = 4 + expected_blocks_count = 8 # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1578,7 +1583,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] # Validate blocks_data structure and size - expected_blocks_count = 8 assert len(blocks_data) == expected_blocks_count, ( f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) From 6991cddabf073f8ecb1c1a7d016471731b7579c9 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:51:17 +0000 Subject: [PATCH 50/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index f2389d5c3a19..479321581bcd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1558,7 +1558,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): unique_tensor[1].data_ptr(), ] expected_num_entries = 4 - expected_blocks_count = 8 + expected_blocks_count = 8 # Execute register_kv_caches connector.register_kv_caches(kv_caches) From 2791d34fd566b59f4168938cd79d3ca391a69731 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 14:55:56 +0000 Subject: [PATCH 51/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 479321581bcd..80bbd847d8e8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1514,6 +1514,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): expected_num_entries = 1 kv_heads_idx = cross_layers_kv_cache.shape.index(4) + print("kv_heads idx %d", kv_heads_idx) expected_blocks_count = 8 if kv_heads_idx == 1 else 16 kv_caches = {"all-layers": cross_layers_kv_cache} From ff1f24409a25a0acfca16dae5b2a016c9033004e Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 15:02:36 +0000 Subject: [PATCH 52/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 80bbd847d8e8..821eaa4b13a8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1421,14 +1421,10 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend backend_cls = RocmAttentionBackend - elif attn_backend == "TRITON_ATTN": + else: # TRITON from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend backend_cls = TritonAttentionBackend - else: # FLASHINFER - from vllm.v1.attention.backends.flashinfer import FlashInferBackend - - backend_cls = FlashInferBackend nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( @@ -1513,9 +1509,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ] expected_num_entries = 1 - kv_heads_idx = cross_layers_kv_cache.shape.index(4) - print("kv_heads idx %d", kv_heads_idx) - expected_blocks_count = 8 if kv_heads_idx == 1 else 16 + expected_blocks_count = 8 kv_caches = {"all-layers": cross_layers_kv_cache} From 92f262805dfbd0ef73aff380137f4c526d1e7e73 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 21 Jan 2026 15:35:16 +0000 Subject: [PATCH 53/84] Add cross layers blocks to run_accuracy_tests.sh Signed-off-by: Liran Schour --- .../nixl_integration/run_accuracy_test.sh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c2c38f51c500..e5e333f19fe8 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -34,11 +34,18 @@ else KV_CONFIG_HETERO_LAYOUT='' fi +CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers +if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then + KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}' +else + KV_EXTRA_CONFIG='' +fi + # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" fi # Models to run From 4715cedbc97038879a9bb6fdd6098411655ab5c2 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 22 Jan 2026 03:49:42 +0000 Subject: [PATCH 54/84] Code review fix Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 4 +-- .../kv_transfer/kv_connector/utils.py | 10 -------- .../kv_connector/v1/nixl_connector.py | 25 ++++++------------- 3 files changed, 10 insertions(+), 29 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 7b1a93347f20..489e639ad312 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -187,10 +187,10 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con ### Cross layers blocks By default, this feature is enabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. -You can disable this feature: +To enable this feature: ```bash ---kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "False"}}' +--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}' ``` ## Example Scripts/Code diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 7c361c5aa440..f73431fd7b2f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -317,7 +317,6 @@ class TpKVTopology: engine_id: EngineId remote_block_size: dict[EngineId, int] tensor_shape: torch.Size | None = None - device_type: str = "cuda" def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -351,10 +350,6 @@ def __post_init__(self): # permute kv_cache_shape according to stride_order kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - physical_kv_heads_position = kv_cache_shape.index(4) - assert physical_kv_heads_position is not None - self._physical_kv_heads_position = physical_kv_heads_position - physical_block_size_position = kv_cache_shape.index(16) assert physical_block_size_position is not None self._physical_block_size_position = -( @@ -388,11 +383,6 @@ def cross_layers_blocks(self) -> bool: def block_size_position(self) -> int: return self._physical_block_size_position - @property - def physical_kv_heads_position(self) -> int: - assert self._physical_kv_heads_position is not None - return self._physical_kv_heads_position - def tp_ratio( self, remote_tp_size: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 71f5ff13fc58..e82ef910c322 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1021,15 +1021,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = TpKVTopology( - tp_rank=self.tp_rank, - engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=self.attn_backend, - ) self._physical_blocks_per_logical_kv_block = 1 self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( @@ -1740,8 +1731,10 @@ def _validate_remote_agent_handshake( """ remote_engine_id = nixl_agent_meta.engine_id - assert self._tp_size[remote_engine_id] == remote_tp_size - assert self.kv_topo is not None + assert ( + self._tp_size[remote_engine_id] == remote_tp_size + and self.kv_topo is not None + ) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1878,6 +1871,7 @@ def post_process_device_kv_on_receive( if len(self.device_kv_caches) == 0: return assert block_size_ratio >= 1, "Only nP < nD supported currently." + assert self.kv_topo is not None if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( "Post-processing device kv cache on receive by converting " @@ -1897,7 +1891,6 @@ def post_process_device_kv_on_receive( block_size_ratio, ) - assert self.kv_topo is not None split_k_and_v = self.kv_topo.split_k_and_v for block_ids in block_ids_list: @@ -1923,11 +1916,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: The scheduler process (via the MultiprocExecutor) will use this output to track which workers are done. """ + assert self.kv_topo is not None done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) - assert self.kv_topo is not None - # add requests that skipped transfer to done_recving done_recving.update(self._failed_recv_reqs) self._failed_recv_reqs.clear() @@ -1994,6 +1986,7 @@ def _get_new_notifs(self) -> set[str]: are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. """ + assert self.kv_topo is not None notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: @@ -2012,7 +2005,6 @@ def _get_new_notifs(self) -> set[str]: # NOTE: `tp_ratio` is the opposite when swapping local<>remote n_consumers = int(tp_size) - assert self.kv_topo is not None tp_ratio = self.kv_topo.tp_ratio(n_consumers) # Number of reads *per producer* to wait for. @@ -2154,8 +2146,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - assert meta.remote is not None - assert self.kv_topo is not None + assert meta.remote is not None and self.kv_topo is not None remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( meta.remote.engine_id ) From c2e0ca05422059b3c18e1a967cb9aa3cadd97f22 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 22 Jan 2026 04:31:31 +0000 Subject: [PATCH 55/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 821eaa4b13a8..1286af75d5aa 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( KVOutputAggregator, TpKVTopology, + get_current_attn_backend, ) from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats @@ -2251,6 +2252,27 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_worker = decode_connector.connector_worker + backend = get_current_attn_backend(local_vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + decode_worker.kv_topo = TpKVTopology( + tp_rank=decode_worker.tp_rank, + engine_id=decode_worker.engine_id, + remote_tp_size=decode_worker._tp_size, # shared state + remote_block_size=decode_worker._block_size, # shared state + is_mla=decode_worker.use_mla, + total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), + attn_backend=backend, + tensor_shape=test_shape, + ) + + decode_worker.compat_hash = compute_nixl_compatibility_hash( + decode_worker.vllm_config, + decode_worker.backend_name, + decode_worker.kv_topo.cross_layers_blocks, + ) + if error_scenario == "handshake_decode_error": msg_bytes = b"this is not valid msgpack data" elif error_scenario == "handshake_validation_error": From d9ad710cb157d31360c4fc712ea4153011d173d2 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 22 Jan 2026 07:21:37 +0000 Subject: [PATCH 56/84] Minor fix Signed-off-by: Liran Schour --- docs/features/nixl_connector_usage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 489e639ad312..749c6fbe7f6d 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -186,7 +186,7 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con ### Cross layers blocks -By default, this feature is enabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. +By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. To enable this feature: ```bash From 7f0d3b45e7884758446e08288fcab19f9ab882ae Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 08:55:54 +0000 Subject: [PATCH 57/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 6 ++++++ .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f73431fd7b2f..1561a2ef077a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -347,10 +347,16 @@ def __post_init__(self): except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + logger.info("XXX shape: %s", kv_cache_shape) # permute kv_cache_shape according to stride_order kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) physical_block_size_position = kv_cache_shape.index(16) + logger.info( + "XXX shape %s blk_size_pos %d", + kv_cache_shape, + physical_block_size_position, + ) assert physical_block_size_position is not None self._physical_block_size_position = -( len(kv_cache_shape) - physical_block_size_position diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e82ef910c322..2fbd34e244f8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -311,7 +311,10 @@ def prefer_cross_layer_blocks(self) -> bool: return False extra_config = self.kv_transfer_config.kv_connector_extra_config - return bool(str(extra_config.get("enable_cross_layers_blocks", "False"))) + return ( + str(extra_config.get("enable_cross_layers_blocks", "False")).lower() + == "true" + ) def __init__( self, From 156359e65e577ed2f9f57c81a8d62c1d1e31c14e Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 09:05:38 +0000 Subject: [PATCH 58/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1561a2ef077a..3ba11981a776 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -352,15 +352,16 @@ def __post_init__(self): kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) physical_block_size_position = kv_cache_shape.index(16) - logger.info( - "XXX shape %s blk_size_pos %d", - kv_cache_shape, - physical_block_size_position, - ) + assert physical_block_size_position is not None self._physical_block_size_position = -( len(kv_cache_shape) - physical_block_size_position ) + logger.info( + "XXX shape %s blk_size_pos %d", + kv_cache_shape, + self._physical_block_size_position, + ) @property def is_kv_layout_blocks_first(self) -> bool: From fceb3f97ce35bc303adc56d3c63cd7604884e691 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 09:27:06 +0000 Subject: [PATCH 59/84] n/a Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2fbd34e244f8..0cc072433c17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1340,6 +1340,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None + if self.kv_topo.cross_layers_blocks: + block_size_position = self.kv_topo.block_size_position + else: + if self.device_type == "cpu": + block_size_position = -2 + else: + block_size_position = -2 if self.use_mla else -3 + # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms @@ -1352,7 +1360,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if base_addr in seen_base_addresses: continue - kernel_block_size = cache.shape[self.kv_topo.block_size_position] + kernel_block_size = cache.shape[block_size_position] if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" From 9f3f8c4b718b44c1bc511bcee6c9d129ae8b4a0a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 10:56:34 +0000 Subject: [PATCH 60/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0cc072433c17..20924f536bac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1360,6 +1360,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if base_addr in seen_base_addresses: continue + logger.info("XXX cache shape %s", cache.shape) kernel_block_size = cache.shape[block_size_position] if self.block_size != kernel_block_size: logger.info_once( From a19fc2c5c7689d62ff0e2b441818877ffc56b208 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 13:34:22 +0000 Subject: [PATCH 61/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 6 ------ .../kv_transfer/kv_connector/v1/nixl_connector.py | 1 - 2 files changed, 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 3ba11981a776..9a2730eb909c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -347,7 +347,6 @@ def __post_init__(self): except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - logger.info("XXX shape: %s", kv_cache_shape) # permute kv_cache_shape according to stride_order kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) @@ -357,11 +356,6 @@ def __post_init__(self): self._physical_block_size_position = -( len(kv_cache_shape) - physical_block_size_position ) - logger.info( - "XXX shape %s blk_size_pos %d", - kv_cache_shape, - self._physical_block_size_position, - ) @property def is_kv_layout_blocks_first(self) -> bool: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 20924f536bac..0cc072433c17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1360,7 +1360,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if base_addr in seen_base_addresses: continue - logger.info("XXX cache shape %s", cache.shape) kernel_block_size = cache.shape[block_size_position] if self.block_size != kernel_block_size: logger.info_once( From eea914c1f267a728b0fcd4f456db2f2061d24832 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Wed, 28 Jan 2026 13:40:45 +0000 Subject: [PATCH 62/84] Calculate block_size_position Signed-off-by: Liran Schour --- .../nixl_integration/run_accuracy_test.sh | 2 +- .../kv_transfer/kv_connector/utils.py | 27 +++++++++---------- .../kv_connector/v1/nixl_connector.py | 10 +++---- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e5e333f19fe8..6f08fd9c4f06 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -36,7 +36,7 @@ fi CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then - KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}' + KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"enable_cross_layers_blocks": "True"}' else KV_EXTRA_CONFIG='' fi diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 9a2730eb909c..d288738dc18d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -340,22 +340,21 @@ def __post_init__(self): if self._cross_layers_blocks: # prepend layers dimension kv_cache_shape = (80,) + kv_cache_shape - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=self._cross_layers_blocks - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=self._cross_layers_blocks + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - # permute kv_cache_shape according to stride_order - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + # In case of cross layers permute kv_cache_shape according to + # stride_order to retrieve physical position of block_size + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - physical_block_size_position = kv_cache_shape.index(16) + block_size_position = kv_cache_shape.index(16) - assert physical_block_size_position is not None - self._physical_block_size_position = -( - len(kv_cache_shape) - physical_block_size_position - ) + assert block_size_position is not None + self._block_size_position = -(len(kv_cache_shape) - block_size_position) @property def is_kv_layout_blocks_first(self) -> bool: @@ -382,7 +381,7 @@ def cross_layers_blocks(self) -> bool: @property def block_size_position(self) -> int: - return self._physical_block_size_position + return self._block_size_position def tp_ratio( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0cc072433c17..1f1a76305186 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1340,13 +1340,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None - if self.kv_topo.cross_layers_blocks: - block_size_position = self.kv_topo.block_size_position - else: - if self.device_type == "cpu": - block_size_position = -2 - else: - block_size_position = -2 if self.use_mla else -3 + block_size_position = ( + self.kv_topo.block_size_position if self.device_type != "cpu" else -2 + ) # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() From 568e6412e3590d8d08e536b5067296f22e667ac6 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 29 Jan 2026 11:16:22 +0000 Subject: [PATCH 63/84] Enhance test_register_kv_caches() to test cross layers case Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 7 +++++-- vllm/distributed/kv_transfer/kv_connector/utils.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e93835598a41..5c6e288eece4 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1377,6 +1377,7 @@ def req_id(outputs: list[RequestOutput]) -> str: llm.llm_engine.engine_core.shutdown() +@pytest.mark.parametrize("enable_cross_layers", [False, True]) @pytest.mark.parametrize( "attn_backend", [ @@ -1398,7 +1399,9 @@ def req_id(outputs: list[RequestOutput]) -> str: "FLASHINFER", ], ) -def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): +def test_register_kv_caches( + default_vllm_config, dist_init, attn_backend, enable_cross_layers +): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1415,7 +1418,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): # Enable cross layers blocks vllm_config.kv_transfer_config.kv_connector_extra_config[ "enable_cross_layers_blocks" - ] = True + ] = enable_cross_layers # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index d288738dc18d..4c865e197bf9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -351,6 +351,10 @@ def __post_init__(self): # stride_order to retrieve physical position of block_size kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + # In the default non-cross layers layout the block_size position + # is logical while in the cross layers case it is the physical + # position. This matches the shape of the actual kv cache tensors + # passed at register_kv_caches()/register_cross_layers_kv_cache() block_size_position = kv_cache_shape.index(16) assert block_size_position is not None From 6096ce96916b498a1f5159eee6593a8b8cb52fce Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 29 Jan 2026 11:37:36 +0000 Subject: [PATCH 64/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 5c6e288eece4..607d67e47a5f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1396,7 +1396,6 @@ def req_id(outputs: list[RequestOutput]) -> str: ), ), "TRITON_ATTN", - "FLASHINFER", ], ) def test_register_kv_caches( @@ -1607,6 +1606,8 @@ def test_register_kv_caches( f"got {block_len}" ) + assert connector.connector_worker.block_size == 16 + class FakePlatform(Platform): device_type: str = "oot" From 264c550e0fa30ea405074f3500df15fa88094882 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 29 Jan 2026 14:28:26 +0000 Subject: [PATCH 65/84] Code review fix Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 4c865e197bf9..9331355e34a3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -321,8 +321,9 @@ class TpKVTopology: def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. + _MOCK_BLOCK_SIZE = 16 kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=4, head_size=1 + num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=4, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. @@ -339,7 +340,8 @@ def __post_init__(self): if self._cross_layers_blocks: # prepend layers dimension - kv_cache_shape = (80,) + kv_cache_shape + _MOCK_NUM_LAYERS = 80 + kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape try: kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=self._cross_layers_blocks @@ -355,7 +357,7 @@ def __post_init__(self): # is logical while in the cross layers case it is the physical # position. This matches the shape of the actual kv cache tensors # passed at register_kv_caches()/register_cross_layers_kv_cache() - block_size_position = kv_cache_shape.index(16) + block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) assert block_size_position is not None self._block_size_position = -(len(kv_cache_shape) - block_size_position) From d14a4876281643df752ebe3e90e83b4be0d7707d Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 2 Feb 2026 09:56:00 +0200 Subject: [PATCH 66/84] Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Or Ozeri Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1f1a76305186..09489be9d1ad 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: backend = get_current_attn_backend(self._vllm_config) - if backend().get_name() not in ( + if backend.get_name() not in ( "FLASH_ATTN", "FLASHINFER", ): From 969c3a1cfeefc7bf39758281ef5f82b4639371cb Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 08:04:29 +0000 Subject: [PATCH 67/84] Code review fix Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- .../kv_transfer/kv_connector/v1/nixl_connector.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 607d67e47a5f..07d4c5db25dd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1428,7 +1428,7 @@ def test_register_kv_caches( from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend backend_cls = RocmAttentionBackend - else: # TRITON + else: # TRITON_ATTN from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend backend_cls = TritonAttentionBackend diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 09489be9d1ad..f9852016ae6c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -310,6 +310,9 @@ def prefer_cross_layer_blocks(self) -> bool: # does not support on HND return False + if get_kv_cache_layout() != "HND": + return False + extra_config = self.kv_transfer_config.kv_connector_extra_config return ( str(extra_config.get("enable_cross_layers_blocks", "False")).lower() @@ -1738,10 +1741,8 @@ def _validate_remote_agent_handshake( """ remote_engine_id = nixl_agent_meta.engine_id - assert ( - self._tp_size[remote_engine_id] == remote_tp_size - and self.kv_topo is not None - ) + assert self._tp_size[remote_engine_id] == remote_tp_size + assert self.kv_topo is not None tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -2222,6 +2223,10 @@ def _read_blocks( local_xfer_side_handle: int, remote_xfer_side_handle: int, ): + """ + Post a READ point-to-point xfer request from a single local worker to + a single remote worker. + """ assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: From f0335d3b47618cbb3a6a3e72018cf22faf68f8e4 Mon Sep 17 00:00:00 2001 From: liranschour Date: Mon, 2 Feb 2026 11:58:41 +0200 Subject: [PATCH 68/84] Update vllm/distributed/kv_transfer/kv_connector/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Nicolò Lucchesi Signed-off-by: liranschour --- vllm/distributed/kv_transfer/kv_connector/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 9331355e34a3..9f89c2f84344 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -323,7 +323,7 @@ def __post_init__(self): # or num_blocks. This is used to register the memory regions correctly. _MOCK_BLOCK_SIZE = 16 kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=4, head_size=1 + num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below. From 7f25bca8aacfa4eba2541488b335e9c673421232 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 10:14:33 +0000 Subject: [PATCH 69/84] Code review fix Signed-off-by: Liran Schour --- .../kv_connector/unit/test_nixl_connector.py | 2 +- .../kv_transfer/kv_connector/utils.py | 52 ++++++++++--------- .../kv_connector/v1/nixl_connector.py | 10 ++-- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 07d4c5db25dd..41bbe28b9642 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1377,7 +1377,7 @@ def req_id(outputs: list[RequestOutput]) -> str: llm.llm_engine.engine_core.shutdown() -@pytest.mark.parametrize("enable_cross_layers", [False, True]) +@pytest.mark.parametrize("enable_cross_layers", ["False", "True"]) @pytest.mark.parametrize( "attn_backend", [ diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 9f89c2f84344..13845fe29734 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -331,36 +332,39 @@ def __post_init__(self): len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) - self._kv_heads_position: int | None = None self._cross_layers_blocks = False if self.tensor_shape is not None: self._cross_layers_blocks = ( len(self.tensor_shape) == len(kv_cache_shape) + 1 ) - if self._cross_layers_blocks: - # prepend layers dimension - _MOCK_NUM_LAYERS = 80 - kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=self._cross_layers_blocks - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - - # In case of cross layers permute kv_cache_shape according to - # stride_order to retrieve physical position of block_size - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - - # In the default non-cross layers layout the block_size position - # is logical while in the cross layers case it is the physical - # position. This matches the shape of the actual kv cache tensors - # passed at register_kv_caches()/register_cross_layers_kv_cache() - block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) - - assert block_size_position is not None - self._block_size_position = -(len(kv_cache_shape) - block_size_position) + if self._cross_layers_blocks: + # prepend layers dimension + _MOCK_NUM_LAYERS = 80 + kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=self._cross_layers_blocks + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + + # In case of cross layers permute kv_cache_shape according to + # stride_order to retrieve physical position of block_size + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + + # In the default non-cross layers layout the block_size position + # is logical while in the cross layers case it is the physical + # position. This matches the shape of the actual kv cache tensors + # passed at register_kv_caches()/register_cross_layers_kv_cache() + block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) + + assert block_size_position is not None + self._block_size_position = ( + -(len(kv_cache_shape) - block_size_position) + if current_platform.device_type != "cpu" + else -2 + ) @property def is_kv_layout_blocks_first(self) -> bool: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f9852016ae6c..8a4e2959e7ec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -306,10 +306,10 @@ def prefer_cross_layer_blocks(self) -> bool: "FLASH_ATTN", "FLASHINFER", ): - # For now there is no benefit to run cross layers when backend - # does not support on HND return False + # For now there is no benefit to run cross layers when backend + # does not support on HND if get_kv_cache_layout() != "HND": return False @@ -1343,10 +1343,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None - block_size_position = ( - self.kv_topo.block_size_position if self.device_type != "cpu" else -2 - ) - # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms @@ -1359,7 +1355,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if base_addr in seen_base_addresses: continue - kernel_block_size = cache.shape[block_size_position] + kernel_block_size = cache.shape[self.kv_topo.block_size_position] if self.block_size != kernel_block_size: logger.info_once( "User-specified logical block size (%s) does not match" From 858f856c78d4715fe63919c95dea330e25bb376a Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 10:54:20 +0000 Subject: [PATCH 70/84] Added CI tests Signed-off-by: Liran Schour --- .../config_sweep_accuracy_test.sh | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 2e25e2f1ac32..9976bd392943 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -13,15 +13,26 @@ tp_configs=( "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" ) +cross_layers_configs=( + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1" + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" + "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" +) dp_ep_configs=( -"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) -"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) +"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) +"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) # Select config array based on DP_EP env var if [[ -n "${DP_EP:-}" ]]; then configs=("${dp_ep_configs[@]}") echo "DP_EP is set, using dp_ep_configs" +elif [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then + configs=("${cross_layers_configs[@]}") + echo "CROSS_LAYERS_BLOCKS is set, using cross_layers_configs" else configs=("${tp_configs[@]}") fi From ff7e8c7aa474b3079758b8877b32f71872bc26ac Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 11:51:29 +0000 Subject: [PATCH 71/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 6f08fd9c4f06..1182f59c6ef0 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -4,6 +4,7 @@ set -xe # Parse command line arguments KV_BUFFER_DEVICE="cuda" # Default to cuda ATTENTION_BACKEND="" # Default to empty (use vllm default) +CROSS_LAYERS_BLOCKS="False" while [[ $# -gt 0 ]]; do case $1 in --kv_buffer_device) @@ -14,6 +15,8 @@ while [[ $# -gt 0 ]]; do ATTENTION_BACKEND="$2" shift 2 ;; + --enable-cross-layers + CROSS_LAYERS_BLOCKS="True" *) echo "Unknown option $1" echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]" @@ -34,7 +37,6 @@ else KV_CONFIG_HETERO_LAYOUT='' fi -CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"enable_cross_layers_blocks": "True"}' else From 08f1ac44c77b4bc253d7db6c8f24830e340831c0 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 11:55:35 +0000 Subject: [PATCH 72/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 1182f59c6ef0..560ce4407038 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -15,8 +15,10 @@ while [[ $# -gt 0 ]]; do ATTENTION_BACKEND="$2" shift 2 ;; - --enable-cross-layers + --enable-cross-layers) CROSS_LAYERS_BLOCKS="True" + shift 1 + ;; *) echo "Unknown option $1" echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]" From 1c734a9b13539723bb65cdc847c75cd2e922476f Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Mon, 2 Feb 2026 12:05:05 +0000 Subject: [PATCH 73/84] CI fix Signed-off-by: Liran Schour --- .../config_sweep_accuracy_test.sh | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 9976bd392943..9cdcb86b7077 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -13,14 +13,6 @@ tp_configs=( "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" ) -cross_layers_configs=( - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1" - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" - "CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" -) dp_ep_configs=( "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) @@ -30,9 +22,6 @@ dp_ep_configs=( if [[ -n "${DP_EP:-}" ]]; then configs=("${dp_ep_configs[@]}") echo "DP_EP is set, using dp_ep_configs" -elif [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then - configs=("${cross_layers_configs[@]}") - echo "CROSS_LAYERS_BLOCKS is set, using cross_layers_configs" else configs=("${tp_configs[@]}") fi @@ -68,3 +57,11 @@ if [[ -n "${FLASHINFER:-}" ]]; then else echo "FLASHINFER not set, skipping FLASHINFER runs." fi + +# Check if cross-layers is enabled (non-empty) +if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then + echo "CROSS_LAYERS_BLOCKS is set, rerunning with --enable-cross-layers" + run_tests "default backend" "--enable-cross-layers" +else + echo "CROSS_LAYERS_BLOCKS is not set, skipping --enable-cross-layers runs." +fi From e51a290adb5285132f80dfd28812488d65366cda Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 08:34:40 +0000 Subject: [PATCH 74/84] Code review fix Signed-off-by: Liran Schour --- .../nixl_integration/config_sweep_accuracy_test.sh | 3 --- tests/v1/kv_connector/unit/test_nixl_connector.py | 1 + vllm/distributed/kv_transfer/kv_connector/utils.py | 8 +------- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 9cdcb86b7077..a4b92ab2cd07 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -62,6 +62,3 @@ fi if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then echo "CROSS_LAYERS_BLOCKS is set, rerunning with --enable-cross-layers" run_tests "default backend" "--enable-cross-layers" -else - echo "CROSS_LAYERS_BLOCKS is not set, skipping --enable-cross-layers runs." -fi diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 41bbe28b9642..9cab154bdfb8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1465,6 +1465,7 @@ def test_register_kv_caches( expected_base_addrs: list[int] expected_num_entries: int kv_caches: dict[str, torch.Tensor] + assert (not enable_cross_layers) or connector.prefer_cross_layer_blocks if connector.prefer_cross_layer_blocks: num_layers = 32 block_size = 16 diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 13845fe29734..019201ede73e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -193,8 +193,6 @@ def copy_kv_blocks( dst_device=dst_device, ) - from vllm.platforms import current_platform - if direction == "h2d": copy_fn = current_platform.insert_blocks_to_device else: @@ -360,11 +358,7 @@ def __post_init__(self): block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE) assert block_size_position is not None - self._block_size_position = ( - -(len(kv_cache_shape) - block_size_position) - if current_platform.device_type != "cpu" - else -2 - ) + self._block_size_position = -(len(kv_cache_shape) - block_size_position) @property def is_kv_layout_blocks_first(self) -> bool: From 0ba9b051305154d90e50e46be9a9b9dc1403863b Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 08:46:32 +0000 Subject: [PATCH 75/84] Code review fix Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 9cab154bdfb8..704300f8cb97 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1465,7 +1465,9 @@ def test_register_kv_caches( expected_base_addrs: list[int] expected_num_entries: int kv_caches: dict[str, torch.Tensor] - assert (not enable_cross_layers) or connector.prefer_cross_layer_blocks + assert ( + str(enable_cross_layers).lower() == "false" + ) or connector.prefer_cross_layer_blocks if connector.prefer_cross_layer_blocks: num_layers = 32 block_size = 16 From 60cf003187424934d1a83a08e9c84d939255bf15 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:01:48 +0000 Subject: [PATCH 76/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 704300f8cb97..95b0658680cc 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1465,9 +1465,10 @@ def test_register_kv_caches( expected_base_addrs: list[int] expected_num_entries: int kv_caches: dict[str, torch.Tensor] - assert ( - str(enable_cross_layers).lower() == "false" - ) or connector.prefer_cross_layer_blocks + assert str(enable_cross_layers).lower() != "true" or ( + (attn_backend not in ("FLASH_ATTN", "FLASHINFER")) + or connector.prefer_cross_layer_blocks + ) if connector.prefer_cross_layer_blocks: num_layers = 32 block_size = 16 From b95de2816543e49a65c4537e5fd7d86344678d7e Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:09:36 +0000 Subject: [PATCH 77/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8a4e2959e7ec..52e1ba184300 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -314,6 +314,7 @@ def prefer_cross_layer_blocks(self) -> bool: return False extra_config = self.kv_transfer_config.kv_connector_extra_config + logger.info("XXX %s", extra_config) return ( str(extra_config.get("enable_cross_layers_blocks", "False")).lower() == "true" @@ -331,7 +332,7 @@ def __init__( assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config - + logger.info("XXX %s", self.kv_transfer_config.kv_connector_extra_config) if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( NixlConnectorScheduler(vllm_config, self.engine_id) From 6142cfbb239182a0fe37408356cc15702395c28b Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:13:38 +0000 Subject: [PATCH 78/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 52e1ba184300..91ed905a8190 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -302,6 +302,7 @@ class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: backend = get_current_attn_backend(self._vllm_config) + logger.info("XXX %s %s", backend.get_name(), backend().get_name()) if backend.get_name() not in ( "FLASH_ATTN", "FLASHINFER", From f21077427022d469c61d0b1a1e91d5aae48f13ce Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:15:38 +0000 Subject: [PATCH 79/84] n/a Signed-off-by: Liran Schour --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 91ed905a8190..fb39c68bb5d8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -307,11 +307,13 @@ def prefer_cross_layer_blocks(self) -> bool: "FLASH_ATTN", "FLASHINFER", ): + logger.info("XXX AA") return False # For now there is no benefit to run cross layers when backend # does not support on HND if get_kv_cache_layout() != "HND": + logger.info("XXX HERE") return False extra_config = self.kv_transfer_config.kv_connector_extra_config From e1832cb2c7d372bdc97bc1fce2dce779f66ac44b Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:36:39 +0000 Subject: [PATCH 80/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 95b0658680cc..025448dcf54a 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -49,6 +49,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends import set_kv_cache_layout from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor @@ -1418,6 +1419,7 @@ def test_register_kv_caches( vllm_config.kv_transfer_config.kv_connector_extra_config[ "enable_cross_layers_blocks" ] = enable_cross_layers + set_kv_cache_layout("HND") # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": From b8252dd97e9b257005ac175850d7c9ef8834dd07 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:37:50 +0000 Subject: [PATCH 81/84] n/a Signed-off-by: Liran Schour --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 025448dcf54a..1975d2226073 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -49,8 +49,8 @@ from vllm.platforms import current_platform from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams -from vllm.v1.attention.backends import set_kv_cache_layout from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor From bd9f1c442d3e136c9dc38cd2c5c1f021d6667d7c Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 09:39:59 +0000 Subject: [PATCH 82/84] n/a Signed-off-by: Liran Schour --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fb39c68bb5d8..fccc76a9f742 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -302,22 +302,18 @@ class NixlConnector(KVConnectorBase_V1): @property def prefer_cross_layer_blocks(self) -> bool: backend = get_current_attn_backend(self._vllm_config) - logger.info("XXX %s %s", backend.get_name(), backend().get_name()) if backend.get_name() not in ( "FLASH_ATTN", "FLASHINFER", ): - logger.info("XXX AA") return False # For now there is no benefit to run cross layers when backend # does not support on HND if get_kv_cache_layout() != "HND": - logger.info("XXX HERE") return False extra_config = self.kv_transfer_config.kv_connector_extra_config - logger.info("XXX %s", extra_config) return ( str(extra_config.get("enable_cross_layers_blocks", "False")).lower() == "true" @@ -335,7 +331,6 @@ def __init__( assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config - logger.info("XXX %s", self.kv_transfer_config.kv_connector_extra_config) if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( NixlConnectorScheduler(vllm_config, self.engine_id) From 4652655356b58c0ac4004f1a6fefd7b6fa2a9c80 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Tue, 3 Feb 2026 13:46:09 +0000 Subject: [PATCH 83/84] n/a Signed-off-by: Liran Schour --- .../kv_connector/nixl_integration/config_sweep_accuracy_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index a4b92ab2cd07..cdbcdca546e7 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -62,3 +62,4 @@ fi if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then echo "CROSS_LAYERS_BLOCKS is set, rerunning with --enable-cross-layers" run_tests "default backend" "--enable-cross-layers" +fi From ac8903f26a3882748ff299f1433f61c101962915 Mon Sep 17 00:00:00 2001 From: Liran Schour Date: Thu, 5 Feb 2026 07:58:44 +0000 Subject: [PATCH 84/84] retrigger tests Signed-off-by: Liran Schour