Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```

### Cross layers blocks

By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:

```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```

## Example Scripts/Code

Refer to these example scripts in the vLLM repository:
Expand Down
11 changes: 2 additions & 9 deletions tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,11 @@ else
KV_CONFIG_HETERO_LAYOUT=''
fi

CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers
if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}'
else
KV_EXTRA_CONFIG=''
fi

# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi

# Models to run
Expand Down
225 changes: 47 additions & 178 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
import torch

from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
Expand Down Expand Up @@ -52,11 +48,8 @@
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup

from .utils import create_request, create_scheduler, create_vllm_config

Expand Down Expand Up @@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init):

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector.register_kv_caches(kv_caches)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
Expand Down Expand Up @@ -410,23 +402,6 @@ def __init__(
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=test_shape,
)

self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)

def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
Expand Down Expand Up @@ -1395,7 +1370,6 @@ def req_id(outputs: list[RequestOutput]) -> str:
),
),
"TRITON_ATTN",
"FLASHINFER",
],
)
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
Expand All @@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):

vllm_config = create_vllm_config(attention_backend=attn_backend)

# Enable cross layers blocks
vllm_config.kv_transfer_config.kv_connector_extra_config[
"enable_cross_layers_blocks"
] = True

# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
Expand All @@ -1426,11 +1395,49 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend

backend_cls = RocmAttentionBackend
else: # TRITON
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend

backend_cls = TritonAttentionBackend

# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}

# Store tensor info for validation

test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1

if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4

nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
Expand Down Expand Up @@ -1459,107 +1466,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False

expected_tensor_size: int
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)

with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
[
AttentionGroup(
backend=backend_cls,
layer_names=[],
kv_cache_spec=kv_cache_spec,
kv_cache_group_id=0,
)
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
kernel_block_sizes=[block_size],
)
)
# Store tensor info for validation
expected_tensor_size = (
cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
)
expected_base_addrs = [
cross_layers_kv_cache.data_ptr(),
]
expected_num_entries = 1

expected_blocks_count = 8

kv_caches = {"all-layers": cross_layers_kv_cache}

else:
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}

# Store tensor info for validation

test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1

if is_blocks_first:
expected_tensor_size = (
shared_tensor.element_size() * shared_tensor.numel()
)
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8

# Execute register_kv_caches
connector.register_kv_caches(kv_caches)

Expand All @@ -1583,19 +1489,16 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]

# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)

if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
expected_block_len = expected_tensor_size // num_blocks

for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
Expand Down Expand Up @@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation(
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
decode_connector.register_kv_caches(kv_caches)

remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
Expand All @@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation(
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
remote_vllm_config, decode_worker.backend_name
)

prefill_block_size = config_overrides.get("block_size", 16)
Expand Down Expand Up @@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker

backend = get_current_attn_backend(local_vllm_config)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
tp_rank=decode_worker.tp_rank,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
tensor_shape=test_shape,
)

decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)

if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":
Expand Down
Loading