diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 32652118d526..29c25bd654c5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1006,6 +1006,67 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=1) + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.NixlWrapper", + FakeNixlWrapper, + ) + def test_handshake_validates_gqa_replicated_block_len( + self, default_vllm_config, dist_init + ): + """Regression test for https://github.com/vllm-project/vllm/issues/45330. + + When tp_size > total_num_kv_heads, GQA replication caps the per-rank + KV head count at 1, so block_len stops scaling linearly with 1/tp. + With 8 KV heads and D_TP=16 pulling from P_TP=8, both sides hold one + head per rank and report the *same* block_len; the handshake + validation used to expect local_block_len * tp_ratio and reject the + valid handshake. + """ + vllm_config = create_vllm_config() + + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=16, + ): + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # 8 total KV heads: local TP=16 is capped at 1 head per rank. + worker.transfer_topo.total_num_kv_heads = 8 + worker.transfer_topo.local_physical_heads = 1 + # Head-splitting validation requires HND for heterogeneous TP. + worker.kv_cache_layout = "HND" + + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + + # Remote P with TP=8 and 8 KV heads also holds exactly one head + # per rank -> identical block_len despite tp_ratio == 2. + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=0, + num_blocks=1, + block_lens=list(worker.block_len_per_layer), + kv_cache_layout="HND", + block_size=worker.block_size, + ssm_sizes=(0, 0), + attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, + ) + + # Must validate cleanly (used to raise AssertionError expecting + # remote_block_len == local_block_len * tp_ratio). + worker.add_remote_agent(meta, remote_tp_size=8) + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl.base_worker.NixlWrapper", FakeNixlWrapper, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/base_worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/base_worker.py index e587b0cd1fa4..c5b377fa3576 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/base_worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/base_worker.py @@ -1596,7 +1596,8 @@ def _validate_remote_agent_handshake( # Per-region block_len validation enforcing the P/D invariant. # REPLICATE regions (MLA, or a whole-model MLA / replicated-KV transfer) # only allow the number of blocks to differ; SPLIT regions scale with - # tp_ratio. Mamba uses the ssm_sizes counterpart, so skip block_len here. + # the per-rank head ratio. Mamba uses the ssm_sizes counterpart, so + # skip block_len here. if not self._has_mamba: assert len(self.block_len_per_layer) == len(nixl_agent_meta.block_lens), ( "Number of KV layers must match between prefill and decode" @@ -1604,6 +1605,15 @@ def _validate_remote_agent_handshake( model_replicated = self.use_mla or self.transfer_topo.is_kv_replicated( remote_engine_id ) + # SPLIT block_lens scale with the per-rank head count, which GQA + # replication caps at 1 when tp_size > total_num_kv_heads — so + # they do not scale linearly with 1/tp beyond that point + # (issue #45330). Compare against the actual per-rank head ratio + # rather than the raw tp_ratio (identical whenever neither side + # is capped). + total_kv_heads = self.transfer_topo.total_num_kv_heads + local_heads = self.transfer_topo.local_physical_heads + remote_heads = max(1, total_kv_heads // remote_tp_size) for i, local_len in enumerate(self.block_len_per_layer): replicated = model_replicated or self._is_region_replicated(i) remote_len = nixl_agent_meta.block_lens[i] @@ -1614,20 +1624,30 @@ def _validate_remote_agent_handshake( f"remote={remote_len}, bsr={block_size_ratio})." ) elif tp_ratio > 0: - assert remote_len == (local_len * tp_ratio) // block_size_ratio, ( + # D_TP >= P_TP: remote holds remote_heads/local_heads x + # local heads (== tp_ratio when GQA replication caps + # neither side). + assert ( + remote_len + == (local_len * remote_heads // local_heads) // block_size_ratio + ), ( f"SPLIT region {i}: remote P KV block_len {remote_len} " - f"must equal local {local_len} * tp_ratio {tp_ratio} " + f"must equal local {local_len} * remote_heads " + f"{remote_heads} // local_heads {local_heads} " f"// block_size_ratio {block_size_ratio}." ) else: + # P_TP > D_TP: local holds local_heads/remote_heads x + # remote heads (== |tp_ratio|; the capped combination is + # rejected by the tp_ratio < 0 guard above). assert block_size_ratio == 1, ( "Different local/remote block sizes are not supported " "when P TP > D TP." ) - assert remote_len == local_len // (-tp_ratio), ( - f"SPLIT region {i}: remote P KV block_len " - f"{remote_len} must equal local {local_len} " - f"// |tp_ratio| {-tp_ratio}." + assert remote_len == local_len * remote_heads // local_heads, ( + f"SPLIT region {i}: remote P KV block_len {remote_len} " + f"must equal local {local_len} * remote_heads " + f"{remote_heads} // local_heads {local_heads}." ) # TP workers that handhshake with same remote have same #blocks.