diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fb4b641e1376..3803e4fd3869 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -527,6 +527,7 @@ def _nixl_handshake( block_size=self.block_size, ssm_sizes=(0, 0), attn_backend_name=self.backend_name, + physical_blocks_per_logical_kv_block=1, ), remote_tp_rank=remote_tp_rank, remote_tp_size=remote_tp_size, @@ -726,6 +727,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + worker.num_descs = len(worker.src_blocks_data) def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size @@ -978,6 +980,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( block_size=worker.block_size, ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) with pytest.raises(RuntimeError): @@ -1035,6 +1038,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( block_size=worker.block_size, ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -2354,6 +2358,7 @@ def test_compatibility_hash_validation( block_size=prefill_block_size, ssm_sizes=(0, 0), attn_backend_name=decode_worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) handshake_payload = NixlHandshakePayload( compatibility_hash=remote_hash, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 3f5a9b9cc031..127db16f2eb5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -93,75 +93,80 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "has_mamba,swa_enabled,mamba_enabled,remote_ratio," - "remote_block_ids,expected_remote_block_ids", + "group_spec_types,expansion_stride,remote_block_ids,expected_remote_block_ids", [ - # Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids. - # Regression for https://github.com/vllm-project/vllm/pull/39724 - ( - False, - True, - False, - 1, + pytest.param( + ("FullAttentionSpec", "SlidingWindowSpec"), + 2, ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], + id="dense_fa_swa", ), - # Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids, - # Mamba passed through unchanged. - # remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using - # the wrong conversion method produces different FA results. - ( - True, - False, - True, + pytest.param( + ("FullAttentionSpec", "MambaSpec"), 261, ([0, 1, 2], [10, 11]), [[0, 1, 261, 262, 522, 523], [10, 11]], + id="mamba_fa_ssm", ), ], - ids=["non_mamba_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - has_mamba, - swa_enabled, - mamba_enabled, - remote_ratio, + group_spec_types, + expansion_stride, remote_block_ids, expected_remote_block_ids, ): """_read_blocks_for_req must expand remote logical block IDs to kernel block IDs when kernel block size != logical block size. - Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded). - Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba - passed through). + The hot path always calls _logical_to_remote_kernel_block_ids with + plan.remote_expansion_stride (model-agnostic). """ from unittest.mock import MagicMock from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( NixlConnectorMetadata, ) + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, + ) + + spec_name_to_type = { + "FullAttentionSpec": FullAttentionSpec, + "SlidingWindowSpec": SlidingWindowSpec, + "MambaSpec": MambaSpec, + } + resolved_types = tuple(spec_name_to_type[n] for n in group_spec_types) worker = object.__new__(NixlConnectorWorker) - worker._has_mamba = has_mamba worker._physical_blocks_per_logical_kv_block = 2 + + has_mamba = any(t is MambaSpec for t in resolved_types) + has_swa = any(t is SlidingWindowSpec for t in resolved_types) worker.kv_cache_config = make_kv_cache_config( - block_size=16, swa_enabled=swa_enabled, mamba_enabled=mamba_enabled + block_size=16, swa_enabled=has_swa, mamba_enabled=has_mamba ) remote_engine_id = "remote-engine" - if has_mamba: - worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio} - # Mock transfer_topo: empty remote ranks skips the transfer machinery - # entirely, isolating the block-ID expansion logic. worker.transfer_topo = MagicMock() - worker.transfer_topo.target_remote_ranks.return_value = [] - worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1) worker.transfer_topo.tp_ratio.return_value = 1 + worker.use_mla = False + + mock_plan = MagicMock(spec=EngineTransferPlan) + mock_plan.remote_expansion_stride = expansion_stride + mock_plan.all_source_ranks = () + mock_plan.source_ranks_per_group = () + worker._transfer_plans = {remote_engine_id: mock_plan} metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -308,73 +313,70 @@ def test_nixl_metadata_hma_block_ids_structure(): @pytest.mark.cpu_test def test_get_block_descs_ids_hybrid_ssm(): - """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM - when ratio=1 (no kernel block size mismatch).""" + """Test _compute_desc_ids uses per-group strides for hybrid + FA+SSM when ratio=1 (no kernel block size mismatch).""" from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec - worker = object.__new__(NixlConnectorWorker) + from .test_transfer_plan import _make_mamba_plan_for_desc_ids - num_blocks = 100 - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = 1 - worker._physical_blocks_per_logical = {engine_id: 1} - worker.block_len_per_layer = [100] - # num_descs = num_regions * num_blocks (no blocks_first doubling) - worker.num_descs = 2 * num_blocks + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + group_spec_types=(FullAttentionSpec, MambaSpec), + fa_num_blocks=100, + ssm_num_blocks=100, + ) fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) - - # FA group: stride=num_blocks=100, offset=0 - # region0: [3, 5], region1: [103, 105] - # SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1), - # offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm) - # region0: [201, 202], region1: [301, 302], - # region2: [401, 402], region3: [501, 502] + result = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=100, + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502] assert list(result) == expected, f"Expected {expected}, got {list(result)}" @pytest.mark.cpu_test def test_get_block_descs_ids_kernel_block_mismatch(): - """Test _get_block_descs_ids uses different strides for FA (kernel blocks) - vs SSM (logical blocks) when ratio > 1.""" + """Test _compute_desc_ids uses different strides for FA + (kernel blocks) vs SSM (logical blocks) when ratio > 1.""" from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec - worker = object.__new__(NixlConnectorWorker) + from .test_transfer_plan import _make_mamba_plan_for_desc_ids ratio = 4 logical_blocks = 100 num_blocks = logical_blocks * ratio # 400 kernel blocks - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = ratio - worker._physical_blocks_per_logical = {engine_id: ratio} - worker.block_len_per_layer = [100] - worker.num_descs = 2 * num_blocks # 800 - - fa_blocks = [3, 7] # kernel-level block IDs - ssm_blocks = [1, 2] # logical block IDs - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) - - # FA group: stride=num_blocks=400, offset=0 - # region0: [3, 7], region1: [403, 407] - # SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800, - # 4 regions per Mamba layer (x, B, C, ssm) - # region0: [801, 802], region1: [901, 902], - # region2: [1001, 1002], region3: [1101, 1102] + + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + group_spec_types=(FullAttentionSpec, MambaSpec), + fa_num_blocks=num_blocks, + ssm_num_blocks=logical_blocks, + ) + + fa_blocks = [3, 7] + ssm_blocks = [1, 2] + result = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=num_blocks, + block_size_ratio=None, + physical_blocks_per_logical=ratio, + ) + expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102] assert list(result) == expected, f"Expected {expected}, got {list(result)}" diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py new file mode 100644 index 000000000000..fc7018fa2c54 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -0,0 +1,1002 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for plan-based transfer executors. + +These tests verify that the plan-based design produces correct +outputs (descriptor tuples, descriptor IDs, read specs, split handles). +No GPU or NIXL required. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, + TransferTopology, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + RegionPlan, + generate_dense_plan, + generate_gemma4_plan, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + +# ====================================================================== +# Test fixtures / helpers +# ====================================================================== + +ENGINE_ID = "remote_engine" + + +@dataclass +class FakeNixlAgentMeta: + """Minimal mock of NixlAgentMetadata for testing.""" + + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + device_id: int + num_blocks: int + block_lens: list[int] + kv_cache_layout: str + block_size: int + ssm_sizes: tuple[int, int] + attn_backend_name: str + + +def _make_fake_topo( + tp_rank: int = 0, + tp_size: int = 1, + is_mla: bool = False, + total_num_kv_heads: int = 8, + block_size: int = 16, + is_blocks_first: bool = False, +) -> TransferTopology: + """Build a lightweight TransferTopology mock (skips __post_init__).""" + topo = MagicMock(spec=TransferTopology) + topo.tp_rank = tp_rank + topo.tp_size = tp_size + topo.is_mla = is_mla + topo.total_num_kv_heads = total_num_kv_heads + topo.block_size = block_size + topo.is_kv_layout_blocks_first = is_blocks_first + return topo + + +def _common_plan_params( + tp_rank: int = 0, + tp_size: int = 1, + is_mla: bool = False, + num_kv_heads: int = 8, + block_size: int = 16, + is_blocks_first: bool = False, + block_len_per_layer: list[int] | None = None, + remote_tp_size: int = 1, + remote_block_size: int = 16, + remote_num_blocks: int = 256, + remote_block_lens: list[int] | None = None, + remote_physical_blocks_per_logical: int = 1, + local_physical_blocks_per_logical: int = 1, +) -> dict: + """Build common kwargs for plan generators.""" + if block_len_per_layer is None: + slot_size = num_kv_heads * 128 * 2 # num_heads * head_size * dtype_bytes + block_len_per_layer = [slot_size * block_size] * 2 + if remote_block_lens is None: + remote_block_lens = list(block_len_per_layer) + return dict( + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=is_mla, + total_num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + ), + block_len_per_layer=block_len_per_layer, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0] * len(block_len_per_layer), + num_blocks=remote_num_blocks, + block_lens=remote_block_lens, + block_size=remote_block_size, + ), + group_spec_types=(FullAttentionSpec,), + local_physical_blocks_per_logical=local_physical_blocks_per_logical, + ) + + +def _make_nixl_meta( + base_addrs: list[int], + num_blocks: int, + block_lens: list[int], + device_id: int = 0, + block_size: int = 16, +) -> FakeNixlAgentMeta: + return FakeNixlAgentMeta( + engine_id=ENGINE_ID, + agent_metadata=b"", + kv_caches_base_addr=base_addrs, + device_id=device_id, + num_blocks=num_blocks, + block_lens=block_lens, + kv_cache_layout="HND", + block_size=block_size, + ssm_sizes=(0, 0), + attn_backend_name="FlashAttentionBackend", + ) + + +# ====================================================================== +# Dense equivalence tests +# ====================================================================== + + +class TestDensePlanExecutors: + """Verify plan-based executors produce correct outputs for dense models.""" + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [ + (1, 1), + (2, 1), + (4, 2), + (1, 2), + (2, 4), + ], + ) + @pytest.mark.parametrize("tp_rank_frac", [0.0, 0.5]) + def test_build_remote_descs(self, tp_size, remote_tp_size, tp_rank_frac): + tp_rank = int(tp_rank_frac * (tp_size - 1)) if tp_size > 1 else 0 + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + base_addrs = [0x1000 * (i + 1) for i in range(num_layers)] + plan = generate_dense_plan( + **_common_plan_params( + tp_rank=tp_rank, + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + meta = _make_nixl_meta( + base_addrs, num_blocks, remote_block_lens, block_size=block_size + ) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) + + expected_count = len(plan.fa_regions) * num_blocks + assert len(descs) == expected_count + for addr, length, dev in descs: + assert length > 0 + assert dev == 0 + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [(1, 1), (2, 1), (1, 2)], + ) + def test_compute_desc_ids(self, tp_size, remote_tp_size): + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + plan = generate_dense_plan( + **_common_plan_params( + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + block_ids = ([1, 5, 10, 20],) + ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids, + dst_num_blocks=num_blocks, + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + + num_regions = len(plan.fa_regions) + assert len(ids) == num_regions * len(block_ids[0]) + assert ids[0] == 1 + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [(1, 1), (2, 1), (1, 2)], + ) + def test_compute_read_specs(self, tp_size, remote_tp_size): + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + plan = generate_dense_plan( + **_common_plan_params( + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + local_ids = ([1, 2, 3],) + remote_ids = ([4, 5, 6],) + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) + + assert len(specs) == len(plan.all_source_ranks) + for spec in specs: + assert list(spec.local_block_ids[0]) == [1, 2, 3] + assert list(spec.remote_block_ids[0]) == [4, 5, 6] + + @pytest.mark.parametrize("remote_tp_size", [2, 4]) + def test_build_src_split_handles(self, remote_tp_size): + tp_rank = 0 + tp_size = 1 + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + plan = generate_dense_plan( + **_common_plan_params( + tp_rank=tp_rank, + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)] + num_descs = len(src_blocks_data) + splits = NixlConnectorWorker._build_local_splits_from_plan( + plan, + src_blocks_data, + num_descs, + ) + + assert len(splits) == remote_tp_size + for handle in splits: + assert len(handle) == len(src_blocks_data) + for _, length, _ in handle: + assert length == 1024 // remote_tp_size + + +class TestDensePlanStructure: + def test_source_ranks_homogeneous(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=2, tp_rank=1, remote_tp_size=2), + ) + assert plan.all_source_ranks == (1,) + + def test_source_ranks_d_gt_p(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=4, tp_rank=2, remote_tp_size=2), + ) + assert plan.all_source_ranks == (1,) + + def test_source_ranks_p_gt_d(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=1, tp_rank=0, remote_tp_size=2), + ) + assert plan.all_source_ranks == (0, 1) + + def test_no_ssm_regions(self): + plan = generate_dense_plan(**_common_plan_params()) + assert plan.ssm_regions == () + assert plan.group_spec_types == (FullAttentionSpec,) + + def test_blocks_first_has_k_and_v(self): + plan = generate_dense_plan( + **_common_plan_params(is_blocks_first=True), + ) + num_layers = 2 + assert len(plan.fa_regions) == num_layers * 2 # K + V per layer + + def test_not_blocks_first_has_only_k(self): + plan = generate_dense_plan( + **_common_plan_params(is_blocks_first=False), + ) + num_layers = 2 + assert len(plan.fa_regions) == num_layers # K only per layer + + +# ====================================================================== +# Mamba equivalence tests +# ====================================================================== + + +def _make_mamba_plan_for_desc_ids( + num_fa_regions: int, + num_ssm_regions: int, + group_spec_types: tuple[type, ...], + fa_num_blocks: int = 100, + ssm_num_blocks: int = 100, +) -> EngineTransferPlan: + """Build a minimal plan with enough structure for compute_desc_ids.""" + fa_regions = tuple( + RegionPlan( + layer_idx=i, + descriptor_bytes=100, + offset_in_page=0, + page_stride=100, + num_blocks=fa_num_blocks, + ) + for i in range(num_fa_regions) + ) + ssm_regions = tuple( + RegionPlan( + layer_idx=i % (num_ssm_regions // 4) if num_ssm_regions >= 4 else 0, + descriptor_bytes=50, + offset_in_page=0, + page_stride=200, + num_blocks=ssm_num_blocks, + ) + for i in range(num_ssm_regions) + ) + all_ranks = (0,) + source_ranks_per_group = tuple(all_ranks for _ in group_spec_types) + return EngineTransferPlan( + fa_regions=fa_regions, + ssm_regions=ssm_regions, + group_spec_types=group_spec_types, + source_ranks_per_group=source_ranks_per_group, + all_source_ranks=(0,), + rank_to_attention_slot=({0: 0},) * len(group_spec_types), + remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, + ) + + +class TestMambaPlanDescIds: + """Verify plan-based desc IDs for hybrid FA+SSM models.""" + + def test_hybrid_ssm_ratio_1(self): + """Equivalent to test_get_block_descs_ids_hybrid_ssm.""" + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, # 4 regions per layer, 1 layer + group_spec_types=(FullAttentionSpec, MambaSpec), + fa_num_blocks=100, + ssm_num_blocks=100, + ) + + fa_blocks = [3, 5] + ssm_blocks = [1, 2] + + result = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=100, + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + + expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + def test_kernel_block_mismatch(self): + """Equivalent to test_get_block_descs_ids_kernel_block_mismatch.""" + ratio = 4 + logical_blocks = 100 + num_blocks = logical_blocks * ratio # 400 + + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + group_spec_types=(FullAttentionSpec, MambaSpec), + fa_num_blocks=num_blocks, + ssm_num_blocks=logical_blocks, + ) + + fa_blocks = [3, 7] + ssm_blocks = [1, 2] + + result = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=num_blocks, + block_size_ratio=None, + physical_blocks_per_logical=ratio, + ) + + expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + +class TestMambaPlanReadSpecs: + """Verify plan-based read specs handle FA group filtering correctly.""" + + def test_all_source_ranks_serve_fa(self): + """When all ranks are FA sources, no filtering happens.""" + both = (0, 1) + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=(), + group_spec_types=(FullAttentionSpec, MambaSpec), + source_ranks_per_group=(both, both), + all_source_ranks=(0, 1), + rank_to_attention_slot=({0: 0, 1: 1}, {0: 0, 1: 1}), + remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, + ) + + local_ids = ([1, 2], [3, 4]) + remote_ids = ([5, 6], [7, 8]) + + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) + assert len(specs) == 2 + for spec in specs: + assert list(spec.local_block_ids[0]) == [1, 2] + assert list(spec.local_block_ids[1]) == [3, 4] + + def test_non_fa_rank_skips_fa_groups(self): + """Ranks not in source_ranks_per_group get groups zeroed out.""" + fa_readers = (0,) + ssm_readers = (0, 1, 2) + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=(), + group_spec_types=(FullAttentionSpec, MambaSpec), + source_ranks_per_group=(fa_readers, ssm_readers), + all_source_ranks=(0, 1, 2), + rank_to_attention_slot=({0: 0}, {0: 0}), + remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, + ) + + local_ids = ([1, 2], [3, 4]) + remote_ids = ([5, 6], [7, 8]) + + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) + assert len(specs) == 3 + + # Rank 0 (FA source): gets all groups + assert list(specs[0].local_block_ids[0]) == [1, 2] + assert list(specs[0].local_block_ids[1]) == [3, 4] + + # Rank 1 (not FA): FA group zeroed, Mamba group preserved + assert specs[1].local_block_ids[0] == [] + assert list(specs[1].local_block_ids[1]) == [3, 4] + + # Rank 2 (not FA): same + assert specs[2].local_block_ids[0] == [] + assert list(specs[2].local_block_ids[1]) == [3, 4] + + +class TestMambaPlanSplitHandles: + """Verify plan-based split handles for Mamba with FA/SSM distinction.""" + + def test_fa_and_ssm_different_split_factors(self): + """Section 0 split by num_attn_reads, section 1 by abs_tp.""" + fa_readers = (0,) + ssm_readers = (0, 1) + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=( + RegionPlan( + layer_idx=0, + descriptor_bytes=100, + offset_in_page=0, + page_stride=100, + num_blocks=10, + ), + ), + group_spec_types=(FullAttentionSpec, MambaSpec), + source_ranks_per_group=(fa_readers, ssm_readers), + all_source_ranks=(0, 1), + rank_to_attention_slot=({0: 0, 1: 0}, {0: 0, 1: 0}), + remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, + ) + + # 2 FA descs + 1 SSM desc + src_blocks_data = [ + (1000, 200, 0), # FA desc 0 + (2000, 200, 0), # FA desc 1 + (3000, 400, 0), # SSM desc 0 + ] + + splits = NixlConnectorWorker._build_local_splits_from_plan( + plan, src_blocks_data, 2 + ) + + assert len(splits) == 2 # 2 source ranks + + # Rank 0 (FA source, p_idx=0): + # FA: chunk=200//1=200, slot=0 → (1000, 200, 0), (2000, 200, 0) + # SSM: chunk=400//2=200, idx=0 → (3000, 200, 0) + assert splits[0] == [(1000, 200, 0), (2000, 200, 0), (3000, 200, 0)] + + # Rank 1 (not FA source, p_idx=1): + # FA: chunk=200//1=200, slot=0 (skip_fa) → (1000, 200, 0), (2000, 200, 0) + # SSM: chunk=400//2=200, idx=1 → (3200, 200, 0) + assert splits[1] == [(1000, 200, 0), (2000, 200, 0), (3200, 200, 0)] + + +# ====================================================================== +# Gemma4 HeteroTP tests +# ====================================================================== + + +def _make_gemma4_plan_params( + tp_rank: int = 0, + tp_size: int = 4, + remote_tp_size: int = 2, +) -> dict: + """Build kwargs for generate_gemma4_plan at 2p4d. + + Gemma4-26B at P_TP=2, D_TP=4: + SWA: 25 layers, K=8, head_dim=256, block_size=16 on both sides + FA: 5 layers, K=2, head_dim=512, P block_size=32, D block_size=16 + + With page unification + HMA, all groups share one physical pool. + page_size: P=65536, D=32768 → remote_page > local_page (split-read). + For simplicity, use 2 physical layers in tests. + """ + # D side (local): kv_heads_per_rank for all groups = page_size / block_size + # page_size = 32768 for both groups at D_TP=4. + d_page = 32768 + p_page = 65536 + num_layers = 2 + + return dict( + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=False, + total_num_kv_heads=8, + block_size=16, + is_blocks_first=False, + ), + block_len_per_layer=[d_page] * num_layers, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=16, + remote_block_len=p_page, + remote_physical_blocks_per_logical=1, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0x10000 * (i + 1) for i in range(num_layers)], + num_blocks=500, + block_lens=[p_page] * num_layers, + block_size=16, + ), + group_spec_types=(FullAttentionSpec, FullAttentionSpec), + total_num_kv_heads_per_group=(8, 2), + local_tokens_per_block=(16, 16), + remote_tokens_per_block=(16, 32), + ) + + +class TestGemma4PlanStructure: + """Verify plan structure for Gemma4-style heterogeneous attention.""" + + def test_plan_fields_2p4d_rank0(self): + """D rank 0 at 2p4d: ratio=2, SWA head-split, FA multi-block.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=0)) + + assert plan.remote_page_size == 65536 + assert plan.local_page_size == 32768 + assert plan.group_spec_types == (FullAttentionSpec, FullAttentionSpec) + assert plan.local_blocks_per_remote_block == (1, 2) + assert plan.remote_desc_offset_per_group == (0, 0) # rank 0: index=0 + assert plan.all_source_ranks == (0,) + assert plan.source_ranks_per_group == ((0,), (0,)) + + def test_plan_fields_2p4d_rank1(self): + """D rank 1 at 2p4d: SWA reads second descriptor (index=1).""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=1)) + + assert plan.remote_desc_offset_per_group == (1, 0) # rank 1: SWA=1 + assert plan.local_blocks_per_remote_block == (1, 2) + assert plan.all_source_ranks == (0,) + + def test_plan_fields_2p4d_rank2(self): + """D rank 2 reads from P rank 1.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=2)) + + assert plan.all_source_ranks == (1,) + assert plan.remote_desc_offset_per_group == (0, 0) + + def test_fa_regions_have_multiple_descs_per_block(self): + """FA regions should have descs_per_block = page ratio.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + + for region in plan.fa_regions: + assert region.descs_per_block == 2 + assert region.desc_stride_bytes == 32768 # D page size + + +class TestGemma4RemoteDescs: + """Verify remote descriptor building with sub-descriptors.""" + + def test_descs_per_block(self): + """Each region produces num_blocks * descs_per_block descriptors.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[65536, 65536], + ) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) + + # 2 layers × 1 region/layer × 500 blocks × 2 descs/block = 2000 + expected_count = 2 * 500 * 2 + assert len(descs) == expected_count + + def test_desc_stride_within_block(self): + """Descriptors within a block should be desc_stride_bytes apart.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[65536, 65536], + ) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) + + # First block, layer 0: descriptor 0 and descriptor 1 + addr_d0, len_d0, _ = descs[0] + addr_d1, len_d1, _ = descs[1] + assert addr_d1 - addr_d0 == 32768 # desc_stride_bytes + assert len_d0 == len_d1 == 32768 # descriptor_bytes + + +class TestGemma4DescIds: + """Verify desc ID computation with sub-desc block IDs.""" + + def test_remapped_block_ids(self): + """After remapping, descriptor indices are correct.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + + # SWA blocks [3, 7], FA blocks [10, 11] + # Remapped to descriptor indices: + # SWA (desc_index=0): [3*2+0, 7*2+0] = [6, 14] + # FA (2 local per remote): [10*2, 10*2+1, 11*2, 11*2+1] = [20,21,22,23] + # + # dst_num_blocks = 500 * 2 = 1000 (num_blocks * descs_per_block) + # 2 fa_regions (2 layers), each with 1000 desc slots + # SWA: [0*1000+6, 0*1000+14, 1*1000+6, 1*1000+14] + # = [6, 14, 1006, 1014] + # FA: [0*1000+20, 0*1000+21, 0*1000+22, 0*1000+23, + # 1*1000+20, 1*1000+21, 1*1000+22, 1*1000+23] + # = [20, 21, 22, 23, 1020, 1021, 1022, 1023] + + # First remap via read specs to get descriptor-level block IDs + local_swa = [10, 11] + local_fa = [20, 21, 22, 23] + remote_swa = [3, 7] + remote_fa = [10, 11] + + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + assert len(specs) == 1 # Single source rank + spec = specs[0] + + # Verify remapped remote block IDs + assert list(spec.remote_block_ids[0]) == [6, 14] # SWA: b*2+0 + assert list(spec.remote_block_ids[1]) == [20, 21, 22, 23] # FA: 2 per + + # Verify local block IDs unchanged + assert list(spec.local_block_ids[0]) == [10, 11] + assert list(spec.local_block_ids[1]) == [20, 21, 22, 23] + + # Now compute desc IDs with the remapped remote blocks + remote_ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=spec.remote_block_ids, + dst_num_blocks=500 * 2, # num_blocks * descs_per_block + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + expected_remote = [6, 14, 1006, 1014, 20, 21, 22, 23, 1020, 1021, 1022, 1023] + assert list(remote_ids) == expected_remote + + # Local desc IDs (standard, descs_per_block=1 locally) + local_ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=spec.local_block_ids, + dst_num_blocks=1000, # local num_blocks + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + expected_local = [10, 11, 1010, 1011, 20, 21, 22, 23, 1020, 1021, 1022, 1023] + assert list(local_ids) == expected_local + + # Both have same length → can be paired for transfer + assert len(remote_ids) == len(local_ids) + + +# ====================================================================== +# Gemma4 Gather-Read tests (local page > remote page) +# ====================================================================== + + +def _make_gemma4_gather_plan_params( + tp_rank: int = 0, + tp_size: int = 2, + remote_tp_size: int = 4, +) -> dict: + """Build kwargs for generate_gemma4_plan at 4p2d (gather-read). + + Gemma4-26B at P_TP=4, D_TP=2: + SWA: K=8, head_dim=256, P_tpb=16, D_tpb=16 → concat (2 P ranks) + FA: K=2, head_dim=512, P_tpb=16, D_tpb=32 → gather (2P→1D block) + + page_size: P=32768, D=65536 → local_page > remote_page (gather-read). + """ + d_page = 65536 + p_page = 32768 + num_layers = 2 + + return dict( + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=False, + total_num_kv_heads=8, + block_size=16, + is_blocks_first=False, + ), + block_len_per_layer=[d_page] * num_layers, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=16, + remote_block_len=p_page, + remote_physical_blocks_per_logical=1, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0x10000 * (i + 1) for i in range(num_layers)], + num_blocks=500, + block_lens=[p_page] * num_layers, + block_size=16, + ), + group_spec_types=(FullAttentionSpec, FullAttentionSpec), + total_num_kv_heads_per_group=(8, 2), + local_tokens_per_block=(16, 32), + remote_tokens_per_block=(16, 16), + ) + + +class TestGemma4GatherReadPlanStructure: + """Verify plan structure for gather-read (4p2d).""" + + def test_plan_fields_4p2d_rank0(self): + """D rank 0 at 4p2d: gather_ratio=2, SWA concat, FA gather.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + assert plan.local_page_size == 65536 + assert plan.remote_page_size == 32768 + assert plan.group_spec_types == (FullAttentionSpec, FullAttentionSpec) + assert plan.remote_blocks_per_local_block == (1, 2) + assert plan.local_blocks_per_remote_block == (1, 1) + # SWA: D rank 0 reads from P rank 0 and P rank 1 + assert (0,) in plan.source_ranks_per_group[0] or len( + plan.source_ranks_per_group[0] + ) == 2 + # FA: after GQA dedup, D rank 0 reads from P rank 0 only + assert len(plan.source_ranks_per_group[1]) == 1 + + def test_no_assertion_error(self): + """4p2d should NOT crash (old code had assert page_ratio >= 1).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + assert plan is not None + + def test_fa_regions_standard_descs(self): + """Gather-read: FA regions have descs_per_block=1 (standard).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + + for region in plan.fa_regions: + assert region.descs_per_block == 1 + assert region.descriptor_bytes == 32768 # remote page size + + +class TestGemma4GatherReadRemoteDescs: + """Verify remote descriptor building for gather-read.""" + + def test_standard_descs_per_block(self): + """Gather-read: 1 desc per block (no remote sub-descs).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[32768, 32768], + ) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) + + # 2 layers × 1 region/layer × 500 blocks × 1 desc/block = 1000 + assert len(descs) == 2 * 500 * 1 + + def test_desc_bytes_match_remote_page(self): + """Each remote desc should be remote_page_size bytes.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[32768, 32768], + ) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) + + for _, length, _ in descs: + assert length == 32768 + + +class TestGemma4GatherReadSpecs: + """Verify read spec computation for gather-read.""" + + def test_gather_read_specs_4p2d_rank0(self): + """4p2d rank 0: SWA from 2 ranks, FA from 1 rank (gather).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + # D has 2 SWA blocks and 1 FA block (32 tokens) + local_swa = [10, 11] + local_fa = [20] + # P has 2 SWA blocks per rank and 2 FA blocks (16 tokens each) + remote_swa = [5, 6] + remote_fa = [30, 31] + + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + + # SWA reads from 2 P ranks → 2 specs + assert len(specs) == 2 + + # Spec 0 (P rank 0): + # SWA: local sub-desc slot 0 → [10*2+0, 11*2+0] = [20, 22] + # FA: expanded → [20*2+0, 20*2+1] = [40, 41] + spec0 = specs[0] + assert list(spec0.local_block_ids[0]) == [20, 22] # SWA slot 0 + assert list(spec0.local_block_ids[1]) == [40, 41] # FA gather + assert list(spec0.remote_block_ids[0]) == [5, 6] # SWA blocks + assert list(spec0.remote_block_ids[1]) == [30, 31] # FA blocks + + # Spec 1 (P rank 1): + # SWA: local sub-desc slot 1 → [10*2+1, 11*2+1] = [21, 23] + # FA: empty (rank 1 not in FA source_ranks after GQA dedup) + spec1 = specs[1] + assert list(spec1.local_block_ids[0]) == [21, 23] # SWA slot 1 + assert list(spec1.remote_block_ids[0]) == [5, 6] # SWA blocks + assert spec1.local_block_ids[1] == [] # FA empty for rank 1 + assert spec1.remote_block_ids[1] == [] + + def test_gather_read_desc_ids_match(self): + """Local and remote desc IDs should have same length for NIXL.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + local_swa = [10, 11] + local_fa = [20] + remote_swa = [5, 6] + remote_fa = [30, 31] + + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + + for spec in specs: + # Remote desc IDs: standard (no sub-descs), num_blocks=500 + remote_ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=spec.remote_block_ids, + dst_num_blocks=500, + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + # Local desc IDs: gather sub-descs, num_blocks=1000*gather_ratio + local_ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, + block_ids=spec.local_block_ids, + dst_num_blocks=1000 * 2, # local_num_blocks * gather_ratio + block_size_ratio=None, + physical_blocks_per_logical=1, + ) + assert len(remote_ids) == len(local_ids), ( + f"Desc ID length mismatch for rank {spec.remote_rank}: " + f"remote={len(remote_ids)}, local={len(local_ids)}" + ) + + +class TestGemma4GatherReadPlan4p1d: + """Verify gather-read for 4p1d (D_TP=1, P_TP=4).""" + + def test_4p1d_no_crash(self): + """4p1d should not crash.""" + params = _make_gemma4_gather_plan_params(tp_rank=0, tp_size=1, remote_tp_size=4) + # D_TP=1: D_page = 131072 (8 heads * 256 * 2 * 16 * 2 for SWA) + # P_TP=4: P_page = 32768 + params["block_len_per_layer"] = [131072, 131072] + params["local_tokens_per_block"] = (16, 32) + params["remote_tokens_per_block"] = (16, 16) + plan = generate_gemma4_plan(**params) + + assert plan.local_page_size == 131072 + assert plan.remote_page_size == 32768 + assert plan.remote_blocks_per_local_block == (1, 2) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 63b56eddfaed..b85416ab3071 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -386,38 +386,6 @@ class EngineTransferInfo: """Physical blocks per logical block.""" -@dataclass(frozen=True) -class MambaEngineTransferInfo(EngineTransferInfo): - """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry. - - For hybrid SSM+Attention models, FA and Mamba layers may require - different numbers of reads from different remote ranks. This - dataclass captures that per-engine transfer plan. - """ - - remote_fa_source_ranks: tuple[int, ...] - """Remote ranks carrying unique FA heads for this local rank.""" - - remote_all_source_ranks: tuple[int, ...] - """All remote ranks this local rank reads from (FA + Mamba).""" - - remote_num_fa_reads: int - """Number of distinct remote ranks needed for FA data.""" - - remote_num_mamba_reads: int - """Number of distinct remote ranks needed for Mamba data.""" - - remote_fa_descriptor_bytes: int - """Byte size of one FA K (or V) descriptor entry.""" - - is_remote_replicated: bool - """Whether the remote engine has replicated KV heads - (remote_tp_size > total_num_kv_heads).""" - - remote_physical_heads: int - """Physical KV heads stored per remote rank.""" - - # ---- Transfer topology ---- @@ -439,8 +407,6 @@ def __post_init__(self): self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) self._engines: dict[EngineId, EngineTransferInfo] = {} - self._fa_source_sets: dict[EngineId, frozenset[int]] = {} - self._fa_source_indices: dict[EngineId, dict[int, int]] = {} # Figure out whether the first dimension of the cache is K/V # or num_blocks. @@ -487,24 +453,12 @@ def __post_init__(self): def register_remote_engine( self, remote_engine_id: EngineId, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - *, - local_block_len: int = 0, + info: EngineTransferInfo, ) -> EngineTransferInfo: """Register a remote engine, unifying worker dicts state. - Only remote engines should be registered here — the local engine's - identity (tp_size, block_size, etc.) is set via ``__init__`` params. - - For Mamba models, also computes the Mamba transfer plan and - builds the FA source lookup caches. - - Args: - local_block_len: Local representative block_len (bytes). - Required for Mamba models to compute ``fa_descriptor_bytes``. + The caller (worker) is responsible for computing the info via + the transfer policy. This method only stores and deduplicates. """ assert remote_engine_id != self.engine_id, ( f"Cannot register local engine {self.engine_id} as remote. " @@ -512,29 +466,6 @@ def register_remote_engine( ) if remote_engine_id in self._engines: return self._engines[remote_engine_id] - info: EngineTransferInfo - if self.is_mamba: - info = self._build_mamba_info( - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_len, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - local_block_len=local_block_len, - ) - assert isinstance(info, MambaEngineTransferInfo) - self._fa_source_sets[remote_engine_id] = frozenset( - info.remote_fa_source_ranks - ) - self._fa_source_indices[remote_engine_id] = { - r: i for i, r in enumerate(info.remote_fa_source_ranks) - } - else: - info = EngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - ) self._engines[remote_engine_id] = info return info @@ -622,14 +553,8 @@ def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: """Get the remote TP rank(s) that the current local TP rank will read from. When remote tp_size > local tp_size, reads from multiple remote ranks. - - For Mamba models, returns the precomputed ``all_source_ranks`` - (FA + Mamba union). """ info = self._engines[remote_engine_id] - if isinstance(info, MambaEngineTransferInfo): - return list(info.remote_all_source_ranks) - tp_ratio = self.tp_ratio(info.remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] @@ -662,309 +587,15 @@ def get_transfer_cache_regions( # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] - # ============================================================ - # Mamba-specific methods - # ============================================================ - - def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank (mamba-only).""" - return remote_rank not in self._fa_source_sets[remote_engine_id] - - def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - """ - fa_index = self._fa_source_indices[remote_engine_id] - if remote_rank in fa_index: - return fa_index[remote_rank] - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - K = self.total_num_kv_heads - remote_tp = mamba_info.remote_tp_size - r_head = self._physical_head_range(remote_tp, K, remote_rank) - for target in mamba_info.remote_fa_source_ranks: - t_head = self._physical_head_range(remote_tp, K, target) - if self._range_overlap(r_head, t_head): - return fa_index[target] - return 0 - - def fa_rank_offset( - self, remote_engine_id: EngineId, remote_kv_block_len: int - ) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when local does not index into remote. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - if self.is_mla or tp_ratio <= 0: - return 0 - K = self.total_num_kv_heads - is_local_replicated = self.tp_size > K - if is_local_replicated: - local_head = self.tp_rank * K // self.tp_size - p_rank = mamba_info.remote_fa_source_ranks[0] - p_start = p_rank * K // mamba_info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return self.tp_rank % tp_ratio * remote_kv_block_len - - def needs_split_handles(self, remote_engine_id: EngineId) -> bool: - """Whether per-remote-rank split handles are needed. - - True when FA and mamba have different read counts, requiring - different splitting factors in the local handle. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - return ( - tp_ratio < 0 - and not self.is_mla - and len(mamba_info.remote_all_source_ranks) > 1 - ) - - def compute_split_handle_data( - self, - remote_engine_id: EngineId, - src_blocks_data: list[tuple[int, int, int]], - num_fa_descs: int, - abs_tp: int, - ) -> list[list[tuple[int, int, int]]]: - """Per-remote-rank (addr, len, dev) triples for Mamba-HMA split - handles. - - FA descriptors (indices < num_fa_descs) are sliced by - ``remote_num_fa_reads``; mamba descriptors are sliced uniformly - by ``abs_tp``. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - all_handle_data: list[list[tuple[int, int, int]]] = [] - for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks): - handle_data: list[tuple[int, int, int]] = [] - skip_fa = self.should_skip_fa(remote_engine_id, p_rank) - fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0 - for j, (addr, local_len, dev) in enumerate(src_blocks_data): - if j < num_fa_descs: - assert mamba_info.remote_num_fa_reads >= 1 - fa_chunk = local_len // mamba_info.remote_num_fa_reads - handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) - else: - mamba_chunk = local_len // abs_tp - handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) - all_handle_data.append(handle_data) - return all_handle_data - - def filter_block_ids_for_rank( - self, - remote_engine_id: EngineId, - remote_rank: int, - local_ids: BlockIds, - remote_ids: BlockIds, - is_mamba_group: list[bool], - ) -> tuple[BlockIds, BlockIds]: - """Zero out FA groups for remote ranks outside ``fa_source_ranks``. - - Returns (filtered_local_ids, filtered_remote_ids). When the - remote rank carries FA data for this local rank, returns the - inputs unchanged. - """ - if not self.should_skip_fa(remote_engine_id, remote_rank): - return local_ids, remote_ids - num_groups = len(local_ids) - filtered_local: list[list[int]] = [ - [] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups) - ] - filtered_remote: list[list[int]] = [ - [] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups) - ] - return filtered_local, filtered_remote - def describe(self, remote_engine_id: EngineId) -> str: """One-line summary of transfer config for logging.""" info = self._engines[remote_engine_id] - base = ( + return ( + f"TransferTopology(" f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, " f"K={self.total_num_kv_heads}, " f"local_tp={self.tp_size}, " f"remote_tp={info.remote_tp_size}, " f"local_rank={self.tp_rank}, " - f"remote_block_len={info.remote_block_len}" - ) - if isinstance(info, MambaEngineTransferInfo): - return ( - f"TransferTopology.mamba({base}, " - f"fa_reads={info.remote_num_fa_reads}, " - f"mamba_reads={info.remote_num_mamba_reads}, " - f"fa_sources={list(info.remote_fa_source_ranks)}, " - f"all_sources={list(info.remote_all_source_ranks)}, " - f"fa_desc_bytes={info.remote_fa_descriptor_bytes})" - ) - return f"TransferTopology({base})" - - # ============================================================ - # Private helpers - # ============================================================ - # Mamba-HMA hetero-TP transfer config: - # With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across - # P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is - # almost always uniquely sharded per P rank. So the number of P - # ranks D must read from can differ between FA and Mamba, and they - # must be handled separately. - - @staticmethod - def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight partitioning): - consecutive ranks share a head before moving to the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - @staticmethod - def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - - # ============================================================ - # Private: build Mamba transfer info - # ============================================================ - - def _build_mamba_info( - self, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - local_block_len: int, - ) -> MambaEngineTransferInfo: - """Compute Mamba transfer plan.""" - K = self.total_num_kv_heads - local_tp = self.tp_size - local_rank = self.tp_rank - - is_remote_replicated = remote_tp_size > K - remote_physical_heads = max(1, K // remote_tp_size) - - if local_tp >= remote_tp_size: - assert local_tp % remote_tp_size == 0 - tp_ratio = local_tp // remote_tp_size - else: - assert remote_tp_size % local_tp == 0 - tp_ratio = -(remote_tp_size // local_tp) - - abs_tp = -tp_ratio if tp_ratio < 0 else 1 - - mamba_range: range | None = None - if tp_ratio < 0: - mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) - - # ---- FA read targets ---- - if self.is_mla or tp_ratio >= 0: - num_fa_reads = 1 - fa_source_ranks: list[int] = ( - [0] - if self.is_mla - else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] - ) - else: - local_needs = self._physical_head_range(local_tp, K, local_rank) - search_range = ( - mamba_range if mamba_range is not None else range(remote_tp_size) - ) - seen: set[tuple[int, int]] = set() - fa_source_ranks = [] - for p in search_range: - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - if not fa_source_ranks: - for p in range(remote_tp_size): - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - num_fa_reads = len(fa_source_ranks) - - # ---- All source ranks (mamba + FA) ---- - if mamba_range is not None and abs_tp > num_fa_reads: - num_mamba_reads = abs_tp - all_source_ranks = list(mamba_range) - else: - num_mamba_reads = num_fa_reads - all_source_ranks = list(fa_source_ranks) - - # ---- FA descriptor bytes ---- - effective_block_len = min(local_block_len, remote_block_len) - if self.is_kv_layout_blocks_first: - fa_descriptor_bytes = effective_block_len // 2 - else: - fa_descriptor_bytes = effective_block_len - - # ---- Validation ---- - is_local_replicated = local_tp > K - if is_local_replicated and is_remote_replicated and tp_ratio > 0: - logger.info( - "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", - local_tp, - remote_tp_size, - K, - ) - tt_set = set(all_source_ranks) - for t in fa_source_ranks: - if t not in tt_set: - logger.error( - "FA source rank %d NOT in all_source_ranks %s.", - t, - all_source_ranks, - ) - if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: - local_k_half = local_block_len // 2 - remote_k_half = remote_block_len // 2 - expected = local_k_half // num_fa_reads - if expected != remote_k_half: - logger.warning( - "FA size mismatch: local_k_half=%d / reads=%d = %d, " - "but remote_k_half=%d.", - local_k_half, - num_fa_reads, - expected, - remote_k_half, - ) - - return MambaEngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - remote_fa_source_ranks=tuple(fa_source_ranks), - remote_all_source_ranks=tuple(all_source_ranks), - remote_num_fa_reads=num_fa_reads, - remote_num_mamba_reads=num_mamba_reads, - remote_fa_descriptor_bytes=fa_descriptor_bytes, - is_remote_replicated=is_remote_replicated, - remote_physical_heads=remote_physical_heads, + f"remote_block_len={info.remote_block_len})" ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 715fcbde16c9..5a94070ebde7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -821,6 +821,7 @@ def __init__( self.cache_config = vllm_config.cache_config self.kv_cache_config = kv_cache_config self.use_mla = self.model_config.use_mla + self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() # Get the attention backend from the first layer @@ -863,6 +864,9 @@ def _sync_block_size_with_kernel(self) -> None: kernel_block_size, ) assert self.block_size > kernel_block_size + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) self.block_size = kernel_block_size def __del__(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index 71ebbf1174fb..724fc709d841 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -32,8 +32,9 @@ # Version History: # 1: Initial version with compatibility checking # 2: Add remote_request_id to kv_transfer_params +# 3: Add physical_blocks_per_logical_kv_block to NixlAgentMetadata # -NIXL_CONNECTOR_VERSION: int = 2 +NIXL_CONNECTOR_VERSION: int = 3 @dataclass @@ -48,6 +49,13 @@ class NixlAgentMetadata: block_size: int ssm_sizes: tuple[int, int] attn_backend_name: str + physical_blocks_per_logical_kv_block: int + # Per-group block_size in tokens after page unification, indexed by + # kv_cache_group position. Needed for HeteroTP models (e.g. Gemma4) + # where groups have different token counts per block. + # Example — Gemma4 at P_TP=2: [16, 32] for [SWA, FA]. + # None for homogeneous models (all groups share the same block_size). + tokens_per_block_per_group: list[int] | None = None @dataclass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py new file mode 100644 index 000000000000..66ff4f2e9be1 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -0,0 +1,1009 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Plan-based transfer design for NIXL connector. + +Data structures, plan generators, and local descriptor builders +for NIXL KV cache transfers. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + BlockIds, + EngineTransferInfo, + TransferTopology, +) +from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, +) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec, MambaSpec + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, + ) + + +# ====================================================================== +# 1. Data structures +# ====================================================================== + + +@dataclass(frozen=True) +class ReadSpec: + """Specification for a single remote block read operation.""" + + remote_rank: int + local_block_ids: BlockIds + remote_block_ids: BlockIds + + +def _is_attention_spec(spec_type: type[KVCacheSpec]) -> bool: + return issubclass(spec_type, AttentionSpec) + + +def _is_ssm_spec(spec_type: type[KVCacheSpec]) -> bool: + return issubclass(spec_type, MambaSpec) + + +@dataclass(frozen=True) +class RegionPlan: + """Geometry for one descriptor region. + + Everything needed to build NIXL descriptors and compute descriptor + IDs is baked in. The caller plugs in ``base_addr`` and + ``device_id`` when constructing the final descriptor tuples. + + When ``descs_per_block > 1``, each physical block produces multiple + NIXL descriptors. This happens when the remote page is larger than + the local page (e.g. Gemma4 2p4d where P page = 65536 bytes, + D page = 32768 bytes → ``descs_per_block = 2``). Each descriptor + covers one local-page-sized chunk of the remote block. + """ + + layer_idx: int + + # Descriptor geometry + descriptor_bytes: int + offset_in_page: int + page_stride: int + num_blocks: int + + # How many NIXL descriptors to register per physical block. + # Default 1 (one desc per block). When the remote page is N times + # larger than local, set to N so each block produces N descriptors. + descs_per_block: int = 1 + # Byte offset between consecutive descriptors within the same block. + desc_stride_bytes: int = 0 + + +@dataclass(frozen=True) +class EngineTransferPlan: + """Complete transfer plan for one remote engine. + + Generated once during handshake. Regions are split into + ``fa_regions`` and ``ssm_regions`` matching the descriptor + handle layout. + + Per-group HeteroTP fields enable models where different attention + groups have different transfer behaviors (e.g. Gemma4 SWA + FA). + """ + + # --- Core regions (descriptor handle order) --- + fa_regions: tuple[RegionPlan, ...] + ssm_regions: tuple[RegionPlan, ...] + + # Per-group KVCacheSpec type — used for descriptor indexing. + group_spec_types: tuple[type[KVCacheSpec], ...] + + # Per-group ordered source ranks. Position = local piece index. + source_ranks_per_group: tuple[tuple[int, ...], ...] + + # Superset of all source ranks (union of all groups). + all_source_ranks: tuple[int, ...] + + # Per-group head slot mapping. Each dict maps source rank → slot. + # Per-group because different groups can have different num_kv_heads, + # leading to different head-to-slot assignments. + # Example: Gemma4 has SWA K=8 and FA K=2; at 4p2d these would + # produce genuinely different slot mappings. + rank_to_attention_slot: tuple[dict[int, int], ...] + + # Stride for expanding remote logical block IDs to kernel block IDs. + remote_expansion_stride: int + + # --- Page sizes (bytes per physical block, same for all groups) --- + # Used to determine transfer direction and descriptor layout. + # Split-read: remote_page_size > local_page_size (e.g. Gemma4 2p4d) + # Gather-read: local_page_size > remote_page_size (e.g. Gemma4 4p2d) + # Standard: local_page_size == remote_page_size + local_page_size: int + remote_page_size: int + + # --- HeteroTP per-group fields (e.g. Gemma4 SWA + FA) --- + # For Dense/Mamba (equal page sizes), these are unused and default + # to empty. + + # Per-group: how many local (D) blocks correspond to one remote (P) + # block. Computed as remote_block_size / local_block_size per group. + # Gemma4 2p4d: SWA = 16/16 = 1, FA = 32/16 = 2. + local_blocks_per_remote_block: tuple[int, ...] = () + + # Per-group: which descriptor offset to read from a multi-descriptor + # remote block (for head-split groups where local reads a portion). + # Gemma4 2p4d rank 0: SWA = 0 (first half), FA = 0 (unused, reads all). + # Gemma4 2p4d rank 1: SWA = 1 (second half), FA = 0. + remote_desc_offset_per_group: tuple[int, ...] = () + + # Per-group: how many remote blocks fill one local block. + # FA in 4p2d: D_tpb / P_tpb = 32 / 16 = 2. + remote_blocks_per_local_block: tuple[int, ...] = () + + def __post_init__(self): + big, small = ( + max(self.local_page_size, self.remote_page_size), + min(self.local_page_size, self.remote_page_size), + ) + assert small > 0, "Page sizes must be positive" + assert big % small == 0, ( + f"Page sizes must be evenly divisible: " + f"local={self.local_page_size}, remote={self.remote_page_size}" + ) + + @property + def all_regions(self) -> tuple[RegionPlan, ...]: + return self.fa_regions + self.ssm_regions + + +# ====================================================================== +# 2. Internal helpers +# ====================================================================== + + +def _get_kv_block_len( + layer_idx: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> int: + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + +@dataclass(frozen=True) +class TPMapping: + """Complete local-to-remote TP mapping for one remote engine.""" + + source_ranks_per_group: tuple[tuple[int, ...], ...] + all_source_ranks: tuple[int, ...] + rank_to_attention_slot: dict[int, int] + rank_offset_factor: int + + +def _compute_tp_mapping( + tp_rank: int, + tp_size: int, + remote_tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + group_spec_types: tuple[type[KVCacheSpec], ...], +) -> TPMapping: + """Build the complete local-to-remote TP mapping. + + Computes source ranks, head slot assignments, and the rank offset + factor in a single pass. Both generators call this and unpack. + """ + # --- Attention source ranks --- + if is_mla: + # All heads replicated across all ranks. + attn_ranks = [0] + elif tp_size >= remote_tp_size: + attn_ranks = [tp_rank * remote_tp_size // tp_size] + else: + # P (remote TP) > D (local TP): one local rank + # reads from multiple remote ranks. + # GQA dedup: when K < remote_tp_size, several remote ranks + # hold the same KV head. np.unique keeps only the first + # rank per unique head so we don't issue redundant reads. + abs_tp = remote_tp_size // tp_size + start = tp_rank * abs_tp + heads = np.arange(start, start + abs_tp) * total_num_kv_heads // remote_tp_size + _, unique_idx = np.unique(heads, return_index=True) + attn_ranks = (start + np.sort(unique_idx)).tolist() + + # --- SSM source ranks --- + has_ssm = any(_is_ssm_spec(t) for t in group_spec_types) + if has_ssm: + if tp_size < remote_tp_size: + abs_tp = remote_tp_size // tp_size + ssm_ranks = list(range(tp_rank * abs_tp, (tp_rank + 1) * abs_tp)) + else: + ssm_ranks = list(attn_ranks) + else: + ssm_ranks = [] + + all_ranks = sorted(set(attn_ranks) | set(ssm_ranks)) + + # --- Per-group ordered source ranks --- + source_ranks_per_group = tuple( + tuple(ssm_ranks) if _is_ssm_spec(t) else tuple(attn_ranks) + for t in group_spec_types + ) + + # --- Attention head slots --- + head_to_slot: dict[int, int] = {} + for i, r in enumerate(attn_ranks): + head_to_slot[r * total_num_kv_heads // remote_tp_size] = i + rank_to_attention_slot = { + r: head_to_slot.get(r * total_num_kv_heads // remote_tp_size, 0) + for r in all_ranks + } + + # --- Rank offset factor --- + if is_mla or tp_size <= remote_tp_size: + rank_offset_factor = 0 + elif tp_size > total_num_kv_heads: + local_head = tp_rank * total_num_kv_heads // tp_size + p_start = attn_ranks[0] * total_num_kv_heads // remote_tp_size + rank_offset_factor = local_head - p_start + else: + rank_offset_factor = tp_rank % (tp_size // remote_tp_size) + + return TPMapping( + source_ranks_per_group=source_ranks_per_group, + all_source_ranks=tuple(all_ranks), + rank_to_attention_slot=rank_to_attention_slot, + rank_offset_factor=rank_offset_factor, + ) + + +def _build_fa_regions( + *, + block_len_per_layer: list[int], + remote_block_lens: list[int], + is_blocks_first: bool, + block_size_ratio: int, + num_attn_reads: int, + rank_offset_factor: int, + remote_num_blocks: int, +) -> list[RegionPlan]: + """Build FA (attention) regions for the transfer plan. + + K bytes = remote_kv_block_len / num_attn_reads. + V bytes = local_block_len / num_attn_reads (no block_size_ratio). + Offset = rank_offset_factor * remote_kv_block_len per layer. + """ + assert len(remote_block_lens) == len(block_len_per_layer), ( + f"Layer count mismatch: remote has {len(remote_block_lens)} layers " + f"but local has {len(block_len_per_layer)}" + ) + fa_regions: list[RegionPlan] = [] + for i in range(len(remote_block_lens)): + local_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + remote_kv_block_len = local_block_len // block_size_ratio + k_desc_bytes = remote_kv_block_len // num_attn_reads + rank_offset = rank_offset_factor * remote_kv_block_len + page_stride = remote_block_lens[i] + + fa_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=k_desc_bytes, + offset_in_page=rank_offset, + page_stride=page_stride, + num_blocks=remote_num_blocks, + ) + ) + + if is_blocks_first: + v_desc_bytes = local_block_len // num_attn_reads + fa_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=v_desc_bytes, + offset_in_page=rank_offset + page_stride // 2, + page_stride=page_stride, + num_blocks=remote_num_blocks, + ) + ) + + return fa_regions + + +# ====================================================================== +# 3. Plan generators +# ====================================================================== + + +def generate_dense_plan( + *, + transfer_topo: TransferTopology, + block_len_per_layer: list[int], + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, + group_spec_types: tuple[type[KVCacheSpec], ...], + local_physical_blocks_per_logical: int, +) -> EngineTransferPlan: + """Generate transfer plan for dense (attention-only) models.""" + local_page = block_len_per_layer[0] + remote_page = remote_meta.block_lens[0] + + block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size + + tp_mapping = _compute_tp_mapping( + transfer_topo.tp_rank, + transfer_topo.tp_size, + remote_info.remote_tp_size, + transfer_topo.is_mla, + transfer_topo.total_num_kv_heads, + group_spec_types=group_spec_types, + ) + + num_attn_reads = next( + len(ranks) + for t, ranks in zip(group_spec_types, tp_mapping.source_ranks_per_group) + if _is_attention_spec(t) + ) + fa_regions = _build_fa_regions( + block_len_per_layer=block_len_per_layer, + remote_block_lens=remote_meta.block_lens, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + block_size_ratio=block_size_ratio, + num_attn_reads=num_attn_reads, + rank_offset_factor=tp_mapping.rank_offset_factor, + remote_num_blocks=remote_meta.num_blocks, + ) + + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=(), + group_spec_types=group_spec_types, + source_ranks_per_group=tp_mapping.source_ranks_per_group, + all_source_ranks=tp_mapping.all_source_ranks, + rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,), + remote_expansion_stride=local_physical_blocks_per_logical, + local_page_size=local_page, + remote_page_size=remote_page, + ) + + +def generate_mamba_plan( + *, + transfer_topo: TransferTopology, + block_len_per_layer: list[int], + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, + group_spec_types: tuple[type[KVCacheSpec], ...], + conv_decomp: MambaConvSplitInfo, + ssm_sizes: tuple[int, int], +) -> EngineTransferPlan: + """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size + remote_tp_size = remote_info.remote_tp_size + remote_phys_ratio = remote_info.remote_physical_blocks_per_logical + remote_block_lens = remote_meta.block_lens + remote_ssm_sizes = remote_meta.ssm_sizes + + block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + + tp_mapping = _compute_tp_mapping( + tp_rank, + tp_size, + remote_tp_size, + transfer_topo.is_mla, + transfer_topo.total_num_kv_heads, + group_spec_types, + ) + + # ---- FA regions ---- + num_attn_reads = next( + len(ranks) + for t, ranks in zip(group_spec_types, tp_mapping.source_ranks_per_group) + if _is_attention_spec(t) + ) + fa_regions = _build_fa_regions( + block_len_per_layer=block_len_per_layer, + remote_block_lens=remote_block_lens, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + block_size_ratio=block_size_ratio, + num_attn_reads=num_attn_reads, + rank_offset_factor=tp_mapping.rank_offset_factor, + remote_num_blocks=remote_meta.num_blocks, + ) + + # ---- SSM regions ---- + effective_ratio = tp_size // remote_tp_size if tp_size >= remote_tp_size else 1 + local_offset = tp_rank % max(effective_ratio, 1) + conv_size_remote = remote_ssm_sizes[0] + ssm_num_blocks = remote_meta.num_blocks // remote_phys_ratio + + # Mamba conv state is always TP-sharded, even when attention KV + # is replicated (num_kv_heads < tp_size). + if tp_size >= remote_tp_size: + # D_TP >= P_TP: P page is larger, D reads its slice. + conv_offsets = conv_decomp.remote_conv_offsets( + local_offset, + effective_ratio, + ) + ssm_read_size = ssm_sizes[1] + else: + # NOTE (ZhanqiuHu): P_TP > D_TP, so P pages are smaller + # than D's. conv_decomp has D-sized dimensions, but we + # need P-sized offsets. Scale down by abs_ratio. + abs_ratio = remote_tp_size // tp_size + xb_p = conv_decomp.x_bytes // abs_ratio + bb_p = conv_decomp.b_bytes // abs_ratio + conv_offsets = [ + (0, xb_p), + (xb_p, bb_p), + (xb_p + bb_p, bb_p), + ] + ssm_read_size = remote_ssm_sizes[1] + + # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], + # in case block lengths vary across layers (e.g. MLA). + ssm_regions: list[RegionPlan] = [] + for i in range(len(remote_block_lens)): + page_stride = remote_block_lens[i] * remote_phys_ratio + + for off, sz in conv_offsets: + ssm_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=sz, + offset_in_page=off, + page_stride=page_stride, + num_blocks=ssm_num_blocks, + ) + ) + + ssm_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=ssm_read_size, + offset_in_page=conv_size_remote + local_offset * ssm_read_size, + page_stride=page_stride, + num_blocks=ssm_num_blocks, + ) + ) + + local_page = block_len_per_layer[0] + remote_page = remote_block_lens[0] + + n_groups = len(group_spec_types) + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=tuple(ssm_regions), + group_spec_types=group_spec_types, + source_ranks_per_group=tp_mapping.source_ranks_per_group, + all_source_ranks=tp_mapping.all_source_ranks, + rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,) * n_groups, + remote_expansion_stride=remote_phys_ratio, + local_page_size=local_page, + remote_page_size=remote_page, + ) + + +def generate_gemma4_plan( + *, + transfer_topo: TransferTopology, + block_len_per_layer: list[int], + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, + group_spec_types: tuple[type[KVCacheSpec], ...], + total_num_kv_heads_per_group: tuple[int, ...], + local_tokens_per_block: tuple[int, ...], + remote_tokens_per_block: tuple[int, ...], +) -> EngineTransferPlan: + """Generate transfer plan for Gemma4-style heterogeneous attention. + + Gemma4 has multiple attention groups (SWA, FA) with different + ``total_num_kv_heads`` and ``head_dim``. With page unification and + HMA, all groups share physical memory pools. This generator: + + 1. Calls ``_compute_tp_mapping`` per group with group-specific K. + 2. Handles both **split-read** (remote page > local page, e.g. 2p4d) + and **gather-read** (local page > remote page, e.g. 4p2d). + 3. Encodes per-group transfer behavior via + ``local_blocks_per_remote_block`` / ``remote_blocks_per_local_block`` + and ``remote_desc_offset_per_group``. + + Split-read (P_page > D_page): each remote block → multiple descriptors. + Gather-read (D_page > P_page): each local block → multiple descriptors. + """ + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size + remote_tp_size = remote_info.remote_tp_size + is_mla = transfer_topo.is_mla + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + n_groups = len(group_spec_types) + + local_page = block_len_per_layer[0] + remote_page = remote_meta.block_lens[0] + + if remote_page >= local_page: + descs_per_remote_block = remote_page // local_page + descs_per_local_block = 1 + else: + descs_per_remote_block = 1 + descs_per_local_block = local_page // remote_page + + blocks_per_remote: list[int] = [] + remote_blocks_per_local: list[int] = [] + remote_desc_offset: list[int] = [] + + source_ranks_all: list[tuple[int, ...]] = [] + rank_to_slot_all: list[dict[int, int]] = [] + + for g in range(n_groups): + r_tpb = remote_tokens_per_block[g] + l_tpb = local_tokens_per_block[g] + + if r_tpb >= l_tpb: + blocks_per_remote.append(r_tpb // l_tpb) + remote_blocks_per_local.append(1) + else: + blocks_per_remote.append(1) + remote_blocks_per_local.append(l_tpb // r_tpb) + + K_g = total_num_kv_heads_per_group[g] + m_g = _compute_tp_mapping( + tp_rank, + tp_size, + remote_tp_size, + is_mla, + K_g, + (group_spec_types[g],), + ) + source_ranks_all.append(m_g.source_ranks_per_group[0]) + rank_to_slot_all.append(m_g.rank_to_attention_slot) + + # Head-split groups (split-read only): rank_offset selects descriptor. + if blocks_per_remote[-1] == 1 and descs_per_remote_block > 1: + remote_desc_offset.append(m_g.rank_offset_factor) + else: + remote_desc_offset.append(0) + + all_ranks: set[int] = set() + for ranks in source_ranks_all: + all_ranks.update(ranks) + all_source_ranks = tuple(sorted(all_ranks)) + + # --- Diagnostic logging for HeteroTP plan --- + logger.info( + "[HeteroTP Plan] tp_rank=%d, tp_size=%d, remote_tp_size=%d, " + "local_page=%d, remote_page=%d, " + "descs_per_remote_block=%d, descs_per_local_block=%d", + tp_rank, + tp_size, + remote_tp_size, + local_page, + remote_page, + descs_per_remote_block, + descs_per_local_block, + ) + for g in range(n_groups): + logger.info( + "[HeteroTP Plan] group=%d spec=%s: K=%d, " + "local_tpb=%d, remote_tpb=%d, " + "blocks_per_remote=%d, remote_blocks_per_local=%d, " + "desc_offset=%d, source_ranks=%s, slot_map=%s", + g, + group_spec_types[g].__name__, + total_num_kv_heads_per_group[g], + local_tokens_per_block[g], + remote_tokens_per_block[g], + blocks_per_remote[g], + remote_blocks_per_local[g], + remote_desc_offset[g], + source_ranks_all[g], + rank_to_slot_all[g], + ) + + # HMA: one K pool (+ optional V pool) shared by all groups. + fa_regions: list[RegionPlan] = [] + for i in range(len(remote_meta.block_lens)): + local_block_len = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + page_stride = remote_meta.block_lens[i] + + if descs_per_remote_block > 1: + # Split-read: remote blocks produce descriptors of local page size + desc_bytes = local_block_len + descs_per_block = descs_per_remote_block + desc_stride = local_block_len + elif descs_per_local_block > 1: + # Gather-read: standard remote descs at remote page size + remote_block_len = _get_kv_block_len( + i, remote_meta.block_lens, is_blocks_first + ) + desc_bytes = remote_block_len + descs_per_block = 1 + desc_stride = 0 + else: + desc_bytes = local_block_len + descs_per_block = 1 + desc_stride = 0 + + fa_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=desc_bytes, + offset_in_page=0, + page_stride=page_stride, + num_blocks=remote_meta.num_blocks, + descs_per_block=descs_per_block, + desc_stride_bytes=desc_stride, + ) + ) + + if is_blocks_first: + fa_regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=desc_bytes, + offset_in_page=page_stride // 2, + page_stride=page_stride, + num_blocks=remote_meta.num_blocks, + descs_per_block=descs_per_block, + desc_stride_bytes=desc_stride, + ) + ) + + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=(), + group_spec_types=group_spec_types, + source_ranks_per_group=tuple(source_ranks_all), + all_source_ranks=all_source_ranks, + rank_to_attention_slot=tuple(rank_to_slot_all), + remote_expansion_stride=1, + local_page_size=local_page, + remote_page_size=remote_page, + local_blocks_per_remote_block=tuple(blocks_per_remote), + remote_desc_offset_per_group=tuple(remote_desc_offset), + remote_blocks_per_local_block=tuple(remote_blocks_per_local), + ) + + +# ====================================================================== +# 4. Local descriptor building +# ====================================================================== + + +def _remap_remote_blocks_to_desc_ids( + plan: EngineTransferPlan, + remote_block_ids: BlockIds, + local_block_ids: BlockIds, +) -> tuple[BlockIds, BlockIds]: + """Convert remote block IDs into descriptor-level indices. + + When ``remote_page_size > local_page_size`` (split-read), each remote + physical block is registered as multiple descriptors (one per + local-page-sized chunk). This function converts remote block IDs + into the descriptor index space so that ``_compute_desc_ids_from_plan`` + can look up the correct descriptors. + + Two per-group cases: + + * **Multi-block** (``local_blocks_per_remote_block > 1``, e.g. FA): + One remote block covers multiple local blocks. + Remote block ``b`` → descriptor indices + ``[b*N, b*N+1, ..., b*N+(n-1)]`` where N = descs_per_remote_block. + Example: FA block 10, N=2 → desc indices [20, 21]. + + * **Head-split** (``local_blocks_per_remote_block == 1``, e.g. SWA): + Local reads one specific chunk of the remote block. + Remote block ``b`` → descriptor index + ``b*N + remote_desc_offset_per_group[g]``. + Example: SWA block 10, N=2, offset=1 → desc index 21. + + Local block IDs are returned unchanged. + """ + if plan.remote_page_size <= plan.local_page_size: + return remote_block_ids, local_block_ids + if not plan.local_blocks_per_remote_block: + return remote_block_ids, local_block_ids + + descs_per_remote_block = plan.remote_page_size // plan.local_page_size + num_groups = len(remote_block_ids) + new_remote: list[list[int]] = [] + new_local: list[list[int]] = [] + + for g in range(num_groups): + n_local = plan.local_blocks_per_remote_block[g] + r_ids = list(remote_block_ids[g]) + l_ids = list(local_block_ids[g]) + + if n_local > 1: + remapped: list[int] = [] + for b in r_ids: + remapped.extend(b * descs_per_remote_block + s for s in range(n_local)) + new_remote.append(remapped) + else: + idx = plan.remote_desc_offset_per_group[g] + new_remote.append([b * descs_per_remote_block + idx for b in r_ids]) + + new_local.append(l_ids) + + return new_remote, new_local + + +def _build_gather_read_specs( + plan: EngineTransferPlan, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, +) -> list[ReadSpec]: + """Build read specs for gather-read (local page > remote page). + + In gather-read, each local block is registered as multiple NIXL + descriptors (``descs_per_block`` in ``RegionPlan``), each matching + the remote block byte size. Each rank's read targets specific + local descriptor IDs: + + * **Gather groups** (``remote_blocks_per_local_block > 1``, e.g. FA): + N remote blocks fill one local block. + Local block ``b`` → descriptor IDs + ``[b*descs_per_local_block, ..., b*descs_per_local_block+(N-1)]``. + Remote block IDs are kept as-is (one remote block = one + remote descriptor). The matched-length invariant + ``len(local_desc_ids) == len(remote_block_ids)`` must hold; + it is enforced by an assertion after construction. + + * **Concat groups** (``remote_blocks_per_local_block == 1``, e.g. SWA): + Each rank writes to a specific descriptor within the local block. + Local block ``b`` → descriptor ID + ``b*descs_per_local_block + rank_slot``. + """ + descs_per_local_block = plan.local_page_size // plan.remote_page_size + num_groups = len(local_block_ids) + + def _pair_gather_group( + g_local_block_ids: list[int], + g_remote_block_ids: list[int], + remote_blocks_per_local: int, + ) -> tuple[list[int], list[int]]: + """Pair local descriptor IDs with remote block IDs for a gather group. + + With HMA, all groups receive the same block ID list. For gather + groups (``remote_blocks_per_local > 1``), every + ``remote_blocks_per_local`` consecutive remote blocks map to + descriptors of a single local block: + + local block b, remote blocks [r0, r1] → + local desc b*descs_per_local_block + 0 paired with r0 + local desc b*descs_per_local_block + 1 paired with r1 + + When the remote block count is not a multiple of + ``remote_blocks_per_local``, the remainder fills the first + descriptors of the next local block (partial fill). + + Returns matched-length lists: + (local_desc_ids, paired_remote_block_ids) + """ + n_local = len(g_local_block_ids) + n_remote = len(g_remote_block_ids) + n_full = min(n_remote // remote_blocks_per_local, n_local) + remainder_remote = n_remote - n_full * remote_blocks_per_local + + local_desc_ids: list[int] = [] + paired_remote: list[int] = [] + + for i in range(n_full): + b = g_local_block_ids[i] + for s in range(remote_blocks_per_local): + local_desc_ids.append(b * descs_per_local_block + s) + paired_remote.append( + g_remote_block_ids[i * remote_blocks_per_local + s] + ) + + if remainder_remote > 0 and n_full < n_local: + b = g_local_block_ids[n_full] + base = n_full * remote_blocks_per_local + for s in range(remainder_remote): + local_desc_ids.append(b * descs_per_local_block + s) + paired_remote.append(g_remote_block_ids[base + s]) + + return local_desc_ids, paired_remote + + specs: list[ReadSpec] = [] + + for rank in plan.all_source_ranks: + rank_local: list[list[int]] = [] + rank_remote: list[list[int]] = [] + + for g in range(num_groups): + if rank not in plan.source_ranks_per_group[g]: + rank_local.append([]) + rank_remote.append([]) + continue + + n_remote_per_local = plan.remote_blocks_per_local_block[g] + + if n_remote_per_local > 1: + g_local, g_remote = _pair_gather_group( + local_block_ids[g], + remote_block_ids[g], + n_remote_per_local, + ) + rank_local.append(g_local) + rank_remote.append(g_remote) + else: + slot = plan.rank_to_attention_slot[g].get(rank, 0) + l_ids = local_block_ids[g] + r_ids = remote_block_ids[g] + n = min(len(l_ids), len(r_ids)) + rank_local.append( + [ + l_ids[i] * descs_per_local_block + slot + for i in range(len(l_ids) - n, len(l_ids)) + ] + ) + rank_remote.append(list(r_ids[len(r_ids) - n :])) + + for g in range(num_groups): + assert len(rank_local[g]) == len(rank_remote[g]), ( + f"Gather-read length mismatch: group={g}, rank={rank}, " + f"n_local_descs={len(rank_local[g])}, " + f"n_remote_blocks={len(rank_remote[g])}. " + f"Each local descriptor must pair with exactly one " + f"remote block ID." + ) + + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=rank_local, + remote_block_ids=rank_remote, + ) + ) + + return specs + + +logger = logging.getLogger(__name__) + + +# ====================================================================== +# 4. Local descriptor building +# ====================================================================== + + +def build_fa_local_regions( + num_blocks: int, + block_size_ratio: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> list[RegionPlan]: + """Build FA local region specs for NIXL registration.""" + regions: list[RegionPlan] = [] + n_blocks = num_blocks * block_size_ratio + for i in range(len(block_len_per_layer)): + kv_block_len = ( + _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + // block_size_ratio + ) + page_stride = block_len_per_layer[i] // block_size_ratio + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=kv_block_len, + offset_in_page=0, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + if is_blocks_first: + second_split = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=second_split, + offset_in_page=kv_block_len, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + return regions + + +def build_fa_local_descs_for_gather_read( + base_addresses: list[int], + device_id: int, + num_blocks: int, + block_len_per_layer: list[int], + is_blocks_first: bool, + gather_page_ratio: int, +) -> list[tuple[int, int, int]]: + """Build FA local descriptors for gather-read. + + Each local block produces ``gather_page_ratio`` descriptors, each + covering ``kv_block_len // gather_page_ratio`` bytes. This allows + NIXL to pair each local descriptor with a remote descriptor of + matching size (the remote's natural page size). + """ + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + kv_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + page_stride = block_len_per_layer[i] + desc_bytes = kv_block_len // gather_page_ratio + + for block_id in range(num_blocks): + blk_addr = base_addr + block_id * page_stride + for s in range(gather_page_ratio): + result.append((blk_addr + s * desc_bytes, desc_bytes, device_id)) + + if is_blocks_first: + v_desc_bytes = kv_block_len // gather_page_ratio + for block_id in range(num_blocks): + v_blk_addr = base_addr + block_id * page_stride + kv_block_len + for s in range(gather_page_ratio): + result.append( + (v_blk_addr + s * v_desc_bytes, v_desc_bytes, device_id) + ) + + return result + + +def build_mamba_local_regions( + block_len_per_layer: list[int], + logical_num_blocks: int, + block_size_ratio: int, + conv_decomp: MambaConvSplitInfo, + ssm_sizes: tuple[int, int], + physical_blocks_per_logical: int, +) -> list[RegionPlan]: + """Build 4 SSM region specs (x, B, C, ssm) per layer.""" + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + conv_offsets = conv_decomp.local_conv_offsets + conv_size, ssm_size = ssm_sizes + n_blocks = logical_num_blocks * block_size_ratio + phys_ratio = physical_blocks_per_logical + + regions: list[RegionPlan] = [] + for i in range(len(block_len_per_layer)): + page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio + for off, sz in conv_offsets: + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=sz, + offset_in_page=off, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + # SSM temporal state follows the conv state. + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=ssm_size, + offset_in_page=conv_size, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + return regions diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 607bf4b988ff..2466293468e4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -21,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, - MambaEngineTransferInfo, + EngineTransferInfo, TransferTopology, get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, @@ -43,13 +43,26 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import ( NixlKVConnectorStats, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + ReadSpec, + _build_gather_read_specs, + _is_attention_spec, + _is_ssm_spec, + _remap_remote_blocks_to_desc_ids, + build_fa_local_descs_for_gather_read, + build_fa_local_regions, + build_mamba_local_regions, + generate_dense_plan, + generate_gemma4_plan, + generate_mamba_plan, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( _NIXL_SUPPORTED_DEVICE, zmq_ctx, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, - compute_physical_blocks_per_logical, derive_mamba_conv_split, ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config @@ -58,11 +71,11 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, MambaSpec, UniformTypeKVCacheSpecs, @@ -80,6 +93,263 @@ class NixlConnectorWorker: """Implementation of Worker side methods""" + # ------------------------------------------------------------------ + # Plan executors (static — no self access) + # ------------------------------------------------------------------ + + @staticmethod + def _build_remote_descs_from_plan( + plan: EngineTransferPlan, + nixl_agent_meta: "NixlAgentMetadata", + ) -> list[tuple[int, int, int]]: + """Build (addr, len, dev_id) descriptor tuples from plan.""" + result: list[tuple[int, int, int]] = [] + dev_id = nixl_agent_meta.device_id + + for region in plan.all_regions: + base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] + for blk in range(region.num_blocks): + blk_addr = base_addr + blk * region.page_stride + region.offset_in_page + for sub in range(region.descs_per_block): + addr = blk_addr + sub * region.desc_stride_bytes + result.append((addr, region.descriptor_bytes, dev_id)) + + return result + + @staticmethod + def _compute_desc_ids_from_plan( + plan: EngineTransferPlan, + block_ids: BlockIds, + dst_num_blocks: int, + block_size_ratio: float | None, + physical_blocks_per_logical: int, + ) -> np.ndarray: + """Compute NIXL descriptor IDs for given block IDs.""" + assert len(block_ids) == len(plan.group_spec_types), ( + f"block_ids has {len(block_ids)} groups but plan has " + f"{len(plan.group_spec_types)} group_spec_types" + ) + num_fa_regions = len(plan.fa_regions) + num_ssm_regions = len(plan.ssm_regions) + + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + ratio = physical_blocks_per_logical + logical_blocks = num_blocks // ratio + + num_fa_descs = num_fa_regions * num_blocks + + # NOTE (NickLucche) With HMA, every kv group has the same number + # of layers and layers from different groups share the same kv + # tensor. Therefore we compute desc IDs per group using the + # right stride: + # FA descs have num_blocks entries per region (kernel granularity), + # SSM descs have logical_blocks entries per region (no kernel + # splitting). + all_descs: list[np.ndarray] = [] + for i, group in enumerate(block_ids): + group_arr = np.asarray(group) + spec_type = plan.group_spec_types[i] + if _is_attention_spec(spec_type): + fa_region_ids = np.arange(num_fa_regions)[:, None] + all_descs.append( + (fa_region_ids * num_blocks + group_arr[None, :]).flatten() + ) + elif _is_ssm_spec(spec_type): + # NOTE (NickLucche) SSM and Attention block regions can + # be exchanged arbitrarily by manager. Therefore, descs + # are laid out as: + # [descs_fa (all regions) | descs_ssm (all regions)]. + # num_fa_descs offset must be computed per-engine since + # P and D can have different num_blocks (and thus + # different FA desc counts). + ssm_region_ids = np.arange(num_ssm_regions)[:, None] + all_descs.append( + ( + ssm_region_ids * logical_blocks + + group_arr[None, :] + + num_fa_descs + ).flatten() + ) + else: + raise ValueError(f"Unknown spec type {spec_type} at index {i}") + + return np.concatenate(all_descs) + + @staticmethod + def _compute_read_specs_from_plan( + plan: EngineTransferPlan, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, + ) -> list[ReadSpec]: + """Compute read specs from plan. + + Dispatches to the correct remapping strategy: + + - **Gather-read** (``local_page_size > remote_page_size``): per-rank + local descriptor pairing via ``_build_gather_read_specs``. + - **Split-read** (``remote_page_size > local_page_size``): + rank-independent remote descriptor remapping via + ``_remap_remote_blocks_to_desc_ids``. + - **Standard**: direct per-rank group filtering. + """ + if ( + plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block + ): + specs = _build_gather_read_specs(plan, local_block_ids, remote_block_ids) + if logger.isEnabledFor(logging.DEBUG): + for s in specs: + for g in range(len(s.local_block_ids)): + if s.local_block_ids[g]: + logger.debug( + "[ReadSpec gather] rank=%d group=%d: " + "local[:5]=%s remote[:5]=%s " + "(n_local=%d, n_remote=%d)", + s.remote_rank, + g, + s.local_block_ids[g][:5], + s.remote_block_ids[g][:5], + len(s.local_block_ids[g]), + len(s.remote_block_ids[g]), + ) + return specs + + remote_block_ids, local_block_ids = _remap_remote_blocks_to_desc_ids( + plan, remote_block_ids, local_block_ids + ) + + num_groups = len(local_block_ids) + specs = [ + ReadSpec( + remote_rank=rank, + local_block_ids=[ + list(local_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + remote_block_ids=[ + list(remote_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + ) + for rank in plan.all_source_ranks + ] + if logger.isEnabledFor(logging.DEBUG): + for s in specs: + for g in range(num_groups): + if s.local_block_ids[g]: + logger.debug( + "[ReadSpec std/split] rank=%d group=%d: " + "local[:5]=%s remote[:5]=%s " + "(n_local=%d, n_remote=%d)", + s.remote_rank, + g, + s.local_block_ids[g][:5], + s.remote_block_ids[g][:5], + len(s.local_block_ids[g]), + len(s.remote_block_ids[g]), + ) + return specs + + @staticmethod + def _build_local_splits_from_plan( + plan: EngineTransferPlan, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, + ) -> list[list[tuple[int, int, int]]]: + """Build split handle data for P_TP > D_TP scenario. + + num_fa_descs is the boundary between FA and SSM descriptors. + Split counts are derived from source_ranks_per_group lengths. + FA uses rank_to_attention_slot for the slot offset; + SSM uses the rank's positional index. + """ + # Mamba-HMA: FA and Mamba use different split factors. + fa_num_splits = next( + len(ranks) + for t, ranks in zip(plan.group_spec_types, plan.source_ranks_per_group) + if _is_attention_spec(t) + ) + + has_ssm_descs = num_fa_descs < len(src_blocks_data) + ssm_num_splits = ( + next( + ( + len(ranks) + for t, ranks in zip( + plan.group_spec_types, plan.source_ranks_per_group + ) + if _is_ssm_spec(t) + ), + 0, + ) + if has_ssm_descs + else 0 + ) + + fa_slot_map = plan.rank_to_attention_slot[0] + + result: list[list[tuple[int, int, int]]] = [] + + for p_idx, p_rank in enumerate(plan.all_source_ranks): + fa_slot = fa_slot_map.get(p_rank, 0) + + handle: list[tuple[int, int, int]] = [] + for j, (addr, local_len, dev) in enumerate(src_blocks_data): + if j < num_fa_descs: + chunk = local_len // fa_num_splits + handle.append((addr + fa_slot * chunk, chunk, dev)) + else: + chunk = local_len // ssm_num_splits + handle.append((addr + p_idx * chunk, chunk, dev)) + result.append(handle) + + return result + + def _build_local_descs( + self, + base_addresses: list[int], + block_size_ratio: int, + ) -> list[tuple[int, int, int]]: + """Build local (src) descriptor tuples for NIXL registration.""" + assert self.transfer_topo is not None + fa_regions = build_fa_local_regions( + self.num_blocks, + block_size_ratio, + self.block_len_per_layer, + self.transfer_topo.is_kv_layout_blocks_first, + ) + if self._has_mamba: + # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the + # 3-read split is unnecessary — a single conv desc per block + # suffices. Consider adding a fast path. Currently we always + # register 4 regions because local descs are created before + # knowing the remote TP. + assert self._conv_decomp is not None + mamba_regions = build_mamba_local_regions( + self.block_len_per_layer, + self._logical_num_blocks, + block_size_ratio, + self._conv_decomp, + self._mamba_ssm_size, + self._physical_blocks_per_logical_kv_block, + ) + else: + mamba_regions = [] + + result: list[tuple[int, int, int]] = [] + for region in fa_regions + mamba_regions: + base = base_addresses[region.layer_idx] + for blk in range(region.num_blocks): + addr = base + blk * region.page_stride + region.offset_in_page + result.append((addr, region.descriptor_bytes, self.device_id)) + return result + def __init__( self, vllm_config: "VllmConfig", @@ -119,44 +389,60 @@ def __init__( } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) - # ---- Mamba model state (derived from model config) ---- - self._is_mamba_group = [ - isinstance(group.kv_cache_spec, MambaSpec) - for group in kv_cache_config.kv_cache_groups - ] + # ---- Model state (derived from model config) ---- mamba_ssm_size = (0, 0) - self._has_mamba = any(self._is_mamba_group) + # Conv state sub-projection decomposition (None when no Mamba). + # The 3-read transfer requires DS (dim, state_len) conv layout so + # that x/B/C sub-projections are contiguous in memory. + self._conv_decomp: MambaConvSplitInfo | None = None + self._has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) + for g in kv_cache_config.kv_cache_groups + ) if self._has_mamba: assert self._is_hma_required + from vllm.model_executor.layers.mamba.mamba_utils import ( + is_conv_state_dim_first, + ) + + assert is_conv_state_dim_first(), ( + "3-read Mamba conv transfer requires DS conv state layout. " + "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + ) mamba_spec = next( spec for spec in self._layer_specs.values() if isinstance(spec, MambaSpec) ) - conv_nbytes, ssm_nbytes = ( - torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc] - torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc] - ) - conv_shape, ssm_shape = ( - torch.Size(mamba_spec.shapes[0]), - torch.Size(mamba_spec.shapes[1]), - ) - mamba_ssm_size = ( - conv_shape.numel() * conv_nbytes, - ssm_shape.numel() * ssm_nbytes, + self._conv_decomp = derive_mamba_conv_split( + mamba_spec, + vllm_config.parallel_config.tensor_parallel_size, ) + mamba_ssm_size = self._conv_decomp.ssm_sizes self._mamba_ssm_size = mamba_ssm_size - # Conv state sub-projection decomposition (None when no Mamba). - # The 3-read transfer requires DS (dim, state_len) conv layout so - # that x/B/C sub-projections are contiguous in memory. - self._conv_decomp: MambaConvSplitInfo | None = None - if self._has_mamba: - assert is_conv_state_dim_first(), ( - "3-read Mamba conv transfer requires DS conv state layout. " - "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + + # ---- Heterogeneous attention detection (e.g. Gemma4 SWA + FA) ---- + tp_size = vllm_config.parallel_config.tensor_parallel_size + attn_specs = [ + g.kv_cache_spec + for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ] + self._is_hetero_attn = ( + len(attn_specs) > 1 and len({s.num_kv_heads for s in attn_specs}) > 1 + ) + if self._is_hetero_attn: + self._total_kv_heads_per_group = tuple( + s.total_num_kv_heads + if s.total_num_kv_heads is not None + else s.num_kv_heads * tp_size + for s in attn_specs + ) + unified_page = max(s.page_size_bytes for s in attn_specs) + self._local_tokens_per_block_per_group = tuple( + s.block_size * unified_page // s.real_page_size_bytes + for s in attn_specs ) - local_tp = vllm_config.parallel_config.tensor_parallel_size - self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] @@ -260,6 +546,9 @@ def __init__( # Populated dynamically during handshake based on remote configuration. # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} + # Gather-read local handles: local blocks split into descriptors + # matching remote page size. Keyed by engine_id. + self._gather_read_handles: dict[EngineId, int] = {} # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict) @@ -268,14 +557,6 @@ def __init__( self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] - # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- - # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. - # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) - # where conv/ssm bytes are per-TP-rank (dimension-sharded). With - # heterogeneous TP the per-rank sizes differ, so the ratio differs: - # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. - self._physical_blocks_per_logical: dict[EngineId, int] = {} - # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} @@ -330,6 +611,10 @@ def __init__( self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() + # Per-engine transfer plans. Generated during handshake, used by + # per-request hot path (model-agnostic). + self._transfer_plans: dict[EngineId, EngineTransferPlan] = {} + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) @@ -812,9 +1097,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.dst_num_blocks[self.engine_id] = self.num_blocks if self._has_mamba: - self._physical_blocks_per_logical[self.engine_id] = ( - self._physical_blocks_per_logical_kv_block - ) logger.info( "Hybrid SSM registration: num_blocks=%s, " "logical_num_blocks=%s, ratio=%s, num_regions=%s, " @@ -847,6 +1129,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ssm_sizes=self._mamba_ssm_size, attn_backend_name=self.backend_name, + physical_blocks_per_logical_kv_block=( + self._physical_blocks_per_logical_kv_block + ), + tokens_per_block_per_group=( + list(self._local_tokens_per_block_per_group) + if self._is_hetero_attn + else None + ), ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -856,157 +1146,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata_bytes=encoder.encode(agent_metadata), ) - def _build_mamba_local( - self, - base_addresses: list[int], - block_size_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 desc regions (x, B, C, ssm) per layer for local mamba - blocks, enabling the 3-read transfer with DS conv layout.""" - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - assert self._conv_decomp is not None - conv_offsets = self._conv_decomp.local_conv_offsets - conv_size, ssm_size = self._mamba_ssm_size - num_blocks = self._logical_num_blocks * block_size_ratio - physical_per_logical = self._physical_blocks_per_logical_kv_block - - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): - page_stride = ( - self.block_len_per_layer[i] // block_size_ratio * physical_per_logical - ) - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append( - (base_addr + blk * page_stride + off, sz, self.device_id) - ) - # SSM temporal state follows the conv state. - for blk in range(num_blocks): - result.append( - ( - base_addr + blk * page_stride + conv_size, - ssm_size, - self.device_id, - ) - ) - return result - - def _build_fa_remote_for_mamba( - self, - nixl_agent_meta: NixlAgentMetadata, - block_size_ratio: int, - transfer_topo: TransferTopology, - remote_engine_id: EngineId, - ) -> list[tuple[int, int, int]]: - """Build remote FA descriptors for mamba models. - - Uses TransferTopology for GQA-aware FA divisor and head-based rank - offset instead of the standard uniform tp_ratio split. - """ - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA - # hetero-TP logic stabilizes. - mamba_info = transfer_topo.get_engine_info(remote_engine_id) - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size) - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - local_block_len = local_block_len // mamba_info.remote_num_fa_reads - - rank_offset = transfer_topo.fa_rank_offset( - remote_engine_id, remote_kv_block_len - ) - - num_blocks = nixl_agent_meta.num_blocks - page_size = nixl_agent_meta.block_lens[i] - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - result.append((addr, local_block_len, nixl_agent_meta.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // mamba_info.remote_num_fa_reads - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append((v_addr, second_split, nixl_agent_meta.device_id)) - return result - - def _build_mamba_remote( - self, - nixl_agent_meta: NixlAgentMetadata, - tp_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 remote desc regions (x, B, C, ssm) per layer for - the 3-read transfer. For hetero-TP, each D rank reads only its - sub-projection slice from the P rank.""" - assert self._conv_decomp is not None - effective_ratio = max(tp_ratio, 1) - # Mamba conv state is always TP-sharded, even when attention KV - # is replicated (num_kv_heads < tp_size). - local_offset = self.tp_rank % effective_ratio - conv_size_remote = nixl_agent_meta.ssm_sizes[0] - - if tp_ratio >= 1: - # D_TP >= P_TP: P page is larger, D reads its slice. - conv_offsets = self._conv_decomp.remote_conv_offsets( - local_offset, effective_ratio - ) - ssm_read_size = self._mamba_ssm_size[1] - else: - # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages - # are smaller than D's. self._conv_decomp has D-sized dimensions, - # but we need P-sized offsets. Scale down by |tp_ratio|. - abs_ratio = -tp_ratio - xb_p = self._conv_decomp.x_bytes // abs_ratio - bb_p = self._conv_decomp.b_bytes // abs_ratio - conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] - ssm_read_size = nixl_agent_meta.ssm_sizes[1] - - remote_physical_per_logical = self._physical_blocks_per_logical[ - nixl_agent_meta.engine_id - ] - num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical - device_id = nixl_agent_meta.device_id - - result: list[tuple[int, int, int]] = [] - # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case - # block lengths vary across layers (e.g. MLA). - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append((base_addr + blk * page_stride + off, sz, device_id)) - # SSM temporal state is also TP-sharded on the heads dimension. - for blk in range(num_blocks): - ssm_addr = ( - base_addr - + blk * page_stride - + conv_size_remote - + local_offset * ssm_read_size - ) - result.append((ssm_addr, ssm_read_size, device_id)) - return result - def register_local_xfer_handler( self, block_size: int, @@ -1023,74 +1162,17 @@ def register_local_xfer_handler( data copy correctness. """ assert self.transfer_topo is not None - transfer_topo = self.transfer_topo - block_size_ratio = self.block_size // block_size - blocks_data: list[tuple[int, int, int]] = [] local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] - def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool): - for i, base_addr in enumerate(local_base_addresses): - # The new block_len is using prefill block_len; - # and num_blocks is multiple with N - kv_block_len = ( - self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - // block_size_ratio - ) - # Jump one page_size, but ssm page_size may be bigger when kernel - # locks block size to a specific value. - block_len_per_layer = ( - self.block_len_per_layer[i] - // block_size_ratio - * (1 if not mamba else self._physical_blocks_per_logical_kv_block) - ) - num_blocks = self._logical_num_blocks if mamba else self.num_blocks - num_blocks = num_blocks * block_size_ratio - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, second_split, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - # NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used - # right now — we use _build_mamba_local instead for the 3-read - # approach. However, we might still need this as a fallback for homogeneous TP. - register_blocks(blocks_data, mamba=False) - if self._has_mamba: - assert self.num_descs == len(blocks_data) - # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is - # unnecessary — a single conv desc per block suffices. Consider - # adding a fast path that falls back to the standard 2-region - # registration (register_blocks mamba=True) when no hetero-TP - # remote has been seen. Currently we always register 4 regions - # because local descs are created before knowing the remote TP. - logger.debug("Registering local Mamba descriptors (4 regions/layer)") - blocks_data.extend( - self._build_mamba_local(local_base_addresses, block_size_ratio) - ) + blocks_data = self._build_local_descs(local_base_addresses, block_size_ratio) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1160,26 +1242,80 @@ def add_remote_agent( assert self.transfer_topo is not None transfer_topo = self.transfer_topo physical_blocks_per_logical = ( - compute_physical_blocks_per_logical( - nixl_agent_meta.ssm_sizes, - nixl_agent_meta.block_lens[0], - ) - if self._has_mamba - else 1 + nixl_agent_meta.physical_blocks_per_logical_kv_block ) - transfer_topo.register_remote_engine( - remote_engine_id=engine_id, + transfer_info = EngineTransferInfo( remote_tp_size=remote_tp_size, remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], remote_physical_blocks_per_logical=physical_blocks_per_logical, - local_block_len=self.block_len_per_layer[0], ) - if self._has_mamba and engine_id not in self._physical_blocks_per_logical: - self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical - + transfer_topo.register_remote_engine(engine_id, transfer_info) logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) + # Generate the transfer plan for this remote engine. + if self._has_mamba: + assert self._conv_decomp is not None + self._transfer_plans[engine_id] = generate_mamba_plan( + transfer_topo=transfer_topo, + block_len_per_layer=self.block_len_per_layer, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, + group_spec_types=tuple( + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups + ), + conv_decomp=self._conv_decomp, + ssm_sizes=self._mamba_ssm_size, + ) + elif self._is_hetero_attn: + remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ()) + group_spec_types = tuple( + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups + ) + assert len(remote_tpb) == len(group_spec_types), ( + f"Remote tokens_per_block_per_group length " + f"{len(remote_tpb)} != {len(group_spec_types)} groups" + ) + logger.info( + "[HeteroTP] Generating Gemma4 plan: " + "group_spec_types=%s, total_kv_heads_per_group=%s, " + "local_tpb_per_group=%s, remote_tpb_per_group=%s, " + "local_block_len_per_layer=%s, remote_block_lens=%s, " + "local_tp=%d, remote_tp=%d, tp_rank=%d", + [t.__name__ for t in group_spec_types], + self._total_kv_heads_per_group, + self._local_tokens_per_block_per_group, + remote_tpb, + self.block_len_per_layer[:3], + nixl_agent_meta.block_lens[:3], + transfer_topo.tp_size, + remote_tp_size, + transfer_topo.tp_rank, + ) + self._transfer_plans[engine_id] = generate_gemma4_plan( + transfer_topo=transfer_topo, + block_len_per_layer=self.block_len_per_layer, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, + group_spec_types=group_spec_types, + total_num_kv_heads_per_group=(self._total_kv_heads_per_group), + local_tokens_per_block=(self._local_tokens_per_block_per_group), + remote_tokens_per_block=remote_tpb, + ) + else: + self._transfer_plans[engine_id] = generate_dense_plan( + transfer_topo=transfer_topo, + block_len_per_layer=self.block_len_per_layer, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, + group_spec_types=tuple( + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups + ), + local_physical_blocks_per_logical=( + self._physical_blocks_per_logical_kv_block + ), + ) + remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) @@ -1206,11 +1342,6 @@ def add_remote_agent( # this is the ratio between the two sizes. tp_ratio = transfer_topo.tp_ratio(remote_tp_size) - # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = ( - not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 - ) - logger.debug( "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", engine_id, @@ -1218,158 +1349,66 @@ def add_remote_agent( tp_ratio, ) - ### (Optional) Register local agent memory regions. MLA is not split. + plan = self._transfer_plans[engine_id] + + ### (Optional) Register local agent memory regions. if ( + plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block + and engine_id not in self._gather_read_handles + ): + # Gather-read: local page > remote page. Register local + # descriptors matching the remote block size. + assert self.transfer_topo is not None + descs_per_local_block = plan.local_page_size // plan.remote_page_size + local_base_addresses = self.kv_caches_base_addr[self.engine_id][ + self.tp_rank + ] + gather_blocks_data = build_fa_local_descs_for_gather_read( + base_addresses=local_base_addresses, + device_id=self.device_id, + num_blocks=self.num_blocks, + block_len_per_layer=self.block_len_per_layer, + is_blocks_first=self.transfer_topo.is_kv_layout_blocks_first, + gather_page_ratio=descs_per_local_block, + ) + descs = self.nixl_wrapper.get_xfer_descs( + gather_blocks_data, self.nixl_memory_type + ) + self._gather_read_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs + ) + elif ( tp_ratio < 0 and not self.use_mla and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): + # MLA is not split. # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). - abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - if self._has_mamba: - if transfer_topo.needs_split_handles(engine_id): - # Mamba-HMA: FA and Mamba use different split factors. - for handle_data in transfer_topo.compute_split_handle_data( - engine_id, self.src_blocks_data, self.num_descs, abs_tp - ): - descs = self.nixl_wrapper.get_xfer_descs( - handle_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - logger.info( - "Mamba-HMA split handles: %s, num_descs=%s", - transfer_topo.describe(engine_id), - self.num_descs, - ) - else: - # Original path: uniform divide by abs_tp (non-Mamba-HMA). - for i in range(abs_tp): - blocks_data = [] - for memory_region in self.src_blocks_data: - addr, local_block_len, own_tp_rank = memory_region - remote_block_len = local_block_len // abs_tp - addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs( - blocks_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - ### Register remote agent memory regions - blocks_data = [] - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogeneous TP, prepare the descriptors by splitting the - # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - - # Register all remote blocks, but only the corresponding kv heads. - def register_remote_blocks( - blocks_data: list[tuple[int, int, int]], mamba: bool - ): - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote. - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - # Remote tp is bigger: read a chunk of local region from remote - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if indexes_into_remote - else 0 - ) - - # Assume same num_blocks for mamba and fa - num_blocks = ( - nixl_agent_meta.num_blocks - if not mamba - else nixl_agent_meta.num_blocks - // self._physical_blocks_per_logical_kv_block - ) - page_size = nixl_agent_meta.block_lens[i] * ( - 1 if not mamba else self._physical_blocks_per_logical_kv_block + for handle_data in self._build_local_splits_from_plan( + plan, + self.src_blocks_data, + self.num_descs, + ): + descs = self.nixl_wrapper.get_xfer_descs( + handle_data, self.nixl_memory_type ) - for block_id in range(num_blocks): - block_offset = block_id * page_size - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append( - (addr, local_block_len, nixl_agent_meta.device_id) - ) - - if transfer_topo.is_kv_layout_blocks_first: - # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Apply the same scaling as local_block_len above for when we read - # a chunk of local V from `tp_ratio` separate remote workers. - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // (-tp_ratio) - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - # Hop over the first split of remote page: either K or Conv. - if mamba: - v_addr = addr + nixl_agent_meta.ssm_sizes[0] - else: - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append( - (v_addr, second_split, nixl_agent_meta.device_id) - ) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - logger.debug( - "Created %s blocks for dst engine %s" - " with remote rank %s and local rank %s", - len(blocks_data), - engine_id, - remote_tp_rank, - self.tp_rank, - ) - - if self._has_mamba: - # Mamba-HMA: separate FA registration with GQA-aware sizing, - # plus mamba 3-read registration for the Mamba "view" of the - # same KV cache tensors. - logger.debug( - "Registering remote Mamba blocks for engine %s rank %s", - engine_id, - remote_tp_rank, - ) - blocks_data.extend( - self._build_fa_remote_for_mamba( - nixl_agent_meta, - block_size_ratio, - transfer_topo, - engine_id, - ) - ) - blocks_data.extend( - self._build_mamba_remote( - nixl_agent_meta, - tp_ratio, - ) - ) - else: - register_remote_blocks(blocks_data, mamba=False) + ### Register remote agent memory regions + blocks_data = self._build_remote_descs_from_plan(plan, nixl_agent_meta) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -1897,27 +1936,56 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.transfer_topo is not None engine_id = meta.remote.engine_id - remote_ranks = self.transfer_topo.target_remote_ranks(engine_id) + plan = self._transfer_plans[engine_id] remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) - if self._has_mamba: - # Expand remote logical → kernel block IDs. - meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( - meta.remote.block_ids, - self._physical_blocks_per_logical[meta.remote.engine_id], - ) - else: - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids - ) + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids, + plan.remote_expansion_stride, + ) + remote_block_ids = meta.remote.block_ids + read_specs = self._compute_read_specs_from_plan( + plan, + local_block_ids=meta.local_physical_block_ids, + remote_block_ids=remote_block_ids, + ) + # D may have to perform multiple reads from different remote ranks. - for i, remote_rank in enumerate(remote_ranks): - if self.use_mla and tp_ratio < 0 and i > 0: - # MLA opt: when P TP > D TP, only a single read is executed for - # the first remote rank (cache is duplicated).. - break + # MLA opt: when P TP > D TP, only a single read is executed for + # the first remote rank (cache is duplicated). + if self.use_mla and tp_ratio < 0: + read_specs = read_specs[:1] + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[HeteroTP _read_blocks_for_req] req=%s engine=%s " + "tp_ratio=%d, n_read_specs=%d, " + "plan: local_page=%d, remote_page=%d, " + "group_spec_types=%s, " + "local_physical_block_ids=[%s], " + "remote_block_ids=[%s]", + req_id, + engine_id, + tp_ratio, + len(read_specs), + plan.local_page_size, + plan.remote_page_size, + [t.__name__ for t in plan.group_spec_types], + ", ".join( + f"g{i}:{meta.local_physical_block_ids[i][:5]}" + for i in range(len(meta.local_physical_block_ids)) + ), + ", ".join( + f"g{i}:{meta.remote.block_ids[i][:5]}" + for i in range(len(meta.remote.block_ids)) + ), + ) + + for i, spec in enumerate(read_specs): + remote_rank = spec.remote_rank + local_block_ids = spec.local_block_ids + remote_block_ids = spec.remote_block_ids remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" @@ -1928,7 +1996,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): req_id, ) # Get side handles. - if tp_ratio < 0 and not self.use_mla: + if engine_id in self._gather_read_handles: + # Gather-read: local descriptor handle matches remote page size. + local_xfer_side_handle = self._gather_read_handles[engine_id] + elif tp_ratio < 0 and not self.use_mla: assert remote_block_size == self.block_size # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. @@ -1945,37 +2016,26 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_rank ] - local_ids: BlockIds = meta.local_physical_block_ids - remote_ids: BlockIds = meta.remote.block_ids - if self._has_mamba: - # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. - local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank( - engine_id, - remote_rank, - local_ids, - remote_ids, - self._is_mamba_group, - ) - self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, remote_request_id=meta.remote.request_id, - local_block_ids=local_ids, - remote_block_ids=remote_ids, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, remote_rank=remote_rank, local_xfer_side_handle=local_xfer_side_handle, remote_xfer_side_handle=remote_xfer_side_handle, ) - if self.use_mla and tp_ratio < 0: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + if self.use_mla and tp_ratio < 0 and read_specs: + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + read_ranks = {s.remote_rank for s in read_specs} + for rank_to_notify, agent in remote_agents.items(): + if rank_to_notify not in read_ranks: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, @@ -1993,6 +2053,7 @@ def _read_blocks( a single remote worker. """ assert self.transfer_topo is not None + plan = self._transfer_plans[dst_engine_id] remote_info = self.transfer_topo.get_engine_info(dst_engine_id) block_size_ratio = self.transfer_topo.block_size_ratio( remote_info.remote_block_size @@ -2061,35 +2122,112 @@ def _read_blocks( == len(local_block_ids) == len(self.kv_cache_config.kv_cache_groups) ) + # Partial prefix cache hit: trim to the shorter of local/remote. + # Skip mamba groups — their blocks represent full state (conv+ssm), + # not per-token data, so trimming would corrupt the transfer. + # After ReadSpec construction, local descriptor IDs and remote + # block IDs should already have matched lengths per group + # (gather-read pairing ensures this). Trim from the head to + # keep the tail (newest blocks). remote_block_ids = list(remote_block_ids) - for i, remote_group in enumerate(remote_block_ids): - num_remote_blocks = len(remote_group) - num_local_blocks = len(local_block_ids[i]) - if not self._is_mamba_group[i]: - assert num_local_blocks <= num_remote_blocks - # Partial prefix cache hit: just read uncomputed blocks. - # Skip mamba groups — their blocks represent full state (conv+ssm), - # not per-token data, so trimming would corrupt the transfer. - if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: - remote_block_ids[i] = remote_group[-num_local_blocks:] + local_block_ids = list(local_block_ids) + group_specs = self.kv_cache_config.kv_cache_groups + for i in range(len(remote_block_ids)): + is_mamba = isinstance(group_specs[i].kv_cache_spec, MambaSpec) + if is_mamba: + continue + n_local = len(local_block_ids[i]) + n_remote = len(remote_block_ids[i]) + n = min(n_local, n_remote) + if n_local > n: + local_block_ids[i] = local_block_ids[i][-n:] + if n_remote > n: + remote_block_ids[i] = remote_block_ids[i][-n:] + + for i in range(len(remote_block_ids)): + assert len(local_block_ids[i]) == len(remote_block_ids[i]), ( + f"Block ID length mismatch after trim: group={i}, " + f"n_local={len(local_block_ids[i])}, " + f"n_remote={len(remote_block_ids[i])}. " + f"ReadSpec should produce matched lengths." + ) # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. - # Get descs ids. - remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, - remote_block_ids, + # Get descs ids. Both calls use the same plan since region counts + # (len(fa_regions), len(ssm_regions)) are model-determined and + # identical across engines. + if ( + plan.remote_page_size > plan.local_page_size + and plan.local_blocks_per_remote_block + ): + # Split-read (Gemma4): each remote block → multiple descriptors. + remote_desc_blocks = ( + self.dst_num_blocks[dst_engine_id] + * plan.remote_page_size + // plan.local_page_size + ) + local_desc_blocks = self.dst_num_blocks[self.engine_id] + elif ( + plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block + ): + # Gather-read (Gemma4): each local block → multiple descriptors. + remote_desc_blocks = self.dst_num_blocks[dst_engine_id] + local_desc_blocks = ( + self.dst_num_blocks[self.engine_id] + * plan.local_page_size + // plan.remote_page_size + ) + else: + # Standard: 1:1 block-to-descriptor mapping. + remote_desc_blocks = self.dst_num_blocks[dst_engine_id] + local_desc_blocks = self.dst_num_blocks[self.engine_id] + + remote_block_descs_ids = self._compute_desc_ids_from_plan( + plan, + block_ids=remote_block_ids, + dst_num_blocks=remote_desc_blocks, + block_size_ratio=None, + physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) - local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, - local_block_ids, + local_block_descs_ids = self._compute_desc_ids_from_plan( + plan, + block_ids=local_block_ids, + dst_num_blocks=local_desc_blocks, block_size_ratio=block_size_ratio, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[HeteroTP _read_blocks] req=%s rank=%d: " + "n_descs=%d, block_size_ratio=%s, " + "remote_desc_blocks=%d, local_desc_blocks=%d, " + "local_desc_ids[:10]=%s, remote_desc_ids[:10]=%s, " + "local_block_ids=[%s], remote_block_ids=[%s]", + request_id, + remote_rank, + len(local_block_descs_ids), + block_size_ratio, + remote_desc_blocks, + local_desc_blocks, + local_block_descs_ids[:10].tolist(), + remote_block_descs_ids[:10].tolist(), + ", ".join( + f"g{i}:{local_block_ids[i][:5]}" + for i in range(len(local_block_ids)) + ), + ", ".join( + f"g{i}:{remote_block_ids[i][:5]}" + for i in range(len(remote_block_ids)) + ), + ) + # Prepare transfer with Nixl. handle = None try: @@ -2147,63 +2285,6 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) - def _get_block_descs_ids( - self, - engine_id: str, - block_ids: BlockIds, - block_size_ratio: float | None = None, - ) -> np.ndarray: - """ - Get the descs ids for a set of block ids. - When HMA is enabled number of descriptors across kv cache groups might differ. - A single flattened array is returned for all groups anyway. - """ - region_ids = np.arange(self.num_regions) - - # NOTE (NickLucche) With HMA, every kv group has the same number of layers and - # layers from different groups share the same kv tensor. - # eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, - # same for [3], but group0-group1 blocks will always differ (different areas). - # Therefore we can just flatten the block_ids and compute the descs ids for all - # groups at once. - num_blocks = self.dst_num_blocks[engine_id] - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - - # Compute desc ids per group using the right stride: FA descs have - # num_blocks entries per region (kernel granularity), SSM descs have - # logical_blocks entries per region (no kernel splitting). - region_ids = region_ids[:, None] - if not self._has_mamba: - block_ids = np.concatenate(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids - return descs_ids.flatten() - else: - # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged - # arbitrarily by manager. Therefore, descs are duplicated for SSM and - # Attention like so: - # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. - # This is like having two "low-level views" of the same storage. - # `num_fa_descs` offset must be computed per-engine since P and D can - # have different num_blocks (and thus different FA descs counts). - physical_per_logical = self._physical_blocks_per_logical[engine_id] - logical_blocks = num_blocks // physical_per_logical - num_fa_descs = self.num_regions * num_blocks - # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). - mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] - all_descs = [] - for i, group in enumerate(block_ids): - group_arr = np.asarray(group)[None, :] - if self._is_mamba_group[i]: - all_descs.append( - ( - mamba_region_ids * logical_blocks + group_arr + num_fa_descs - ).flatten() - ) - else: - all_descs.append((region_ids * num_blocks + group_arr).flatten()) - return np.concatenate(all_descs) - def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: """ Convert logical block ids to kernel physical block ids. @@ -2260,52 +2341,6 @@ def _logical_to_remote_kernel_block_ids( result.append(group) return result - def get_backend_aware_kv_block_len( - self, layer_idx: int, first_split: bool = True, mamba_view: bool = False - ) -> int: - """ - Get the block length for one K/V element (K and V have the same size). - - For FA and other backends, this is equal to the length of the whole - block, as K and V are in separate regions. - For FlashInfer, this is half the length of the whole block, as K and V - share the same region. - Similarly, for SSM-based models, state and conv are interleaved, but crucially - the their size differs. - Reference diagram: - KVCacheTensor (Shared) - / \\ - / \\ - / \\ - Attention (FlashInfer) View Mamba View - | | - | | - +-------------------+ +-------------------+ - | KVCacheTensor | | KVCacheTensor | - | | | | - |<----- page ------>| |<----- page ------->| - | size | | size | - | Key 0 | Val 0 | |Conv 0 | SSM 0 | - | Key 1 | Val 1 | |Conv 1 | SSM 1 | - | ... | ... | | ... | ... | - | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | - | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | - +-------------------+ +--------------------+ - |1st_split-2nd_split| |1st_split-2nd_split | - """ - assert self.transfer_topo is not None - if self.transfer_topo.is_kv_layout_blocks_first: - # For indexing only half (either just the K or V part). - if mamba_view: - # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so - # we're only transferring the minimum required bytes. - block_len = self._mamba_ssm_size[not first_split] - else: - block_len = self.block_len_per_layer[layer_idx] // 2 - else: - block_len = self.block_len_per_layer[layer_idx] - return block_len - def get_kv_connector_stats(self) -> KVConnectorStats | None: """ Get the KV transfer stats for the connector. @@ -2346,6 +2381,9 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_dlist_handle(handle) self.src_xfer_handles_by_tp_ratio.clear() + for handle in self._gather_read_handles.values(): + self.nixl_wrapper.release_dlist_handle(handle) + self._gather_read_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py index 309426814c68..00b8e2bb7275 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py @@ -31,6 +31,7 @@ class MambaConvSplitInfo: x_local: int # intermediate_size / TP (columns for x) b_local: int # groups_ss / TP (columns for B; C is same size) conv_dtype_size: int # bytes per element (e.g. 2 for float16) + ssm_sizes: tuple[int, int] # (conv_state_bytes, ssm_state_bytes) @property def conv_dim_local(self) -> int: @@ -99,8 +100,8 @@ def derive_mamba_conv_split( local_tp: this engine's tensor-parallel size. Returns: - MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and - conv_dtype_size. + MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, + conv_dtype_size, and ssm_sizes (conv_state_bytes, ssm_state_bytes). """ if mamba_spec.mamba_type != "mamba2": raise NotImplementedError( @@ -142,12 +143,20 @@ def derive_mamba_conv_split( dtype=mamba_spec.dtypes[0], # type: ignore[misc] ).element_size() + ssm_dtype_size = torch.tensor( + [], + dtype=mamba_spec.dtypes[1], # type: ignore[misc] + ).element_size() + conv_state_bytes = torch.Size(mamba_spec.shapes[0]).numel() * conv_dtype_size + ssm_state_bytes = torch.Size(mamba_spec.shapes[1]).numel() * ssm_dtype_size + # Divide by TP to get per-rank column counts. return MambaConvSplitInfo( conv_rows=conv_rows, x_local=intermediate_size // local_tp, b_local=groups_ss // local_tp, conv_dtype_size=conv_dtype_size, + ssm_sizes=(conv_state_bytes, ssm_state_bytes), ) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index db9ae2bbda34..c0b83cf35e68 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -203,6 +203,7 @@ def __init__( kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, head_size_v: int | None = None, + total_num_kv_heads: int | None = None, **extra_impl_args, ) -> None: """ @@ -285,6 +286,7 @@ def __init__( self.head_size = head_size self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads + self.total_num_kv_heads = total_num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None @@ -552,6 +554,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, kv_quant_mode=quant_mode, + total_num_kv_heads=self.total_num_kv_heads, sliding_window=self.sliding_window, ) elif self.kv_cache_dtype.startswith("turboquant_"): @@ -579,6 +582,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, kv_quant_mode=quant_mode, + total_num_kv_heads=self.total_num_kv_heads, ) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index b724fa71968c..52543f04d654 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -492,6 +492,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + total_num_kv_heads=self.total_num_kv_heads, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2545c440368a..69661bf3d7e5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -133,6 +133,7 @@ class AttentionSpec(KVCacheSpec): dtype: torch.dtype kv_quant_mode: KVQuantMode = KVQuantMode.NONE page_size_padded: int | None = None + total_num_kv_heads: int | None = None @property def page_size_bytes(self) -> int: @@ -234,6 +235,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + total_num_kv_heads=specs[0].total_num_kv_heads, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -601,6 +603,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + total_num_kv_heads=specs[0].total_num_kv_heads, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), )