Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this fits on CI

# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)

# Select config array based on DP_EP env var
if [[ -n "${DP_EP:-}" ]]; then
configs=("${dp_ep_configs[@]}")
echo "DP_EP is set, using dp_ep_configs"
elif [[ -n "${HYBRID_SSM:-}" ]]; then
configs=("${hybrid_ssm_configs[@]}")
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
else
configs=("${tp_configs[@]}")
fi
Expand Down
1 change: 1 addition & 0 deletions tests/v1/kv_connector/nixl_integration/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": 0.84,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a big model for CI

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can switch to granite

}

SIMPLE_PROMPT = (
Expand Down
121 changes: 76 additions & 45 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
from vllm.v1.attention.backends.utils import set_kv_cache_layout
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.kv_cache_interface import (
AttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheTensor,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
Expand Down Expand Up @@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):

# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
kv_cache_groups = [
KVCacheGroupSpec(
["layer0", "layer1", "layer2"],
FullAttentionSpec(
block_size=16,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
),
)
]
kv_cache_config = KVCacheConfig(
num_blocks=2, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
Expand Down Expand Up @@ -437,7 +455,7 @@ def __init__(
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
Expand All @@ -447,7 +465,7 @@ def __init__(
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
attn_backends=self.attn_backends,
tensor_shape=test_shape,
)

Expand Down Expand Up @@ -501,6 +519,7 @@ def _nixl_handshake(
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
ssm_sizes=(0, 0),
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
Expand Down Expand Up @@ -951,6 +970,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(
block_lens=worker.block_len_per_layer,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
)

with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -1006,6 +1026,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
block_lens=[i * 2 for i in worker.block_len_per_layer],
kv_cache_layout="HND",
block_size=worker.block_size,
ssm_sizes=(0, 0),
)

# We don't check layout for homogeneous TP and MLA for now, as the
Expand Down Expand Up @@ -1496,9 +1517,47 @@ def test_register_kv_caches(
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]

num_layers = 32
block_size = 16
num_blocks = 8
num_heads = 4
head_size = 16

# TODO (NickLucche) the fact that connector depends on kv_cache_config for init
# but cross-layer preference cant be inferred prior to creating kv_cache_config
# is a bit awkward.
dummy_connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=block_size),
)
kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=num_heads,
head_size=head_size,
dtype=torch.float16,
)
if dummy_connector.prefer_cross_layer_blocks:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["all-layers"],
)
for _ in range(num_layers)
],
kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)],
)
else:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec)
],
)
# Create connector
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
Expand Down Expand Up @@ -1526,35 +1585,6 @@ def test_register_kv_caches(
or connector.prefer_cross_layer_blocks
)
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config = make_kv_cache_config(
block_size=block_size, num_blocks=num_blocks
)
connector.connector_worker.kv_cache_config = worker_kv_cache_config
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)

with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
Expand Down Expand Up @@ -1586,12 +1616,8 @@ def test_register_kv_caches(
expected_blocks_count = 8

kv_caches = {"all-layers": cross_layers_kv_cache}

else:
# Create test kv cache tensors using proper backend shape
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
Expand Down Expand Up @@ -2261,18 +2287,22 @@ def test_compatibility_hash_validation(
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape(
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
# Build kv_caches from the actual layer names in kv_cache_config so that
# _layer_specs lookups in register_kv_caches always find a matching key.
layer_names = [
name for group in kv_cache_config.kv_cache_groups for name in group.layer_names
]
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
name: shared_tensor if i % 2 == 0 else unique_tensor
for i, name in enumerate(layer_names)
}
decode_connector.register_kv_caches(kv_caches)

Expand Down Expand Up @@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
block_lens=[4096 * prefill_block_size], # slot_size * block_size
kv_cache_layout="HND",
block_size=prefill_block_size,
ssm_sizes=(0, 0),
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
Expand Down Expand Up @@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
attn_backends=[backend],
tensor_shape=test_shape,
)

Expand Down
112 changes: 112 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)

# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
Expand Down Expand Up @@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]


@pytest.mark.cpu_test
def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)

worker = object.__new__(NixlConnectorWorker)

num_blocks = 100
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks

fa_blocks = [3, 5]
ssm_blocks = [1, 2]
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))

# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected = [3, 5, 103, 105, 201, 202, 301, 302]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"


@pytest.mark.cpu_test
def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)

worker = object.__new__(NixlConnectorWorker)

ratio = 4
logical_blocks = 100
num_blocks = logical_blocks * ratio # 400 kernel blocks
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker.num_descs = 2 * num_blocks # 800

fa_blocks = [3, 7] # kernel-level block IDs
ssm_blocks = [1, 2] # logical block IDs
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))

# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected = [3, 7, 403, 407, 801, 802, 901, 902]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"


@pytest.mark.cpu_test
def test_nixl_metadata_hybrid_ssm_block_ids():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
)

metadata = NixlConnectorMetadata()

# FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7]
ssm_blocks = [0, 1]

metadata.add_new_req_to_recv(
request_id="test-req-hybrid",
local_block_ids=(fa_blocks, ssm_blocks),
kv_transfer_params={
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [20, 21]),
"remote_engine_id": "remote-engine",
"remote_request_id": "prefill-test-req-hybrid",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)

assert "test-req-hybrid" in metadata.reqs_to_recv
req_meta = metadata.reqs_to_recv["test-req-hybrid"]

# Verify local block IDs: different lengths per group
assert len(req_meta.local_block_ids) == 2
assert list(req_meta.local_block_ids[0]) == fa_blocks
assert list(req_meta.local_block_ids[1]) == ssm_blocks
assert len(req_meta.local_block_ids[0]) != len(req_meta.local_block_ids[1])

# Verify remote block IDs: same asymmetry preserved
assert req_meta.remote is not None
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [20, 21]
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
Loading
Loading