Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,67 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(
worker.add_remote_agent(meta, remote_tp_size=2)
worker.add_remote_agent(meta, remote_tp_size=1)

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_validates_gqa_replicated_block_len(
self, default_vllm_config, dist_init
):
"""Regression test for https://github.com/vllm-project/vllm/issues/45330.

When tp_size > total_num_kv_heads, GQA replication caps the per-rank
KV head count at 1, so block_len stops scaling linearly with 1/tp.
With 8 KV heads and D_TP=16 pulling from P_TP=8, both sides hold one
head per rank and report the *same* block_len; the handshake
validation used to expect local_block_len * tp_ratio and reject the
valid handshake.
"""
vllm_config = create_vllm_config()

with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=16,
):
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker

# 8 total KV heads: local TP=16 is capped at 1 head per rank.
worker.transfer_topo.total_num_kv_heads = 8
worker.transfer_topo.local_physical_heads = 1
# Head-splitting validation requires HND for heterogeneous TP.
worker.kv_cache_layout = "HND"

worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks

# Remote P with TP=8 and 8 KV heads also holds exactly one head
# per rank -> identical block_len despite tp_ratio == 2.
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=list(worker.block_len_per_layer),
kv_cache_layout="HND",
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
physical_blocks_per_logical_kv_block=1,
)

# Must validate cleanly (used to raise AssertionError expecting
# remote_block_len == local_block_len * tp_ratio).
worker.add_remote_agent(meta, remote_tp_size=8)

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.NixlWrapper",
FakeNixlWrapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1596,14 +1596,24 @@ def _validate_remote_agent_handshake(
# Per-region block_len validation enforcing the P/D invariant.
# REPLICATE regions (MLA, or a whole-model MLA / replicated-KV transfer)
# only allow the number of blocks to differ; SPLIT regions scale with
# tp_ratio. Mamba uses the ssm_sizes counterpart, so skip block_len here.
# the per-rank head ratio. Mamba uses the ssm_sizes counterpart, so
# skip block_len here.
if not self._has_mamba:
assert len(self.block_len_per_layer) == len(nixl_agent_meta.block_lens), (
"Number of KV layers must match between prefill and decode"
)
model_replicated = self.use_mla or self.transfer_topo.is_kv_replicated(
remote_engine_id
)
# SPLIT block_lens scale with the per-rank head count, which GQA
# replication caps at 1 when tp_size > total_num_kv_heads — so
# they do not scale linearly with 1/tp beyond that point
# (issue #45330). Compare against the actual per-rank head ratio
# rather than the raw tp_ratio (identical whenever neither side
# is capped).
total_kv_heads = self.transfer_topo.total_num_kv_heads
local_heads = self.transfer_topo.local_physical_heads
remote_heads = max(1, total_kv_heads // remote_tp_size)
for i, local_len in enumerate(self.block_len_per_layer):
replicated = model_replicated or self._is_region_replicated(i)
remote_len = nixl_agent_meta.block_lens[i]
Expand All @@ -1614,20 +1624,30 @@ def _validate_remote_agent_handshake(
f"remote={remote_len}, bsr={block_size_ratio})."
)
elif tp_ratio > 0:
assert remote_len == (local_len * tp_ratio) // block_size_ratio, (
# D_TP >= P_TP: remote holds remote_heads/local_heads x
# local heads (== tp_ratio when GQA replication caps
# neither side).
assert (
remote_len
== (local_len * remote_heads // local_heads) // block_size_ratio
), (
f"SPLIT region {i}: remote P KV block_len {remote_len} "
f"must equal local {local_len} * tp_ratio {tp_ratio} "
f"must equal local {local_len} * remote_heads "
f"{remote_heads} // local_heads {local_heads} "
f"// block_size_ratio {block_size_ratio}."
)
else:
# P_TP > D_TP: local holds local_heads/remote_heads x
# remote heads (== |tp_ratio|; the capped combination is
# rejected by the tp_ratio < 0 guard above).
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported "
"when P TP > D TP."
)
assert remote_len == local_len // (-tp_ratio), (
f"SPLIT region {i}: remote P KV block_len "
f"{remote_len} must equal local {local_len} "
f"// |tp_ratio| {-tp_ratio}."
assert remote_len == local_len * remote_heads // local_heads, (
f"SPLIT region {i}: remote P KV block_len {remote_len} "
f"must equal local {local_len} * remote_heads "
f"{remote_heads} // local_heads {local_heads}."
)

# TP workers that handhshake with same remote have same #blocks.
Expand Down
Loading