Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5e0d078
move mamba-specifc states to mamba-engine-info; add policy class
ZhanqiuHu Apr 16, 2026
d7e09af
wire policy into TransferTopology, delegate build_engine_transfer_info
ZhanqiuHu Apr 16, 2026
f873894
decouple policy from topology; extract orchestration methods
ZhanqiuHu Apr 17, 2026
3f0d3ac
extract descriptor building, block ID mapping, and read specs into Mo…
ZhanqiuHu Apr 17, 2026
1d58310
updates
ZhanqiuHu Apr 17, 2026
916d18b
restore original comments; remove redundant _physical_blocks_per_logi…
ZhanqiuHu Apr 18, 2026
0146b75
policy ABC cleanup: static helpers, abstract methods, remove dead code
ZhanqiuHu Apr 20, 2026
ff47d5d
revert comment to reference get_backend_aware_kv_block_len
ZhanqiuHu Apr 21, 2026
1e807a8
always pass transfer_config to policy methods, remove has_mamba guards
ZhanqiuHu Apr 21, 2026
4c4bd44
remove abs_tp from build_src_split_handles args; compute from tp_size…
ZhanqiuHu Apr 21, 2026
f6687a0
reorder build_local_descs args: memory, block geometry, layout
ZhanqiuHu Apr 21, 2026
5d67a63
clean up build_src_split_handles: rename transfer_config to remote_in…
ZhanqiuHu Apr 21, 2026
d057e93
clean up build_remote_descs: rename args, internalize indexes_into_re…
ZhanqiuHu Apr 21, 2026
5637058
rename transfer_config to remote_info, inline _get_block_descs_ids, u…
ZhanqiuHu Apr 21, 2026
a8631c3
extract _fa_descs_ids static helper to deduplicate FA descriptor ID c…
ZhanqiuHu Apr 21, 2026
fd90b2c
add FA replication notes on ABC helpers and _build_fa_remote_descs
ZhanqiuHu Apr 21, 2026
00f845c
fix unit tests: adapt to _get_block_descs_ids inlining and _physical_…
ZhanqiuHu Apr 21, 2026
e2f758e
fix unit tests: restore meta.remote.block_ids mutation, mock transfer…
ZhanqiuHu Apr 21, 2026
8482257
consolidate local topology into TransferTopology for build_engine_tra…
ZhanqiuHu Apr 22, 2026
c72b3db
consolidate build_remote_descs params via transfer_topo + engine_id (…
ZhanqiuHu Apr 22, 2026
91717b5
consolidate build_src_split_handles params via transfer_topo + engine…
ZhanqiuHu Apr 22, 2026
d9d877d
fix mooncake: compute physical_blocks_per_logical for TransferTopology
ZhanqiuHu Apr 22, 2026
e7dc5fa
inline _fa_descs_ids, move static methods to module-level private utils
ZhanqiuHu Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def __init__(
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block,
tensor_shape=test_shape,
)

Expand Down Expand Up @@ -726,6 +727,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size(
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
worker.num_descs = len(worker.src_blocks_data)

def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
Expand Down Expand Up @@ -2435,6 +2437,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
physical_blocks_per_logical=decode_worker._physical_blocks_per_logical_kv_block,
tensor_shape=test_shape,
)

Expand Down
71 changes: 37 additions & 34 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,19 @@ def test_read_blocks_for_req_expands_remote_ids(
)

remote_engine_id = "remote-engine"
if has_mamba:
worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}

# Mock transfer_topo: empty remote ranks skips the transfer machinery
# entirely, isolating the block-ID expansion logic.
worker.transfer_topo = MagicMock()
worker.transfer_topo.target_remote_ranks.return_value = []
worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.get_engine_info.return_value = MagicMock(
remote_tp_size=1,
remote_physical_blocks_per_logical=remote_ratio if has_mamba else 1,
)
worker.transfer_topo.tp_ratio.return_value = 1
worker.transfer_policy = MagicMock()
worker.transfer_policy.compute_read_specs.return_value = []
worker.use_mla = False

metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
Expand Down Expand Up @@ -308,29 +312,28 @@ def test_nixl_metadata_hma_block_ids_structure():

@pytest.mark.cpu_test
def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
"""Test get_block_descs_ids uses per-group strides for hybrid
FA+SSM when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501
MambaModelBlockTransferPolicy,
)

worker = object.__new__(NixlConnectorWorker)
policy = object.__new__(MambaModelBlockTransferPolicy)
policy._is_mamba_group = [False, True]

num_blocks = 100
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
num_regions = 2
block_len_per_layer = [100]

fa_blocks = [3, 5]
ssm_blocks = [1, 2]
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
result = policy.get_block_descs_ids(
block_ids=(fa_blocks, ssm_blocks),
num_regions=num_regions,
dst_num_blocks=num_blocks,
block_len_per_layer=block_len_per_layer,
physical_blocks_per_logical=1,
)

# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
Expand All @@ -344,30 +347,30 @@ def test_get_block_descs_ids_hybrid_ssm():

@pytest.mark.cpu_test
def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
"""Test get_block_descs_ids uses different strides for FA
(kernel blocks) vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501
MambaModelBlockTransferPolicy,
)

worker = object.__new__(NixlConnectorWorker)
policy = object.__new__(MambaModelBlockTransferPolicy)
policy._is_mamba_group = [False, True]

ratio = 4
logical_blocks = 100
num_blocks = logical_blocks * ratio # 400 kernel blocks
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
num_regions = 2
block_len_per_layer = [100]

fa_blocks = [3, 7] # kernel-level block IDs
ssm_blocks = [1, 2] # logical block IDs
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
result = policy.get_block_descs_ids(
block_ids=(fa_blocks, ssm_blocks),
num_regions=num_regions,
dst_num_blocks=num_blocks,
block_len_per_layer=block_len_per_layer,
physical_blocks_per_logical=ratio,
)

# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
Expand Down
Loading
Loading