Skip to content
8 changes: 4 additions & 4 deletions tests/v1/kv_connector/unit/test_mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
mock_thread.return_value.is_alive.return_value = False

worker.use_mla = True
worker.kv_topo.is_mla = True
worker.transfer_topo.is_mla = True

# MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
Expand Down Expand Up @@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
# Override TP rank/size to simulate P TP=2
prefill_worker.tp_rank = P_TP_RANK
prefill_worker.tp_size = P_TP_SIZE
# Update shared dict so kv_topo sees correct TP size
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
prefill_worker.kv_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_size = P_TP_SIZE

prefill_worker.kv_caches_base_addr = [0x1000]
prefill_worker.block_len_per_layer = [local_block_len]
Expand All @@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
send_meta.ready.set()

# Compute target D ranks using the production code path
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size)
target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size)

mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_socket.send_multipart = AsyncMock()
Expand Down
31 changes: 17 additions & 14 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
TransferTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
Expand Down Expand Up @@ -463,19 +463,20 @@ def __init__(
test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
tensor_shape=test_shape,
)

self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
)

def _nixl_handshake(
Expand All @@ -496,7 +497,7 @@ def _nixl_handshake(
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
Expand Down Expand Up @@ -731,8 +732,9 @@ def check_handshake(remote_tp_size: int):
assert set(remote_agents.keys()) == set(range(tp_ratio))

remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
remote_info = worker.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size
assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
Expand Down Expand Up @@ -796,7 +798,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size_mla(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.transfer_topo.tp_size = p_tp_size
worker.tp_rank = rank
worker.use_mla = True

Expand Down Expand Up @@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)

prefill_block_size = config_overrides.get("block_size", 16)
Expand Down Expand Up @@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
block_size=decode_worker.block_size,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
tensor_shape=test_shape,
Expand All @@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)

if error_scenario == "handshake_decode_error":
Expand Down
27 changes: 14 additions & 13 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids(

remote_engine_id = "remote-engine"
if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio}
worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}

# Mock kv_topo: empty remote ranks skips the transfer machinery entirely,
# isolating the block-ID expansion logic.
worker.kv_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1
# Mock transfer_topo: empty remote ranks skips the transfer machinery
# entirely, isolating the block-ID expansion logic.
worker.transfer_topo = MagicMock()
worker.transfer_topo.target_remote_ranks.return_value = []
worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1

metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
Expand Down Expand Up @@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
Expand Down Expand Up @@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800

Expand Down Expand Up @@ -532,15 +533,15 @@ def test_has_mamba_init(
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_physical_blocks_per_logical is TP-dependent.

With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
_physical_blocks_per_logical must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
compute_physical_blocks_per_logical,
)

assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio
Loading
Loading