diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6e399db7b140..57f5f37d48a5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -3,7 +3,7 @@ """Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" import gc -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import torch @@ -583,6 +583,14 @@ def _make_mock_worker_for_desc_ids( worker._has_mamba = has_mamba worker._group_spec_types = group_spec_types worker.block_len_per_layer = block_len_per_layer or [100] + # Derive _is_ssm_region from group_spec_types: one entry per + # unique base-address region. With _cross_layers_blocks each group + # maps to one region, so we use a simple per-group flag. + from vllm.v1.kv_cache_interface import MambaSpec + + worker._is_ssm_region = [issubclass(t, MambaSpec) for t in group_spec_types] + worker._is_attn_region = [not s for s in worker._is_ssm_region] + worker._attn_block_len = {} worker._compute_desc_ids = NixlConnectorWorker._compute_desc_ids.__get__( worker, NixlConnectorWorker ) @@ -1081,3 +1089,441 @@ def test_logical_to_remote_kernel_block_ids( assert list(result) == expected_kernel_block_ids, ( f"Expected {expected_kernel_block_ids}, got {result}" ) + + +# ── Dual-purpose HMA region tests ─────────────────────────────────────── + + +def _make_mock_worker_for_desc(**overrides): + """Build a mock NixlConnectorWorker with attrs for descriptor tests.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, + ) + + worker = object.__new__(NixlConnectorWorker) + defaults = { + "num_blocks": 4, + "_logical_num_blocks": 4, + "_physical_blocks_per_logical_kv_block": 1, + "device_id": 0, + "block_len_per_layer": [], + "_is_ssm_region": [], + "_is_attn_region": [], + "_attn_block_len": {}, + "_has_mamba": True, + "use_mla": False, + "num_regions": 0, + "_group_spec_types": (), + "_conv_decomp": None, + "_mamba_ssm_size": (0, 0), + } + for k, v in defaults.items(): + setattr(worker, k, overrides.get(k, v)) + + worker.transfer_topo = MagicMock() + worker.transfer_topo.virtually_split_kv_in_blocks = False + + return worker + + +def _make_mock_nixl_meta( + base_addrs, block_lens, num_blocks=4, device_id=0, ssm_sizes=(96, 64) +): + """Build a mock NixlAgentMetadata.""" + meta = MagicMock() + meta.kv_caches_base_addr = base_addrs + meta.block_lens = block_lens + meta.num_blocks = num_blocks + meta.device_id = device_id + meta.ssm_sizes = ssm_sizes + meta.physical_blocks_per_logical_kv_block = 1 + return meta + + +class TestBuildFaLocalDualPurpose: + """Tests for _build_fa_local with dual-purpose HMA regions (MLA models).""" + + @pytest.mark.cpu_test + def test_dual_purpose_uses_mla_stride(self): + """Dual-purpose regions use _attn_block_len (MLA) stride, not KDA.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], # KDA stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, # MLA stride + num_blocks=3, + ) + result = worker._build_fa_local([0x2000], block_size_ratio=1) + + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x2000 + i * 256 # MLA stride, not KDA 200 + assert size == 256 + + @pytest.mark.cpu_test + def test_skips_ssm_only_regions(self): + """Pure SSM regions (not attn) are skipped entirely.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200, 200], + _is_ssm_region=[True, True], + _is_attn_region=[True, False], # region 1 is SSM-only + _attn_block_len={0: 256}, + num_blocks=2, + ) + result = worker._build_fa_local([0x1000, 0x2000], block_size_ratio=1) + + assert len(result) == 2 + # Only region 0 (dual-purpose) generates FA descs + assert all(a < 0x2000 for a, _, _ in result) + + @pytest.mark.cpu_test + def test_no_block_size_ratio_for_dual_purpose(self): + """Dual-purpose: block_size_ratio does NOT scale MLA stride.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + num_blocks=2, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # MLA stride 256 unscaled (not 256//2=128) + assert result[0][1] == 256 + assert result[1][0] == 0x1000 + 256 + + @pytest.mark.cpu_test + def test_mixed_regions(self): + """Mix of dual-purpose, pure SSM, and pure attention regions.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200, 128, 300], + _is_ssm_region=[True, True, False], + _is_attn_region=[True, False, True], + _attn_block_len={0: 256}, # only region 0 is dual-purpose + num_blocks=2, + ) + result = worker._build_fa_local([0x1000, 0x2000, 0x3000], block_size_ratio=1) + + # Region 0 (dual-purpose, MLA stride 256) + Region 2 (pure attn, 300) + assert len(result) == 4 + assert result[0] == (0x1000, 256, 0) + assert result[1] == (0x1000 + 256, 256, 0) + assert result[2] == (0x3000, 300, 0) + assert result[3] == (0x3000 + 300, 300, 0) + + +class TestBuildFaRemoteDualPurpose: + """Tests for _build_fa_remote with dual-purpose HMA regions (MLA models).""" + + @pytest.mark.cpu_test + def test_dual_purpose_uses_local_mla_stride(self): + """Remote FA descs for dual-purpose use local MLA stride.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[200], num_blocks=3) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + assert len(result) == 3 + # page_size = _attn_block_len[0] = 256 (not remote's 200) + assert result[0] == (0x5000, 256, 0) + assert result[1] == (0x5000 + 256, 256, 0) + assert result[2] == (0x5000 + 512, 256, 0) + + @pytest.mark.cpu_test + def test_no_ratio_scaling_for_dual_purpose(self): + """Dual-purpose: block_size_ratio doesn't scale MLA stride.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[100], num_blocks=2) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=2) + + for addr, size, dev in result: + assert size == 256 # unscaled MLA stride + + +class TestNumRegionsDualPurpose: + """Tests for num_regions = sum(_is_attn_region).""" + + @pytest.mark.cpu_test + def test_pure_gdn(self): + """Qwen3.5 GDN: all SSM → num_regions = 0.""" + assert sum([False, False, False]) == 0 + + @pytest.mark.cpu_test + def test_pure_mla(self): + """Pure MLA: all attn → num_regions = N.""" + assert sum([True, True, True]) == 3 + + @pytest.mark.cpu_test + def test_kimilinear_dual_purpose(self): + """KimiLinear: 7 dual-purpose + 13 SSM-only → num_regions = 7. + + Old formula (len - sum(_is_ssm_region)) = 20 - 20 = 0 missed + the 7 dual-purpose regions. New formula correctly counts them. + """ + _is_attn = [True] * 7 + [False] * 13 + _is_ssm = [True] * 7 + [True] * 13 + + old_formula = len(_is_attn) - sum(_is_ssm) # 0 (wrong) + new_formula = sum(_is_attn) # 7 (correct) + + assert old_formula == 0 + assert new_formula == 7 + + +class TestNonHMARegression: + """Verify non-HMA models (Qwen3.5 GDN) are unaffected.""" + + @pytest.mark.cpu_test + def test_qwen35_gdn_skips_all_fa(self): + """Qwen3.5 GDN: all SSM, no attn → _build_fa_local produces 0.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[128], + _is_ssm_region=[True], + _is_attn_region=[False], + _attn_block_len={}, + num_blocks=4, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=1) + assert len(result) == 0 + + @pytest.mark.cpu_test + def test_pure_mla_no_dual_purpose(self): + """Pure MLA: _attn_block_len empty → standard path.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[512], + _is_ssm_region=[False], + _is_attn_region=[True], + _attn_block_len={}, # empty → .get(i) returns None → else + _has_mamba=False, + num_blocks=3, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x1000 + i * 512 + + +# ── Qwen heterogeneous TP regression tests ─────────────────────────────── + + +class TestQwenHeteroTPRegression: + """Verify Qwen (standard GQA, use_mla=False) heterogeneous TP is not + broken by the _attn_block_len mechanism introduced for KimiLinear. + + The root cause of the regression: _attn_block_len was populated for + Qwen HMA dual-purpose regions, causing _build_fa_local/_build_fa_remote + to use the MLA stride path which skips block_size_ratio scaling. + For standard attention, stride IS TP-dependent, so skipping the ratio + produces wrong descriptor sizes and 0 successful KV transfers. + + The fix: guard attn_stride with `self.use_mla`. + """ + + @pytest.mark.cpu_test + def test_qwen_hma_dual_purpose_ignores_attn_stride_local(self): + """Qwen HMA with dual-purpose regions: _build_fa_local must NOT + use _attn_block_len stride when use_mla=False. + + Without the use_mla guard, attn_stride=1024 would be used directly + (no block_size_ratio scaling), producing wrong descriptor sizes. + With the guard, the standard path applies block_size_ratio correctly. + """ + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], # SSM stride (local, TP=2) + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, # attention stride (local, TP=2) + num_blocks=2, + ) + # block_size_ratio=2: local blocks are 2× remote blocks + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # Standard path: block_len_per_layer[0] // block_size_ratio = 512//2 = 256 + # num_blocks = 2 * 2 = 4 descriptors with stride 256 + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert size == 256 # 512 // 2, NOT 1024 (unscaled attn stride) + assert addr == 0x1000 + i * 256 + + @pytest.mark.cpu_test + def test_qwen_hma_dual_purpose_ignores_attn_stride_remote(self): + """Qwen HMA with dual-purpose regions: _build_fa_remote must NOT + use _attn_block_len stride when use_mla=False.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], # local SSM stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, # local attention stride + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + # Remote block_lens = P-side SSM stride (TP=4, smaller) + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[256], num_blocks=4) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + # Standard path with block_size_ratio=1: + # get_backend_aware_kv_block_len → block_len_per_layer[0] = 512 + # remote_kv_block_len = 512 // 1 = 512 + # local_block_len = 512 // 1 (num_attn_reads) = 512 + # page_size = remote's block_lens[0] = 256 + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x5000 + i * 256 # steps by remote page_size + assert size == 512 + + @pytest.mark.cpu_test + def test_qwen_hetero_tp_local_applies_block_size_ratio(self): + """Qwen P4D2 (D-side, TP=2): local descriptors must scale by ratio. + + Without the fix, attn_stride (D's attention stride) was used without + //block_size_ratio, producing descriptors 2× the correct size. + """ + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[1024], # D-side SSM stride (large, TP=2) + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 2048}, # D-side attention stride (2×KV heads) + num_blocks=2, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # num_blocks = 2 * 2 = 4 + # Standard path: stride = 1024 // 2 = 512, size = 1024 // 2 = 512 + # Without fix: stride = 2048 (wrong!), size = 2048 (wrong!) + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert size == 512 + assert addr == 0x1000 + i * 512 + + @pytest.mark.cpu_test + def test_qwen_homo_tp_unaffected(self): + """Qwen homogeneous TP: block_size_ratio=1, both paths produce + same result. This verifies no regression for the working case.""" + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, + num_blocks=3, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + # Standard path with ratio=1: stride=512, size=512 + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert size == 512 + assert addr == 0x1000 + i * 512 + + +class TestKimiLinearDualPurpose: + """Verify KimiLinear (use_mla=True, KDA+MLA) dual-purpose regions + correctly use the MLA stride path with the use_mla guard.""" + + @pytest.mark.cpu_test + def test_mla_stride_path_activates_for_kimilinear(self): + """KimiLinear: use_mla=True → attn_stride path is taken.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], # KDA stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, # MLA stride (TP-independent) + num_blocks=3, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + # MLA path: stride=256 (unscaled), size=256 + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert size == 256 # MLA stride, not KDA's 200 + assert addr == 0x1000 + i * 256 + + @pytest.mark.cpu_test + def test_mla_stride_unscaled_by_block_size_ratio(self): + """KimiLinear: MLA stride is TP-independent, block_size_ratio + must NOT be applied (MLA num_kv_heads=1 regardless of TP).""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + num_blocks=2, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # MLA path: stride=256 (NOT 256//2=128) + assert result[0][1] == 256 + assert result[1][0] == 0x1000 + 256 + + @pytest.mark.cpu_test + def test_mla_remote_uses_attn_stride_as_page_size(self): + """KimiLinear remote: page_size = attn_stride (MLA, TP-independent).""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[200], num_blocks=3) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + # MLA path: page_size = attn_stride = 256 (not remote's 200) + assert result[0] == (0x5000, 256, 0) + assert result[1] == (0x5000 + 256, 256, 0) + assert result[2] == (0x5000 + 512, 256, 0) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0ab694b7e73d..b4048af61ca9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -533,11 +533,16 @@ def tp_ratio(self, remote_tp_size: int) -> int: return -(remote_tp_size // self.tp_size) def block_size_ratio(self, remote_block_size: int) -> int: - """Calculate the block size ratio between local and remote.""" - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size} or vice versa." - ) + """Calculate the block size ratio between local and remote. + + When the local and remote block sizes are not evenly divisible + (e.g. hybrid MLA+GDN models whose MLA component is TP-independent), + returns ``1`` as a safe fallback rather than raising. Downstream + code in the nixl worker handles the non-divisible case via + byte-level ``block_lens``. + """ + if self.block_size % remote_block_size != 0: + return 1 return self.block_size // remote_block_size def is_kv_replicated( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 806a87c582fa..46a797d393a7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -95,7 +95,7 @@ def _compute_desc_ids( ) -> np.ndarray: """Compute NIXL descriptor IDs for given block IDs.""" num_fa_regions = self.num_regions - num_ssm_regions = len(self.block_len_per_layer) * 4 if self._has_mamba else 0 + num_ssm_regions = sum(self._is_ssm_region) * 4 if self._has_mamba else 0 num_blocks = dst_num_blocks if block_size_ratio is not None: @@ -365,6 +365,17 @@ def __init__( # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 + # Per-region flags: True if the region has ANY SSM/Mamba layers + # or ANY attention/MLA layers, respectively. HMA may back both + # layer types with the same physical tensor, making a region + # dual-purpose (both flags True). + self._is_ssm_region: list[bool] = [] + self._is_attn_region: list[bool] = [] + # For HMA dual-purpose regions: maps region_idx to the attention + # spec's physical page size (bytes). When KDA (MambaSpec) and MLA + # (AttentionSpec) share the same backing tensor, block_len_per_layer + # stores KDA's stride. FA descriptors need MLA's stride instead. + self._attn_block_len: dict[int, int] = {} # nixl_prepped_dlist_handle. self.src_xfer_handles_by_block_size: dict[int, int] = {} @@ -829,7 +840,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data = [] # With hybrid allocator, layers can share a kv cache tensor - seen_base_addresses = [] + seen_base_addresses: list[int] = [] # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -844,6 +855,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers *only* when MLA is used. # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() + self._is_ssm_region = list[bool]() + self._is_attn_region = list[bool]() + self._attn_block_len = dict[int, int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -894,19 +908,28 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # registering a single tensor for both K/V and splitting logically like FI. for cache in cache_list: base_addr = cache.data_ptr() + is_ssm = isinstance(layer_spec, MambaSpec) + is_attn = not is_ssm if base_addr in seen_base_addresses: # NOTE (NickLucche) HMA employs memory pooling to share tensors # across groups. This results in skipping all tensors but the ones - # pointed to by group0. Also, generally we will have more blocks - # per tensor but fewer regions. - logger.debug("Skipping %s because it's already seen", layer_name) + # pointed to by group0. However, the same physical tensor may + # back both SSM and attention layers. Accumulate both flags so + # that the region can be dual-purpose. + idx = seen_base_addresses.index(base_addr) + self._is_ssm_region[idx] = self._is_ssm_region[idx] or is_ssm + self._is_attn_region[idx] = self._is_attn_region[idx] or is_attn + # Record the attention spec's stride so that FA descriptors + # use MLA's page size instead of KDA's for dual-purpose + # regions. + if is_attn: + self._attn_block_len[idx] = physical_page_size continue - logger.debug( - "Registering layer %s with cache shape: %s", layer_name, cache.shape - ) seen_base_addresses.append(base_addr) + self._is_ssm_region.append(is_ssm) + self._is_attn_region.append(is_attn) # Only record non-Mamba page sizes. - if isinstance(layer_spec, MambaSpec): + if is_ssm: self.block_len_per_layer.append( physical_page_size // self._physical_blocks_per_logical_kv_block ) @@ -946,10 +969,36 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Different block lengths collected: %s", set(self.block_len_per_layer) ) assert len(self.block_len_per_layer) == len(seen_base_addresses) - + # Expand NIXL registration size for dual-purpose regions. KDA + # registers first with its stride, but the physical allocation + # uses max(KDA, MLA). Without this, the registered region is + # too small for FA descriptors that address MLA data. + for idx in range(len(caches_data)): + if ( + idx < len(self._is_ssm_region) + and self._is_ssm_region[idx] + and idx < len(self._is_attn_region) + and self._is_attn_region[idx] + ): + max_page = max( + self.block_len_per_layer[idx], + self._attn_block_len.get(idx, 0), + ) + old_size = caches_data[idx][1] + new_size = self.num_blocks * max_page + if old_size != new_size: + base, _, device, label = caches_data[idx] + caches_data[idx] = (base, new_size, device, label) self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses - self.num_regions = len(caches_data) - + # FA regions count only attention layers. SSM/Mamba regions are + # served by Mamba descriptors, not FA descriptors, and must be + # excluded here so that FA descriptor IDs do not reference them. + # (For hybrid MLA+GDN models, building FA descriptors for the GDN + # region with virtual K/V split would produce out-of-bounds addresses + # under heterogeneous TP, failing NIXL prepXferDlist.) + # HMA dual-purpose regions (both SSM and attn) are counted here. + num_attention_base = sum(self._is_attn_region) + self.num_regions = num_attention_base if self.transfer_topo.virtually_split_kv_in_blocks: # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in @@ -1025,10 +1074,8 @@ def _build_mamba_local( ) -> list[tuple[int, int, int]]: """Build 4 desc regions (x, B, C, ssm) per layer for local mamba blocks, enabling the 3-read transfer with DS conv layout.""" - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) + # block_size_ratio is not used here because local descriptors always + # use local strides. Only remote descriptors need ratio scaling. assert self._conv_decomp is not None conv_offsets = self._conv_decomp.local_conv_offsets conv_size, ssm_size = self._mamba_ssm_size @@ -1037,6 +1084,13 @@ def _build_mamba_local( result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): + # Only build Mamba descriptors for SSM/Mamba regions. + # Attention/MLA regions do not contain conv or temporal state; + # building Mamba descriptors for them would reference addresses + # outside their registered memory, causing prepXferDlist failures. + # Dual-purpose regions (HMA) get both FA and Mamba descs. + if not self._is_ssm_region[i]: + continue # Jump one page_size, but ssm page_size may be bigger when kernel # locks block size to a specific value (physical_per_logical scale). page_stride = ( @@ -1088,6 +1142,10 @@ def _build_mamba_remote( # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case # block lengths vary across layers (e.g. MLA). for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + # Only build Mamba descriptors for SSM/Mamba regions. + # Attention/MLA regions do not contain conv or temporal state. + if i < len(self._is_ssm_region) and not self._is_ssm_region[i]: + continue page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical for off, sz in conv_offsets: for blk in range(num_blocks): @@ -1113,13 +1171,28 @@ def _build_fa_local( num_blocks = self.num_blocks * block_size_ratio result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): - kv_block_len = ( - self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False + # Only build FA descriptors for attention regions. + # SSM/Mamba regions are served by Mamba descriptors. + # Dual-purpose regions (HMA) get both FA and Mamba descs. + if i < len(self._is_attn_region) and not self._is_attn_region[i]: + continue + # Dual-purpose HMA regions: use MLA's TP-independent stride. + # Standard attention falls through to block_size_ratio path. + attn_stride = self._attn_block_len.get(i) if self.use_mla else None + if attn_stride is not None: + if self.transfer_topo.virtually_split_kv_in_blocks: + kv_block_len = attn_stride // 2 + else: + kv_block_len = attn_stride + page_stride = attn_stride + else: + kv_block_len = ( + self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) + // block_size_ratio ) - // block_size_ratio - ) - page_stride = self.block_len_per_layer[i] // block_size_ratio + page_stride = self.block_len_per_layer[i] // block_size_ratio for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset @@ -1129,9 +1202,12 @@ def _build_fa_local( # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) + if attn_stride is not None: + second_split = attn_stride // 2 + else: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset @@ -1154,37 +1230,57 @@ def _build_fa_remote( num_blocks = nixl_agent_meta.num_blocks result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote.. - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # ..using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len + # Only build FA descriptors for attention regions. + # SSM/Mamba regions are served by Mamba descriptors. + # Dual-purpose regions (HMA) get both FA and Mamba descs. + if i < len(self._is_attn_region) and not self._is_attn_region[i]: + continue + # MLA: TP-independent stride; standard attention: use block_size_ratio. + attn_stride = self._attn_block_len.get(i) if self.use_mla else None + if attn_stride is not None: + # MLA stride is TP-independent (num_kv_heads=1), + # so local attn_stride equals remote's MLA stride. + if self.transfer_topo.virtually_split_kv_in_blocks: + local_block_len = attn_stride // 2 + else: + local_block_len = attn_stride + remote_kv_block_len = local_block_len + page_size = attn_stride + else: + local_block_len = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + page_size = nixl_agent_meta.block_lens[i] local_block_len = local_block_len // num_attn_reads rank_offset = plan.rank_offset_factor * remote_kv_block_len - page_size = nixl_agent_meta.block_lens[i] for block_id in range(num_blocks): block_offset = block_id * page_size - # For each block, grab the kv heads chunk belonging to current local - # tp rank of size local_block_len. + # For each block, grab the kv heads chunk belonging to current + # local tp rank of size local_block_len. addr = base_addr + block_offset + rank_offset result.append((addr, local_block_len, nixl_agent_meta.device_id)) if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) + if attn_stride is not None: + second_split = attn_stride // 2 + v_stride = attn_stride + else: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) + v_stride = nixl_agent_meta.block_lens[i] second_split = second_split // num_attn_reads for block_id in range(num_blocks): block_offset = block_id * page_size addr = base_addr + block_offset + rank_offset # Hop over the first split of remote page, K, to read V. - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + v_addr = addr + v_stride // 2 result.append((v_addr, second_split, nixl_agent_meta.device_id)) return result @@ -1324,7 +1420,33 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size) + # Compute block_size_ratio. + # For dual-purpose HMA regions, byte-level strides differ across TP + # even when token-level block_size is the same. Compute byte-level + # ratio for MLA models. Standard attention models use the existing + # num_attn_reads mechanism instead. + if ( + self.use_mla # Only for MLA where stride is TP-independent + and self._attn_block_len # Only for dual-purpose HMA regions + and self.block_len_per_layer + and nixl_agent_meta.block_lens + and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0] + ): + local_bytes = self.block_len_per_layer[0] + remote_bytes = nixl_agent_meta.block_lens[0] + if local_bytes > remote_bytes and local_bytes % remote_bytes == 0: + block_size_ratio = local_bytes // remote_bytes + elif remote_bytes > local_bytes and remote_bytes % local_bytes == 0: + block_size_ratio = -(remote_bytes // local_bytes) + else: + # Non-exact byte division (e.g. hybrid models with + # TP-independent MLA component). Use 1 as fallback; + # _build_fa_remote handles bytes via remote block_lens. + block_size_ratio = 1 + else: + block_size_ratio = transfer_topo.block_size_ratio( + nixl_agent_meta.block_size + ) if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1349,9 +1471,12 @@ def add_remote_agent( plan = self.tp_mappings[engine_id] ### (Optional) Register local agent memory regions. MLA is not split. + # For hybrid MLA+GDN models, SSM state is TP-sharded and must be split + # to assemble data from multiple remote P ranks even when MLA + # attention is replicated. if ( tp_ratio < 0 - and not self.use_mla + and (not self.use_mla or self._has_mamba) and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): # Remote tp_size > local tp_size: read from multiple remote ranks. @@ -1402,7 +1527,6 @@ def add_remote_agent( transfer_info, ) ) - # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( @@ -1432,6 +1556,8 @@ def _validate_remote_agent_handshake( assert remote_info.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) + # Heterogeneous TP with non-divisible block sizes (e.g. hybrid + # MLA+GDN) falls back to 1 inside block_size_ratio. block_size_ratio = self.transfer_topo.block_size_ratio( nixl_agent_meta.block_size ) @@ -1460,10 +1586,20 @@ def _validate_remote_agent_handshake( "Disable prefix caching with --no-enable-prefix-caching." ) - if self._is_hma_required: - assert block_size_ratio == 1, ( - "HMA does not support different remote block size yet" - ) + if ( + self._is_hma_required + and block_size_ratio != 1 + and not (self.use_mla and self._has_mamba) + ): + # For hybrid models with HMA, block_size_ratio != 1 means + # token-level block sizes differ. MLA+SSM models are + # exempted because MLA stride is TP-independent. + # Standard attention + SSM models (e.g. Qwen GDN) are + # handled by _build_fa_remote which applies block_size_ratio + # to compute remote attention dimensions. + # SSM descriptors use their own hetero-TP-aware + # remote_conv_offsets. + raise AssertionError("HMA does not support different remote block size yet") kv_cache_layout = ( self.kv_cache_layout if not self.use_host_buffer @@ -1506,7 +1642,8 @@ def _validate_remote_agent_handshake( # Heterogeneous TP requires head-splitting, which only works with # HND layout. MLA and replicated-KV cases don't split on heads. - # Mamba doesn't support heterogeneous TP. + # The attention component of hybrid models still requires HND for + # head-dimension splitting under heterogeneous TP. if ( abs(tp_ratio) != 1 and not self.use_mla @@ -1523,9 +1660,46 @@ def _validate_remote_agent_handshake( remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. - # TODO (ZhanqiuHu): For mamba models, validate FA and mamba - # block_lens separately. - if not self._has_mamba: + if self._has_mamba and self._is_hma_required: + # Hybrid MLA+SSM with HMA: dual-purpose regions have both + # SSM and attention block lengths. Validate separately. + for i in range(len(self.block_len_per_layer)): + is_ssm = i < len(self._is_ssm_region) and self._is_ssm_region[i] + is_attn = i < len(self._is_attn_region) and self._is_attn_region[i] + remote_bl = nixl_agent_meta.block_lens[i] + if is_ssm and not is_attn: + # Pure SSM region: stride scales with TP. + assert ( + self.block_len_per_layer[i] // block_size_ratio == remote_bl + ), ( + f"SSM region {i} block_len mismatch: " + f"local={self.block_len_per_layer[i]} // " + f"ratio={block_size_ratio} != remote={remote_bl}" + ) + elif is_attn and not is_ssm: + # Pure attention region: MLA is replicated (TP-independent). + assert self.block_len_per_layer[i] == remote_bl, ( + f"Attention region {i} block_len mismatch: " + f"local={self.block_len_per_layer[i]} != remote={remote_bl}" + ) + elif is_ssm and is_attn and block_size_ratio != 1: + # Dual-purpose region: block_len_per_layer stores + # SSM stride (may differ with TP). MLA stride is + # stored in _attn_block_len and is TP-independent. + # Remote block_lens stores the remote's SSM stride. + assert ( + self.block_len_per_layer[i] // block_size_ratio == remote_bl + ), ( + f"Dual-purpose region {i} SSM stride mismatch: " + f"local={self.block_len_per_layer[i]} // " + f"ratio={block_size_ratio} != remote={remote_bl}" + ) + # When block_size_ratio == 1 for dual-purpose regions + # (fallback from non-exact byte division), SSM strides + # differ by approximately |tp_ratio| but page padding + # makes the ratio inexact. Skip strict validation — + # remote_conv_offsets handles TP-scaled addressing. + else: for i in range(len(self.block_len_per_layer)): assert ( self.block_len_per_layer[i] // block_size_ratio @@ -2061,7 +2235,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # D may have to perform multiple reads from different remote ranks. # MLA opt: when P TP > D TP, only a single read is executed for # the first remote rank (cache is duplicated).. - if self.use_mla and tp_ratio < 0: + # For hybrid MLA+GDN models, SSM state is TP-sharded, so multiple + # remote ranks are still needed for the SSM group even when MLA + # attention only needs one rank. + if self.use_mla and tp_ratio < 0 and not self._has_mamba: assert len(read_specs) == 1 for i, spec in enumerate(read_specs): @@ -2075,17 +2252,20 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): req_id, ) # Get side handles. - if tp_ratio < 0 and not self.use_mla: - assert remote_block_size == self.block_size + # Hybrid MLA+GDN: SSM needs split handles for multi-rank assembly. + if tp_ratio < 0 and (not self.use_mla or self._has_mamba): # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] else: # Single read from remote, we write to the whole memory region. # Also handle remote block size different from local block size. - local_xfer_side_handle = self.src_xfer_handles_by_block_size[ - remote_block_size - ] + # Use remote block_size handle if registered (block_size_ratio > 1), + # otherwise fall back to local block_size handle. + local_xfer_side_handle = self.src_xfer_handles_by_block_size.get( + remote_block_size, + self.src_xfer_handles_by_block_size[self.block_size], + ) # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ @@ -2101,9 +2281,19 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_xfer_side_handle=remote_xfer_side_handle, ) - if self.use_mla and tp_ratio < 0 and read_specs: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. + if self.use_mla and tp_ratio < 0 and read_specs and not self._has_mamba: + # MLA attention KV is replicated across P ranks. D reads from + # one P rank, but other P ranks also hold the same KV and need + # to be notified so they can release blocks early (vs. timeout). + # + # For hybrid MLA+GDN models, skip MLA notification entirely. + # MLA attention has only one source rank (replicated), so no + # *other* MLA source needs notifying. SSM source ranks are + # already notified via the NIXL read's built-in notif_msg, and + # sending extra MLA notifications to them causes "unrecognized + # request" errors because the P-side request tracking dicts + # (_reqs_to_send / _reqs_to_process) are only populated on the + # decode side. notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] for rank_to_notify, agent in remote_agents.items():