Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,87 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=1)

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_mixed_fa_mla_hetero_tp(self, default_vllm_config, dist_init):
"""Mixed full-attn (SPLIT) + MLA (REPLICATE) single KV group under
heterogeneous TP must NOT raise (previously a NotImplementedError),
and the per-region gate must still reject a wrong block_len.
"""
vllm_config = create_vllm_config()
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2,
):
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker

# Region 0: full-attn (SPLIT). Region 1: MLA (REPLICATE).
fa_len = 4096 * worker.block_size
idx_len = 512 * worker.block_size
worker.slot_size_per_layer = [4096, 512]
worker.block_len_per_layer = [fa_len, idx_len]
worker._region_is_mla = [False, True]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [
(0, fa_len, worker.tp_rank),
(0, idx_len, worker.tp_rank),
]
worker.num_descs = len(worker.src_blocks_data)

# D_TP=2, P_TP=1 -> tp_ratio=2. SPLIT region scales by tp_ratio;
# REPLICATE region is unchanged.
tp_ratio = 2
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0, 0],
device_id=0,
num_blocks=1,
block_lens=[fa_len * tp_ratio, idx_len],
kv_cache_layout=worker.kv_cache_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
physical_blocks_per_logical_kv_block=1,
)
worker.add_remote_agent(meta, remote_tp_size=1)
assert (
FakeNixlConnectorWorker.REMOTE_ENGINE_ID in worker.dst_xfer_side_handles
)
# Gate rejects an MLA region wrongly scaled by tp_ratio.
worker2 = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker2.block_len_per_layer = [fa_len, idx_len]
worker2._region_is_mla = [False, True]
worker2.num_blocks = 1
worker2.dst_num_blocks[worker2.engine_id] = worker2.num_blocks
bad_meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0, 0],
device_id=0,
num_blocks=1,
# WRONG: MLA region scaled by tp_ratio (it should be replicated).
block_lens=[fa_len * tp_ratio, idx_len * tp_ratio],
kv_cache_layout=worker2.kv_cache_layout,
block_size=worker2.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker2.backend_name,
physical_blocks_per_logical_kv_block=1,
)
with pytest.raises(AssertionError):
worker2.add_remote_agent(bad_meta, remote_tp_size=1)


# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
Expand Down
12 changes: 11 additions & 1 deletion tests/v1/kv_connector/unit/test_tp_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,19 @@ def test_source_ranks_p_gt_d(self):


def _make_mock_worker_for_splits(group_spec_types):
"""Build a mock NixlConnectorWorker with _group_spec_types for split tests."""
"""Build a mock NixlConnectorWorker with _group_spec_types for split tests.
No per-region replicate flags are configured (``block_len_per_layer`` empty
and ``num_regions == 0``), so ``_fa_desc_replicated`` takes its early-return
path and treats every FA descriptor as SPLIT, matching the legacy behavior
these tests assert.
"""
worker = object.__new__(NixlConnectorWorker)
worker._group_spec_types = group_spec_types
worker.transfer_topo = SimpleNamespace(virtually_split_kv_in_blocks=False)
worker.block_len_per_layer = []
worker.num_regions = 0
worker._region_is_mla = []
return worker


Expand Down
Loading
Loading