Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import uuid
from collections import defaultdict
from typing import Any
from typing import Any, cast
from unittest.mock import MagicMock, patch

import msgspec
Expand Down Expand Up @@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init):

# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
Expand Down Expand Up @@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init):

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
decode_connector.register_kv_caches(kv_caches)

Expand Down Expand Up @@ -525,11 +533,13 @@ def test_multi_xfer_one_engine(
request_id = "req_id"

# Test worker role in decode server.
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker
Expand Down Expand Up @@ -1479,18 +1489,22 @@ def test_register_kv_caches(
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]

# Create connector
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
)

# Get the mock instance
Expand All @@ -1515,6 +1529,13 @@ def test_register_kv_caches(
num_layers = 32
block_size = 16
num_blocks = 8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config = make_kv_cache_config(
block_size=block_size, num_blocks=num_blocks
)
connector.connector_worker.kv_cache_config = worker_kv_cache_config
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
Expand Down Expand Up @@ -1568,11 +1589,17 @@ def test_register_kv_caches(

else:
# Create test kv cache tensors using proper backend shape
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
Expand Down Expand Up @@ -1606,7 +1633,7 @@ def test_register_kv_caches(
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8
expected_blocks_count = kv_cache_config.num_blocks * 4

# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
Expand Down Expand Up @@ -1639,7 +1666,7 @@ def test_register_kv_caches(
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = 2
num_blocks = kv_cache_config.num_blocks
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
Expand Down Expand Up @@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat
},
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
local_vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
decode_worker = decode_connector.connector_worker
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
Expand Down
20 changes: 4 additions & 16 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
from vllm.v1.worker.utils import select_common_block_size

BLOCK_SIZE = 16
NUM_BLOCKS = 10
Expand Down Expand Up @@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec:
def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

selected_size = select_common_block_size(128, attn_groups)
selected_size = select_common_block_size(128, [backend_a, backend_b])
assert selected_size == 128


def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

selected_size = select_common_block_size(256, attn_groups)
selected_size = select_common_block_size(256, [backend_a, backend_b])
assert selected_size == 64


def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

with pytest.raises(ValueError):
select_common_block_size(48, attn_groups)
select_common_block_size(48, [backend_a, backend_b])


def test_update_states_new_request(model_runner, dist_init):
Expand Down
60 changes: 35 additions & 25 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,6 @@ def __post_init__(self):
# stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)

# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)

assert block_size_position is not None
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
Comment on lines -360 to -367
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to rely on custom logic to figure out where block dim is


@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
Expand All @@ -389,10 +380,6 @@ def block_size(self) -> int:
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks

@property
def block_size_position(self) -> int:
return self._block_size_position

def tp_ratio(
self,
remote_tp_size: int,
Expand Down Expand Up @@ -483,23 +470,46 @@ def get_target_remote_ranks_from_engine_id(
return self.get_target_remote_ranks(remote_tp_size)


def get_current_attn_backend(vllm_config: VllmConfig):
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.

Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.

Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
backend = next(iter(layers.values())).get_attn_backend()
else:
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. "
"Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())

# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend

backend = get_attn_backend(
return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla,
)
return backend
]


def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
Loading