Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
5195265
Cross layers implementation
liranschour Dec 4, 2025
0f36888
Fix linting
liranschour Dec 7, 2025
8d36b4b
Add cross layers compatibility check
liranschour Dec 10, 2025
2a20197
Move cross_layers logic into TpKVTopology
liranschour Dec 11, 2025
073b30e
Code review minor fix
liranschour Dec 11, 2025
b403a9e
Linting...
liranschour Dec 11, 2025
06d3184
Code review fixes
liranschour Dec 17, 2025
cd27866
Update vllm/distributed/kv_transfer/kv_connector/utils.py
liranschour Dec 22, 2025
19319af
Update vllm/distributed/kv_transfer/kv_connector/utils.py
liranschour Dec 22, 2025
994bf1d
Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
liranschour Dec 22, 2025
0efeba3
Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
liranschour Dec 22, 2025
ef8e7ad
Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
liranschour Dec 22, 2025
6e2b751
Code review fixes
liranschour Dec 22, 2025
5e66e8f
Code review fix
liranschour Dec 28, 2025
eaf5e3d
Code review fix
liranschour Dec 28, 2025
e85f458
Code review fix
liranschour Dec 28, 2025
fff0935
Merge remote-tracking branch 'vllm/main' into nixl_kv_cont_cross_layers
liranschour Dec 28, 2025
9bd9598
Code review fix
liranschour Dec 28, 2025
cd57ed8
Merge remote-tracking branch 'vllm/main' into nixl_kv_cont_cross_layers
liranschour Jan 13, 2026
f153e83
n/a
liranschour Jan 13, 2026
15f2a78
n/a
liranschour Jan 13, 2026
0cb1825
n/a
liranschour Jan 13, 2026
9630c8e
n/a
liranschour Jan 13, 2026
c148f6d
n/a
liranschour Jan 13, 2026
96329f6
n/a
liranschour Jan 13, 2026
0394b36
Merge remote-tracking branch 'vllm/main' into nixl_kv_cont_cross_layers
liranschour Jan 14, 2026
b4d7045
Unit test fix
liranschour Jan 15, 2026
52b1155
Unit test fix
liranschour Jan 15, 2026
e34db34
Unit test fix
liranschour Jan 15, 2026
edc0755
Unit test fix
liranschour Jan 15, 2026
03af3ec
n/a
liranschour Jan 15, 2026
f0f2cf9
n/a
liranschour Jan 15, 2026
c3e1e5e
Merge branch 'main' into nixl_kv_cont_cross_layers
liranschour Jan 18, 2026
701d4ef
Code review fix
liranschour Jan 18, 2026
e7df5f8
Code review fix
liranschour Jan 18, 2026
6dae9b5
Code review fix
liranschour Jan 19, 2026
012bb9e
Handle hetrogenous TP for FLASHINFER and TRITON
liranschour Jan 20, 2026
99b3401
n/a
liranschour Jan 20, 2026
9fe2eb6
n/a
liranschour Jan 20, 2026
043c4d8
n/a
liranschour Jan 20, 2026
fe7197c
Run cross layers only for FlashAttention and FLASHINFER
liranschour Jan 21, 2026
5d59ea6
Enhance test_register_kv_caches
liranschour Jan 21, 2026
392e5d5
Documentation
liranschour Jan 21, 2026
d0e9aed
Merge branch 'main' into nixl_kv_cont_cross_layers
liranschour Jan 21, 2026
3c6921f
Code review fix
liranschour Jan 21, 2026
ed3180c
Code review fix
liranschour Jan 21, 2026
ced9ad4
Code review fix
liranschour Jan 21, 2026
580dbc4
Code review fix
liranschour Jan 21, 2026
19fff29
Code review fix
liranschour Jan 21, 2026
dd97e99
n/a
liranschour Jan 21, 2026
fce2050
n/a
liranschour Jan 21, 2026
7161150
n/a
liranschour Jan 21, 2026
4d3890e
n/a
liranschour Jan 21, 2026
d92bf96
n/a
liranschour Jan 21, 2026
6991cdd
n/a
liranschour Jan 21, 2026
2791d34
n/a
liranschour Jan 21, 2026
ff1f244
n/a
liranschour Jan 21, 2026
92f2628
Add cross layers blocks to run_accuracy_tests.sh
liranschour Jan 21, 2026
4715ced
Code review fix
liranschour Jan 22, 2026
c2e0ca0
Code review fix
liranschour Jan 22, 2026
7d1df76
Merge branch 'main' into nixl_kv_cont_cross_layers
liranschour Jan 22, 2026
d9ad710
Minor fix
liranschour Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```

### Cross layers blocks

By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:

```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```

## Example Scripts/Code

Refer to these example scripts in the vLLM repository:
Expand Down
11 changes: 9 additions & 2 deletions tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ else
KV_CONFIG_HETERO_LAYOUT=''
fi

CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers
if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}'
else
KV_EXTRA_CONFIG=''
fi

# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
fi

# Models to run
Expand Down
225 changes: 178 additions & 47 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import torch

from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
Expand Down Expand Up @@ -48,8 +52,11 @@
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup

from .utils import create_request, create_scheduler, create_vllm_config

Expand Down Expand Up @@ -366,6 +373,7 @@ def test_kv_transfer_handshake(dist_init):

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector.register_kv_caches(kv_caches)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
Expand Down Expand Up @@ -402,6 +410,23 @@ def __init__(
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=test_shape,
)

self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)

def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
Expand Down Expand Up @@ -1366,6 +1391,7 @@ def req_id(outputs: list[RequestOutput]) -> str:
),
),
"TRITON_ATTN",
"FLASHINFER",
],
)
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
Expand All @@ -1382,6 +1408,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):

vllm_config = create_vllm_config(attention_backend=attn_backend)

# Enable cross layers blocks
vllm_config.kv_transfer_config.kv_connector_extra_config[
"enable_cross_layers_blocks"
] = True

# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
Expand All @@ -1391,49 +1422,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend

backend_cls = RocmAttentionBackend
else: # TRITON_ATTN
else: # TRITON
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend

backend_cls = TritonAttentionBackend

# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}

# Store tensor info for validation

test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1

if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4

nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
Expand Down Expand Up @@ -1462,6 +1455,107 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False

expected_tensor_size: int
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)

with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
[
AttentionGroup(
backend=backend_cls,
layer_names=[],
kv_cache_spec=kv_cache_spec,
kv_cache_group_id=0,
)
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
kernel_block_sizes=[block_size],
)
)
# Store tensor info for validation
expected_tensor_size = (
cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
)
expected_base_addrs = [
cross_layers_kv_cache.data_ptr(),
]
expected_num_entries = 1

expected_blocks_count = 8

kv_caches = {"all-layers": cross_layers_kv_cache}

else:
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}

# Store tensor info for validation

test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1

if is_blocks_first:
expected_tensor_size = (
shared_tensor.element_size() * shared_tensor.numel()
)
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8

# Execute register_kv_caches
connector.register_kv_caches(kv_caches)

Expand All @@ -1485,16 +1579,19 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]

# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)

num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks

for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
Expand Down Expand Up @@ -2041,6 +2138,17 @@ def test_compatibility_hash_validation(
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
decode_connector.register_kv_caches(kv_caches)

remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
Expand All @@ -2063,7 +2171,9 @@ def test_compatibility_hash_validation(
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)

prefill_block_size = config_overrides.get("block_size", 16)
Expand Down Expand Up @@ -2142,6 +2252,27 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker

backend = get_current_attn_backend(local_vllm_config)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
tp_rank=decode_worker.tp_rank,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
tensor_shape=test_shape,
)

decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)

if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":
Expand Down
Loading