Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
659826a
move mamba-specifc states to mamba-engine-info; add policy class
ZhanqiuHu Apr 16, 2026
f2524c0
wire policy into TransferTopology, delegate build_engine_transfer_info
ZhanqiuHu Apr 16, 2026
d3b618a
decouple policy from topology; extract orchestration methods
ZhanqiuHu Apr 17, 2026
6bde491
extract descriptor building, block ID mapping, and read specs into Mo…
ZhanqiuHu Apr 17, 2026
6961ae0
updates
ZhanqiuHu Apr 17, 2026
00a44dc
restore original comments; remove redundant _physical_blocks_per_logi…
ZhanqiuHu Apr 18, 2026
3891e97
policy ABC cleanup: static helpers, abstract methods, remove dead code
ZhanqiuHu Apr 20, 2026
3cea947
revert comment to reference get_backend_aware_kv_block_len
ZhanqiuHu Apr 21, 2026
2a1e6c9
always pass transfer_config to policy methods, remove has_mamba guards
ZhanqiuHu Apr 21, 2026
2a2323e
remove abs_tp from build_src_split_handles args; compute from tp_size…
ZhanqiuHu Apr 21, 2026
fea54de
reorder build_local_descs args: memory, block geometry, layout
ZhanqiuHu Apr 21, 2026
0dd4f92
clean up build_src_split_handles: rename transfer_config to remote_in…
ZhanqiuHu Apr 21, 2026
93f21c4
clean up build_remote_descs: rename args, internalize indexes_into_re…
ZhanqiuHu Apr 21, 2026
4bdfe0e
rename transfer_config to remote_info, inline _get_block_descs_ids, u…
ZhanqiuHu Apr 21, 2026
7456cfd
extract _fa_descs_ids static helper to deduplicate FA descriptor ID c…
ZhanqiuHu Apr 21, 2026
b1765d0
add FA replication notes on ABC helpers and _build_fa_remote_descs
ZhanqiuHu Apr 21, 2026
655a9eb
fix unit tests: adapt to _get_block_descs_ids inlining and _physical_…
ZhanqiuHu Apr 21, 2026
6a6087d
fix unit tests: restore meta.remote.block_ids mutation, mock transfer…
ZhanqiuHu Apr 21, 2026
6deae9f
consolidate local topology into TransferTopology for build_engine_tra…
ZhanqiuHu Apr 22, 2026
4b72036
consolidate build_remote_descs params via transfer_topo + engine_id (…
ZhanqiuHu Apr 22, 2026
5b28f28
consolidate build_src_split_handles params via transfer_topo + engine…
ZhanqiuHu Apr 22, 2026
81e82fb
fix mooncake: compute physical_blocks_per_logical for TransferTopology
ZhanqiuHu Apr 22, 2026
d08ab9b
inline _fa_descs_ids, move static methods to module-level private utils
ZhanqiuHu Apr 22, 2026
925a9bd
[3/N] extract model-specific block expansion into plan-based logical_…
ZhanqiuHu Apr 23, 2026
7bca1ca
restore _logical_to_remote_kernel_block_ids as separate method
ZhanqiuHu Apr 23, 2026
e7c59c8
eliminate if _has_mamba branch from hot path via remote_expansion_stride
ZhanqiuHu Apr 23, 2026
a129ae8
introduce GroupKind enum to replace is_mamba_group boolean
ZhanqiuHu Apr 23, 2026
f03a17f
refactor
ZhanqiuHu Apr 23, 2026
d0c4802
remove dead fields, visualization, and revert block ID helpers to main
ZhanqiuHu Apr 23, 2026
e9e8a96
pass remote info/meta objects to plan generators and remove dead code
ZhanqiuHu Apr 23, 2026
93fa814
fix mypy: explicit args instead of dict unpacking, remove stale moonc…
ZhanqiuHu Apr 23, 2026
00dce93
pass transfer_topo to plan generators, fix dead code and stale docstring
ZhanqiuHu Apr 23, 2026
558d528
fix test: set kv_cache_config on mock worker for remote block ID expa…
ZhanqiuHu Apr 23, 2026
be301ac
update
ZhanqiuHu Apr 24, 2026
9d4ffbe
fix: pre-commit lint (unused var, line length, formatting)
ZhanqiuHu Apr 24, 2026
7b8922b
clean
ZhanqiuHu Apr 24, 2026
72ece82
rename
ZhanqiuHu Apr 24, 2026
44caacd
test case
ZhanqiuHu Apr 26, 2026
c3a5c65
update test
ZhanqiuHu Apr 26, 2026
bf52923
test
ZhanqiuHu Apr 26, 2026
0dc8e33
updates
ZhanqiuHu Apr 27, 2026
f8a01e6
fix: add Mamba guard to block ID trimming in _read_blocks
ZhanqiuHu Apr 27, 2026
a6e5266
updates
ZhanqiuHu Apr 27, 2026
2c920b5
add gemma4 heterotp support for NIXL KV transfer
ZhanqiuHu Apr 26, 2026
594e1ab
add gather-read support for Gemma4 HeteroTP NIXL transfer
ZhanqiuHu Apr 28, 2026
06a31ce
rename sub_desc terminology, add gather-read pairing and assertions
ZhanqiuHu Apr 28, 2026
e3d5ebb
gemma4
ZhanqiuHu Apr 28, 2026
56d12e7
fix tests
ZhanqiuHu Apr 29, 2026
80487f1
fix
ZhanqiuHu Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def _nixl_handshake(
block_size=self.block_size,
ssm_sizes=(0, 0),
attn_backend_name=self.backend_name,
physical_blocks_per_logical_kv_block=1,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
Expand Down Expand Up @@ -726,6 +727,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size(
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
worker.num_descs = len(worker.src_blocks_data)

def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
Expand Down Expand Up @@ -978,6 +980,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
physical_blocks_per_logical_kv_block=1,
)

with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -1035,6 +1038,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
physical_blocks_per_logical_kv_block=1,
)

# We don't check layout for homogeneous TP and MLA for now, as the
Expand Down Expand Up @@ -2354,6 +2358,7 @@ def test_compatibility_hash_validation(
block_size=prefill_block_size,
ssm_sizes=(0, 0),
attn_backend_name=decode_worker.backend_name,
physical_blocks_per_logical_kv_block=1,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
Expand Down
158 changes: 80 additions & 78 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,75 +93,80 @@ def test_logical_to_kernel_block_ids_with_hma():

@pytest.mark.cpu_test
@pytest.mark.parametrize(
"has_mamba,swa_enabled,mamba_enabled,remote_ratio,"
"remote_block_ids,expected_remote_block_ids",
"group_spec_types,expansion_stride,remote_block_ids,expected_remote_block_ids",
[
# Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids.
# Regression for https://github.com/vllm-project/vllm/pull/39724
(
False,
True,
False,
1,
pytest.param(
("FullAttentionSpec", "SlidingWindowSpec"),
2,
([0, 1, 2], [3, 4]),
[[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]],
id="dense_fa_swa",
),
# Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids,
# Mamba passed through unchanged.
# remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using
# the wrong conversion method produces different FA results.
(
True,
False,
True,
pytest.param(
("FullAttentionSpec", "MambaSpec"),
261,
([0, 1, 2], [10, 11]),
[[0, 1, 261, 262, 522, 523], [10, 11]],
id="mamba_fa_ssm",
),
],
ids=["non_mamba_fa_swa", "mamba_fa_ssm"],
)
def test_read_blocks_for_req_expands_remote_ids(
has_mamba,
swa_enabled,
mamba_enabled,
remote_ratio,
group_spec_types,
expansion_stride,
remote_block_ids,
expected_remote_block_ids,
):
"""_read_blocks_for_req must expand remote logical block IDs to kernel
block IDs when kernel block size != logical block size.

Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded).
Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba
passed through).
The hot path always calls _logical_to_remote_kernel_block_ids with
plan.remote_expansion_stride (model-agnostic).
"""
from unittest.mock import MagicMock

from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import (
EngineTransferPlan,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
SlidingWindowSpec,
)

spec_name_to_type = {
"FullAttentionSpec": FullAttentionSpec,
"SlidingWindowSpec": SlidingWindowSpec,
"MambaSpec": MambaSpec,
}
resolved_types = tuple(spec_name_to_type[n] for n in group_spec_types)

worker = object.__new__(NixlConnectorWorker)
worker._has_mamba = has_mamba
worker._physical_blocks_per_logical_kv_block = 2

has_mamba = any(t is MambaSpec for t in resolved_types)
has_swa = any(t is SlidingWindowSpec for t in resolved_types)
worker.kv_cache_config = make_kv_cache_config(
block_size=16, swa_enabled=swa_enabled, mamba_enabled=mamba_enabled
block_size=16, swa_enabled=has_swa, mamba_enabled=has_mamba
)

remote_engine_id = "remote-engine"
if has_mamba:
worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}

# Mock transfer_topo: empty remote ranks skips the transfer machinery
# entirely, isolating the block-ID expansion logic.
worker.transfer_topo = MagicMock()
worker.transfer_topo.target_remote_ranks.return_value = []
worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1
worker.use_mla = False

mock_plan = MagicMock(spec=EngineTransferPlan)
mock_plan.remote_expansion_stride = expansion_stride
mock_plan.all_source_ranks = ()
mock_plan.source_ranks_per_group = ()
worker._transfer_plans = {remote_engine_id: mock_plan}

metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
Expand Down Expand Up @@ -308,73 +313,70 @@ def test_nixl_metadata_hma_block_ids_structure():

@pytest.mark.cpu_test
def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
"""Test _compute_desc_ids uses per-group strides for hybrid
FA+SSM when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec

worker = object.__new__(NixlConnectorWorker)
from .test_transfer_plan import _make_mamba_plan_for_desc_ids

num_blocks = 100
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
plan = _make_mamba_plan_for_desc_ids(
num_fa_regions=2,
num_ssm_regions=4,
group_spec_types=(FullAttentionSpec, MambaSpec),
fa_num_blocks=100,
ssm_num_blocks=100,
)

fa_blocks = [3, 5]
ssm_blocks = [1, 2]
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))

# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302],
# region2: [401, 402], region3: [501, 502]
result = NixlConnectorWorker._compute_desc_ids_from_plan(
plan,
block_ids=(fa_blocks, ssm_blocks),
dst_num_blocks=100,
block_size_ratio=None,
physical_blocks_per_logical=1,
)

expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"


@pytest.mark.cpu_test
def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
"""Test _compute_desc_ids uses different strides for FA
(kernel blocks) vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec

worker = object.__new__(NixlConnectorWorker)
from .test_transfer_plan import _make_mamba_plan_for_desc_ids

ratio = 4
logical_blocks = 100
num_blocks = logical_blocks * ratio # 400 kernel blocks
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800

fa_blocks = [3, 7] # kernel-level block IDs
ssm_blocks = [1, 2] # logical block IDs
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))

# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# 4 regions per Mamba layer (x, B, C, ssm)
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]

plan = _make_mamba_plan_for_desc_ids(
num_fa_regions=2,
num_ssm_regions=4,
group_spec_types=(FullAttentionSpec, MambaSpec),
fa_num_blocks=num_blocks,
ssm_num_blocks=logical_blocks,
)

fa_blocks = [3, 7]
ssm_blocks = [1, 2]
result = NixlConnectorWorker._compute_desc_ids_from_plan(
plan,
block_ids=(fa_blocks, ssm_blocks),
dst_num_blocks=num_blocks,
block_size_ratio=None,
physical_blocks_per_logical=ratio,
)

expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"

Expand Down
Loading
Loading