diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 50e83aa2ef20..662ed466c69c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -472,6 +472,7 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) @@ -726,6 +727,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + worker.num_descs = len(worker.src_blocks_data) def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size @@ -2435,6 +2437,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], + physical_blocks_per_logical=decode_worker._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 3f5a9b9cc031..267505546cda 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -153,15 +153,19 @@ def test_read_blocks_for_req_expands_remote_ids( ) remote_engine_id = "remote-engine" - if has_mamba: - worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio} # Mock transfer_topo: empty remote ranks skips the transfer machinery # entirely, isolating the block-ID expansion logic. worker.transfer_topo = MagicMock() worker.transfer_topo.target_remote_ranks.return_value = [] - worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1) + worker.transfer_topo.get_engine_info.return_value = MagicMock( + remote_tp_size=1, + remote_physical_blocks_per_logical=remote_ratio if has_mamba else 1, + ) worker.transfer_topo.tp_ratio.return_value = 1 + worker.transfer_policy = MagicMock() + worker.transfer_policy.compute_read_specs.return_value = [] + worker.use_mla = False metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -308,29 +312,28 @@ def test_nixl_metadata_hma_block_ids_structure(): @pytest.mark.cpu_test def test_get_block_descs_ids_hybrid_ssm(): - """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM - when ratio=1 (no kernel block size mismatch).""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( - NixlConnectorWorker, + """Test get_block_descs_ids uses per-group strides for hybrid + FA+SSM when ratio=1 (no kernel block size mismatch).""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 + MambaModelBlockTransferPolicy, ) - worker = object.__new__(NixlConnectorWorker) + policy = object.__new__(MambaModelBlockTransferPolicy) + policy._is_mamba_group = [False, True] num_blocks = 100 - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = 1 - worker._physical_blocks_per_logical = {engine_id: 1} - worker.block_len_per_layer = [100] - # num_descs = num_regions * num_blocks (no blocks_first doubling) - worker.num_descs = 2 * num_blocks + num_regions = 2 + block_len_per_layer = [100] fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + result = policy.get_block_descs_ids( + block_ids=(fa_blocks, ssm_blocks), + num_regions=num_regions, + dst_num_blocks=num_blocks, + block_len_per_layer=block_len_per_layer, + physical_blocks_per_logical=1, + ) # FA group: stride=num_blocks=100, offset=0 # region0: [3, 5], region1: [103, 105] @@ -344,30 +347,30 @@ def test_get_block_descs_ids_hybrid_ssm(): @pytest.mark.cpu_test def test_get_block_descs_ids_kernel_block_mismatch(): - """Test _get_block_descs_ids uses different strides for FA (kernel blocks) - vs SSM (logical blocks) when ratio > 1.""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( - NixlConnectorWorker, + """Test get_block_descs_ids uses different strides for FA + (kernel blocks) vs SSM (logical blocks) when ratio > 1.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 + MambaModelBlockTransferPolicy, ) - worker = object.__new__(NixlConnectorWorker) + policy = object.__new__(MambaModelBlockTransferPolicy) + policy._is_mamba_group = [False, True] ratio = 4 logical_blocks = 100 num_blocks = logical_blocks * ratio # 400 kernel blocks - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = ratio - worker._physical_blocks_per_logical = {engine_id: ratio} - worker.block_len_per_layer = [100] - worker.num_descs = 2 * num_blocks # 800 + num_regions = 2 + block_len_per_layer = [100] fa_blocks = [3, 7] # kernel-level block IDs ssm_blocks = [1, 2] # logical block IDs - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + result = policy.get_block_descs_ids( + block_ids=(fa_blocks, ssm_blocks), + num_regions=num_regions, + dst_num_blocks=num_blocks, + block_len_per_layer=block_len_per_layer, + physical_blocks_per_logical=ratio, + ) # FA group: stride=num_blocks=400, offset=0 # region0: [3, 7], region1: [403, 407] diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 63b56eddfaed..1662824aee67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -417,6 +417,14 @@ class MambaEngineTransferInfo(EngineTransferInfo): remote_physical_heads: int """Physical KV heads stored per remote rank.""" + @property + def fa_source_set(self) -> frozenset[int]: + return frozenset(self.remote_fa_source_ranks) + + @property + def fa_source_indices(self) -> dict[int, int]: + return {r: i for i, r in enumerate(self.remote_fa_source_ranks)} + # ---- Transfer topology ---- @@ -433,14 +441,13 @@ class TransferTopology: is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] + physical_blocks_per_logical: int tensor_shape: torch.Size | None = None def __post_init__(self): self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) self._engines: dict[EngineId, EngineTransferInfo] = {} - self._fa_source_sets: dict[EngineId, frozenset[int]] = {} - self._fa_source_indices: dict[EngineId, dict[int, int]] = {} # Figure out whether the first dimension of the cache is K/V # or num_blocks. @@ -487,24 +494,12 @@ def __post_init__(self): def register_remote_engine( self, remote_engine_id: EngineId, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - *, - local_block_len: int = 0, + info: EngineTransferInfo, ) -> EngineTransferInfo: """Register a remote engine, unifying worker dicts state. - Only remote engines should be registered here — the local engine's - identity (tp_size, block_size, etc.) is set via ``__init__`` params. - - For Mamba models, also computes the Mamba transfer plan and - builds the FA source lookup caches. - - Args: - local_block_len: Local representative block_len (bytes). - Required for Mamba models to compute ``fa_descriptor_bytes``. + The caller (worker) is responsible for computing the info via + the transfer policy. This method only stores and deduplicates. """ assert remote_engine_id != self.engine_id, ( f"Cannot register local engine {self.engine_id} as remote. " @@ -512,29 +507,6 @@ def register_remote_engine( ) if remote_engine_id in self._engines: return self._engines[remote_engine_id] - info: EngineTransferInfo - if self.is_mamba: - info = self._build_mamba_info( - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_len, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - local_block_len=local_block_len, - ) - assert isinstance(info, MambaEngineTransferInfo) - self._fa_source_sets[remote_engine_id] = frozenset( - info.remote_fa_source_ranks - ) - self._fa_source_indices[remote_engine_id] = { - r: i for i, r in enumerate(info.remote_fa_source_ranks) - } - else: - info = EngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - ) self._engines[remote_engine_id] = info return info @@ -662,131 +634,6 @@ def get_transfer_cache_regions( # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] - # ============================================================ - # Mamba-specific methods - # ============================================================ - - def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank (mamba-only).""" - return remote_rank not in self._fa_source_sets[remote_engine_id] - - def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - """ - fa_index = self._fa_source_indices[remote_engine_id] - if remote_rank in fa_index: - return fa_index[remote_rank] - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - K = self.total_num_kv_heads - remote_tp = mamba_info.remote_tp_size - r_head = self._physical_head_range(remote_tp, K, remote_rank) - for target in mamba_info.remote_fa_source_ranks: - t_head = self._physical_head_range(remote_tp, K, target) - if self._range_overlap(r_head, t_head): - return fa_index[target] - return 0 - - def fa_rank_offset( - self, remote_engine_id: EngineId, remote_kv_block_len: int - ) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when local does not index into remote. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - if self.is_mla or tp_ratio <= 0: - return 0 - K = self.total_num_kv_heads - is_local_replicated = self.tp_size > K - if is_local_replicated: - local_head = self.tp_rank * K // self.tp_size - p_rank = mamba_info.remote_fa_source_ranks[0] - p_start = p_rank * K // mamba_info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return self.tp_rank % tp_ratio * remote_kv_block_len - - def needs_split_handles(self, remote_engine_id: EngineId) -> bool: - """Whether per-remote-rank split handles are needed. - - True when FA and mamba have different read counts, requiring - different splitting factors in the local handle. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - return ( - tp_ratio < 0 - and not self.is_mla - and len(mamba_info.remote_all_source_ranks) > 1 - ) - - def compute_split_handle_data( - self, - remote_engine_id: EngineId, - src_blocks_data: list[tuple[int, int, int]], - num_fa_descs: int, - abs_tp: int, - ) -> list[list[tuple[int, int, int]]]: - """Per-remote-rank (addr, len, dev) triples for Mamba-HMA split - handles. - - FA descriptors (indices < num_fa_descs) are sliced by - ``remote_num_fa_reads``; mamba descriptors are sliced uniformly - by ``abs_tp``. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - all_handle_data: list[list[tuple[int, int, int]]] = [] - for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks): - handle_data: list[tuple[int, int, int]] = [] - skip_fa = self.should_skip_fa(remote_engine_id, p_rank) - fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0 - for j, (addr, local_len, dev) in enumerate(src_blocks_data): - if j < num_fa_descs: - assert mamba_info.remote_num_fa_reads >= 1 - fa_chunk = local_len // mamba_info.remote_num_fa_reads - handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) - else: - mamba_chunk = local_len // abs_tp - handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) - all_handle_data.append(handle_data) - return all_handle_data - - def filter_block_ids_for_rank( - self, - remote_engine_id: EngineId, - remote_rank: int, - local_ids: BlockIds, - remote_ids: BlockIds, - is_mamba_group: list[bool], - ) -> tuple[BlockIds, BlockIds]: - """Zero out FA groups for remote ranks outside ``fa_source_ranks``. - - Returns (filtered_local_ids, filtered_remote_ids). When the - remote rank carries FA data for this local rank, returns the - inputs unchanged. - """ - if not self.should_skip_fa(remote_engine_id, remote_rank): - return local_ids, remote_ids - num_groups = len(local_ids) - filtered_local: list[list[int]] = [ - [] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups) - ] - filtered_remote: list[list[int]] = [ - [] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups) - ] - return filtered_local, filtered_remote - def describe(self, remote_engine_id: EngineId) -> str: """One-line summary of transfer config for logging.""" info = self._engines[remote_engine_id] @@ -808,163 +655,3 @@ def describe(self, remote_engine_id: EngineId) -> str: f"fa_desc_bytes={info.remote_fa_descriptor_bytes})" ) return f"TransferTopology({base})" - - # ============================================================ - # Private helpers - # ============================================================ - # Mamba-HMA hetero-TP transfer config: - # With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across - # P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is - # almost always uniquely sharded per P rank. So the number of P - # ranks D must read from can differ between FA and Mamba, and they - # must be handled separately. - - @staticmethod - def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight partitioning): - consecutive ranks share a head before moving to the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - @staticmethod - def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - - # ============================================================ - # Private: build Mamba transfer info - # ============================================================ - - def _build_mamba_info( - self, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - local_block_len: int, - ) -> MambaEngineTransferInfo: - """Compute Mamba transfer plan.""" - K = self.total_num_kv_heads - local_tp = self.tp_size - local_rank = self.tp_rank - - is_remote_replicated = remote_tp_size > K - remote_physical_heads = max(1, K // remote_tp_size) - - if local_tp >= remote_tp_size: - assert local_tp % remote_tp_size == 0 - tp_ratio = local_tp // remote_tp_size - else: - assert remote_tp_size % local_tp == 0 - tp_ratio = -(remote_tp_size // local_tp) - - abs_tp = -tp_ratio if tp_ratio < 0 else 1 - - mamba_range: range | None = None - if tp_ratio < 0: - mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) - - # ---- FA read targets ---- - if self.is_mla or tp_ratio >= 0: - num_fa_reads = 1 - fa_source_ranks: list[int] = ( - [0] - if self.is_mla - else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] - ) - else: - local_needs = self._physical_head_range(local_tp, K, local_rank) - search_range = ( - mamba_range if mamba_range is not None else range(remote_tp_size) - ) - seen: set[tuple[int, int]] = set() - fa_source_ranks = [] - for p in search_range: - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - if not fa_source_ranks: - for p in range(remote_tp_size): - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - num_fa_reads = len(fa_source_ranks) - - # ---- All source ranks (mamba + FA) ---- - if mamba_range is not None and abs_tp > num_fa_reads: - num_mamba_reads = abs_tp - all_source_ranks = list(mamba_range) - else: - num_mamba_reads = num_fa_reads - all_source_ranks = list(fa_source_ranks) - - # ---- FA descriptor bytes ---- - effective_block_len = min(local_block_len, remote_block_len) - if self.is_kv_layout_blocks_first: - fa_descriptor_bytes = effective_block_len // 2 - else: - fa_descriptor_bytes = effective_block_len - - # ---- Validation ---- - is_local_replicated = local_tp > K - if is_local_replicated and is_remote_replicated and tp_ratio > 0: - logger.info( - "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", - local_tp, - remote_tp_size, - K, - ) - tt_set = set(all_source_ranks) - for t in fa_source_ranks: - if t not in tt_set: - logger.error( - "FA source rank %d NOT in all_source_ranks %s.", - t, - all_source_ranks, - ) - if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: - local_k_half = local_block_len // 2 - remote_k_half = remote_block_len // 2 - expected = local_k_half // num_fa_reads - if expected != remote_k_half: - logger.warning( - "FA size mismatch: local_k_half=%d / reads=%d = %d, " - "but remote_k_half=%d.", - local_k_half, - num_fa_reads, - expected, - remote_k_half, - ) - - return MambaEngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - remote_fa_source_ranks=tuple(fa_source_ranks), - remote_all_source_ranks=tuple(all_source_ranks), - remote_num_fa_reads=num_fa_reads, - remote_num_mamba_reads=num_mamba_reads, - remote_fa_descriptor_bytes=fa_descriptor_bytes, - is_remote_replicated=is_remote_replicated, - remote_physical_heads=remote_physical_heads, - ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 2057c79fa58c..bf7ab3bb2897 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -753,6 +753,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.use_mla = self.model_config.use_mla + self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() # Get the attention backend from the first layer @@ -773,6 +774,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=[backend], + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) self.async_zmq_ctx = zmq.asyncio.Context() @@ -795,6 +797,9 @@ def _sync_block_size_with_kernel(self) -> None: kernel_block_size, ) assert self.block_size > kernel_block_size + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) self.block_size = kernel_block_size def __del__(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py new file mode 100644 index 000000000000..17d91c86a96f --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -0,0 +1,1215 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ModelBlockTransferPolicy: model-specific transfer intelligence. + +This module defines the ``ModelBlockTransferPolicy`` ABC and its concrete +implementations for Dense and Mamba models. The policy encapsulates all +model-specific logic that was previously scattered across ``worker.py`` +and ``TransferTopology``, making both of those model-agnostic. + +The policy is an *immutable config holder*: its state is set once at +``__init__`` and never mutated. It computes results but stores no +per-engine state (that lives on ``TransferTopology._engines``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import torch + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + BlockIds, + EngineId, + EngineTransferInfo, + MambaEngineTransferInfo, + TransferTopology, +) +from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, + derive_mamba_conv_split, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first +from vllm.v1.kv_cache_interface import MambaSpec + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, + ) + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class ReadSpec: + """Specification for a single remote block read operation. + + Computed upfront by ``compute_read_specs`` so that the worker's read + loop is a simple iteration with no model-specific branching. + """ + + remote_rank: int + local_block_ids: BlockIds + remote_block_ids: BlockIds + + +# ------------------------------------------------------------------ +# Private module-level helpers +# ------------------------------------------------------------------ + + +def _get_kv_block_len( + layer_idx: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> int: + """Byte length of one K or V descriptor for a layer. + + For FA/KV layers only — Mamba SSM descriptors use ``_ssm_sizes`` + directly. When ``is_blocks_first``, K and V are interleaved so + the descriptor covers half the block. + """ + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + +def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + """Physical KV head range stored in a rank's KV cache tensor. + + When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. + When ``tp_size > num_heads``: 1 physical head per rank. Heads are + distributed **contiguously** (matching vLLM's GQA weight + partitioning): consecutive ranks share a head before moving to + the next one. + """ + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + + +def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) + + +def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: + """Whether to skip FA groups for this remote rank. + + Returns False unless ``info`` carries FA source tracking + (``MambaEngineTransferInfo``) and the rank is a replicated + duplicate. + """ + if not isinstance(info, MambaEngineTransferInfo): + return False + return remote_rank not in info.fa_source_set + + +def _fa_head_slot( + info: EngineTransferInfo, + remote_rank: int, + total_num_kv_heads: int, +) -> int: + """Index into local FA block for this remote rank's head data. + + For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. + For ranks NOT in ``fa_source_ranks`` (replicated duplicates), + returns the slot of the matching source rank with the same head. + Returns 0 when ``info`` has no FA source tracking. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + fa_index = info.fa_source_indices + if remote_rank in fa_index: + return fa_index[remote_rank] + K = total_num_kv_heads + remote_tp = info.remote_tp_size + r_head = _physical_head_range(remote_tp, K, remote_rank) + for target in info.remote_fa_source_ranks: + t_head = _physical_head_range(remote_tp, K, target) + if _range_overlap(r_head, t_head): + return fa_index[target] + return 0 + + +def _fa_rank_offset( + info: EngineTransferInfo, + remote_kv_block_len: int, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, +) -> int: + """Byte offset into remote FA block for this local rank. + + When local TP is replicated (local_tp > K), multiple local ranks + share a head. Computes offset *relative to the target remote + rank's first head* so it works regardless of how many heads the + remote has. Returns 0 when ``info`` has no FA source tracking + or when local does not index into remote. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) + if is_mla or tp_ratio <= 0: + return 0 + K = total_num_kv_heads + is_local_replicated = tp_size > K + if is_local_replicated: + local_head = tp_rank * K // tp_size + p_rank = info.remote_fa_source_ranks[0] + p_start = p_rank * K // info.remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return tp_rank % tp_ratio * remote_kv_block_len + + +class ModelBlockTransferPolicy(ABC): + """Abstract base for model-specific block transfer logic. + + Encapsulates genuinely model-specific algorithms: descriptor building, + transfer info computation, split handles, read spec filtering, and + orchestration. Simple per-layer branches (``isinstance(MambaSpec)``) + and block-ID mapping remain on ``worker.py``. + """ + + def __init__( + self, + kv_cache_config: KVCacheConfig, + physical_blocks_per_logical: int, + ): + self._kv_cache_config = kv_cache_config + self._physical_blocks_per_logical = physical_blocks_per_logical + + # ------------------------------------------------------------------ + # Per-engine transfer info (data operations) + # ------------------------------------------------------------------ + + @abstractmethod + def build_engine_transfer_info( + self, + *, + # Local topology + transfer_topo: TransferTopology, + # Block geometry + local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + ) -> EngineTransferInfo: + """Compute transfer info for a remote engine. + + Dense models return ``EngineTransferInfo``. + Mamba models return ``MambaEngineTransferInfo``. + """ + + # ------------------------------------------------------------------ + # Descriptor ID computation (abstract — genuinely different per model) + # ------------------------------------------------------------------ + + @abstractmethod + def get_block_descs_ids( + self, + # Input + block_ids: BlockIds, + # Block geometry + num_regions: int, + dst_num_blocks: int, + block_len_per_layer: list[int], + block_size_ratio: float | None = None, + physical_blocks_per_logical: int = 1, + ) -> np.ndarray: + """Compute NIXL descriptor IDs for a set of block IDs.""" + ... + + # ------------------------------------------------------------------ + # Local descriptor building (concrete default = FA-only) + # ------------------------------------------------------------------ + + @abstractmethod + def build_local_descs( + self, + # Memory + base_addresses: list[int], + device_id: int, + # Block geometry + num_blocks: int, + logical_num_blocks: int, + block_size_ratio: int, + block_len_per_layer: list[int], + # Layout + is_blocks_first: bool, + ) -> list[tuple[int, int, int]]: + """Build local (src) descriptor tuples for NIXL registration.""" + ... + + def _build_fa_local_descs( + self, + base_addresses: list[int], + device_id: int, + num_blocks: int, + block_size_ratio: int, + block_len_per_layer: list[int], + is_blocks_first: bool, + ) -> list[tuple[int, int, int]]: + """Build FA local descriptors (shared by Dense and Mamba).""" + result: list[tuple[int, int, int]] = [] + n_blocks = num_blocks * block_size_ratio + for i, base_addr in enumerate(base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N + kv_block_len = ( + _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + // block_size_ratio + ) + # Jump one page_size, but ssm page_size may be bigger when kernel + # locks block size to a specific value. + page_stride = block_len_per_layer[i] // block_size_ratio + for block_id in range(n_blocks): + # (addr, len, device id) + result.append( + ( + base_addr + block_id * page_stride, + kv_block_len, + device_id, + ) + ) + if is_blocks_first: + second_split = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(n_blocks): + # Register addresses for V cache (K registered first). + v_addr = base_addr + block_id * page_stride + kv_block_len + result.append( + ( + v_addr, + second_split, + device_id, + ) + ) + return result + + # ------------------------------------------------------------------ + # Remote descriptor building (abstract — genuinely different) + # ------------------------------------------------------------------ + + @abstractmethod + def build_remote_descs( + self, + transfer_topo: TransferTopology, + engine_id: EngineId, + nixl_agent_meta: NixlAgentMetadata, + block_len_per_layer: list[int], + ) -> list[tuple[int, int, int]]: + """Build remote (dst) descriptor tuples.""" + ... + + @abstractmethod + def build_src_split_handles( + self, + transfer_topo: TransferTopology, + engine_id: EngineId, + src_blocks_data: list[tuple[int, int, int]], + num_descs: int, + ) -> list[list[tuple[int, int, int]]]: + """Build split handle data for P_TP > D_TP scenario.""" + ... + + # ------------------------------------------------------------------ + # Read spec computation + # ------------------------------------------------------------------ + + @abstractmethod + def compute_read_specs( + self, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, + remote_ranks: list[int], + remote_info: EngineTransferInfo, + ) -> list[ReadSpec]: + """Compute the full set of read operations needed for a request. + + Returns one ``ReadSpec`` per remote rank that requires a read. + The worker iterates the result without model-specific branching. + MLA trimming (keeping only the first spec) is handled by the worker. + + Block ID expansion (logical→kernel) is done by the worker before + calling this method. + """ + ... + + # ------------------------------------------------------------------ + # Factory + # ------------------------------------------------------------------ + + @staticmethod + def create( + kv_cache_config: KVCacheConfig, + layer_specs: dict[str, KVCacheSpec], + physical_blocks_per_logical: int, + tp_size: int, + ) -> ModelBlockTransferPolicy: + """Create the appropriate policy based on model architecture.""" + has_mamba = any( + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ) + if has_mamba: + return MambaModelBlockTransferPolicy( + kv_cache_config=kv_cache_config, + tp_size=tp_size, + layer_specs=layer_specs, + physical_blocks_per_logical=physical_blocks_per_logical, + ) + return DenseModelBlockTransferPolicy( + kv_cache_config, + physical_blocks_per_logical, + ) + + +# ====================================================================== +# Dense (pure-attention) policy +# ====================================================================== + + +class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): + """Policy for pure-attention (dense) models. + + Inherits ``get_block_len`` from the ABC. + """ + + def __init__( + self, + kv_cache_config: KVCacheConfig, + physical_blocks_per_logical: int, + ): + super().__init__(kv_cache_config, physical_blocks_per_logical) + + def build_local_descs( + self, + base_addresses, + device_id, + num_blocks, + logical_num_blocks, + block_size_ratio, + block_len_per_layer, + is_blocks_first, + ): + return self._build_fa_local_descs( + base_addresses, + device_id, + num_blocks, + block_size_ratio, + block_len_per_layer, + is_blocks_first, + ) + + def get_block_descs_ids( + self, + block_ids, + num_regions, + dst_num_blocks, + block_len_per_layer, + block_size_ratio=None, + physical_blocks_per_logical=1, + ): + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + block_ids_arr = np.concatenate(block_ids) + region_ids = np.arange(num_regions)[:, None] + return (region_ids * num_blocks + block_ids_arr[None, :]).flatten() + + def build_remote_descs( + self, + transfer_topo, + engine_id, + nixl_agent_meta, + block_len_per_layer, + ): + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + # Register all remote blocks, but only the corresponding kv heads. + remote_info = transfer_topo.get_engine_info(engine_id) + assert isinstance(remote_info, EngineTransferInfo) + tp_rank = transfer_topo.tp_rank + is_mla = transfer_topo.is_mla + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + tp_ratio = transfer_topo.tp_ratio(remote_info.remote_tp_size) + block_size_ratio = transfer_topo.block_size_ratio(remote_info.remote_block_size) + indexes_into_remote = ( + not ( + is_mla or remote_info.remote_tp_size > transfer_topo.total_num_kv_heads + ) + and tp_ratio > 0 + ) + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + local_block_len = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + if tp_ratio < 0 and not is_mla: + local_block_len = local_block_len // (-tp_ratio) + rank_offset = ( + tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 + ) + num_blocks = nixl_agent_meta.num_blocks + page_size = nixl_agent_meta.block_lens[i] + dev_id = nixl_agent_meta.device_id + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + result.append((addr, local_block_len, dev_id)) + if is_blocks_first: + second_split = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + if tp_ratio < 0 and not is_mla: + second_split = second_split // (-tp_ratio) + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + result.append((v_addr, second_split, dev_id)) + return result + + def build_src_split_handles( + self, + transfer_topo, + engine_id, + src_blocks_data, + num_descs, + ): + remote_info = transfer_topo.get_engine_info(engine_id) + assert isinstance(remote_info, EngineTransferInfo) + _ = num_descs + assert remote_info.remote_tp_size > transfer_topo.tp_size + abs_tp = remote_info.remote_tp_size // transfer_topo.tp_size + result: list[list[tuple[int, int, int]]] = [] + for i in range(abs_tp): + blocks_data: list[tuple[int, int, int]] = [] + for addr, local_block_len, own_rank in src_blocks_data: + remote_block_len = local_block_len // abs_tp + blocks_data.append( + ( + addr + i * remote_block_len, + remote_block_len, + own_rank, + ) + ) + result.append(blocks_data) + return result + + def compute_read_specs( + self, + local_block_ids, + remote_block_ids, + remote_ranks, + remote_info, + ): + assert isinstance(remote_info, EngineTransferInfo) + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + ) + for rank in remote_ranks + ] + + def build_engine_transfer_info( + self, + *, + # Local topology + transfer_topo: TransferTopology, + # Block geometry + local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + ) -> EngineTransferInfo: + _ = (transfer_topo, local_block_len) + return EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) + + +# ====================================================================== +# Mamba (hybrid SSM+Attention) policy +# ====================================================================== + + +class MambaModelBlockTransferPolicy(ModelBlockTransferPolicy): + """Policy for hybrid Mamba+Attention (SSM) models. + + Stores Mamba-specific state (SSM sizes, conv decomposition, + per-group flags) and overrides methods that differ from the + FA-only base: descriptor building, split handles, read specs, + registration helpers for MambaSpec layers, and orchestration. + """ + + def __init__( + self, + kv_cache_config: KVCacheConfig, + tp_size: int, + layer_specs: dict[str, KVCacheSpec], + physical_blocks_per_logical: int, + ): + super().__init__(kv_cache_config, physical_blocks_per_logical) + self._is_mamba_group = [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ] + + mamba_spec = next( + spec for spec in layer_specs.values() if isinstance(spec, MambaSpec) + ) + conv_nbytes = torch.tensor( + [], + dtype=mamba_spec.dtypes[0], # type: ignore[misc] + ).element_size() + ssm_nbytes = torch.tensor( + [], + dtype=mamba_spec.dtypes[1], # type: ignore[misc] + ).element_size() + conv_shape = torch.Size(mamba_spec.shapes[0]) + ssm_shape = torch.Size(mamba_spec.shapes[1]) + self._ssm_sizes = ( + conv_shape.numel() * conv_nbytes, + ssm_shape.numel() * ssm_nbytes, + ) + + assert is_conv_state_dim_first(), ( + "3-read Mamba conv transfer requires DS conv state layout. " + "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + ) + self._conv_decomp = derive_mamba_conv_split(mamba_spec, tp_size) + + @property + def ssm_sizes(self) -> tuple[int, int]: + """(conv_state_bytes, ssm_state_bytes) per logical block.""" + return self._ssm_sizes + + @property + def conv_decomp(self) -> MambaConvSplitInfo: + """Conv-state sub-projection decomposition.""" + return self._conv_decomp + + def get_block_descs_ids( + self, + block_ids, + num_regions, + dst_num_blocks, + block_len_per_layer, + block_size_ratio=None, + physical_blocks_per_logical=1, + ): + # NOTE (NickLucche) With HMA, every kv group has the same number of + # layers and layers from different groups share the same kv tensor. + # eg block_ids=[[1,2],[3]]->blocks [1,2] need to be read across all + # regions, same for [3], but group0-group1 blocks will always differ + # (different areas). Therefore we can just flatten the block_ids and + # compute the descs ids for all groups at once. + + # Compute desc ids per group using the right stride: FA descs have + # num_blocks entries per region (kernel granularity), SSM descs have + # logical_blocks entries per region (no kernel splitting). + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + ratio = physical_blocks_per_logical + logical_blocks = num_blocks // ratio + num_fa_descs = num_regions * num_blocks + # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged + # arbitrarily by manager. Therefore, descs are duplicated for SSM and + # Attention like so: + # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. + # This is like having two "low-level views" of the same storage. + # `num_fa_descs` offset must be computed per-engine since P and D can + # have different num_blocks (and thus different FA descs counts). + # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). + mamba_region_ids = np.arange( + len(block_len_per_layer) * 4, + )[:, None] + all_descs: list[np.ndarray] = [] + for i, group in enumerate(block_ids): + group_arr = np.asarray(group) + if self._is_mamba_group[i]: + # Mamba blocks are 1:1 logical-to-physical (no expansion). + all_descs.append( + ( + mamba_region_ids * logical_blocks + + group_arr[None, :] + + num_fa_descs + ).flatten() + ) + else: + region_ids = np.arange(num_regions)[:, None] + all_descs.append( + (region_ids * num_blocks + group_arr[None, :]).flatten() + ) + return np.concatenate(all_descs) + + def build_local_descs( + self, + base_addresses, + device_id, + num_blocks, + logical_num_blocks, + block_size_ratio, + block_len_per_layer, + is_blocks_first, + ): + fa_descs = self._build_fa_local_descs( + base_addresses, + device_id, + num_blocks, + block_size_ratio, + block_len_per_layer, + is_blocks_first, + ) + num_regions = len(base_addresses) * (2 if is_blocks_first else 1) + assert len(fa_descs) == num_regions * num_blocks + # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read + # split is unnecessary — a single conv desc per block suffices. + # Consider adding a fast path that falls back to the standard 2-region + # registration when no hetero-TP remote has been seen. Currently we + # always register 4 regions because local descs are created before + # knowing the remote TP. + logger.debug("Registering local Mamba descriptors (4 regions/layer)") + mamba_descs = self._build_mamba_local_descs( + base_addresses, + block_len_per_layer, + logical_num_blocks, + block_size_ratio, + device_id, + ) + return fa_descs + mamba_descs + + def _build_mamba_local_descs( + self, + base_addresses: list[int], + block_len_per_layer: list[int], + logical_num_blocks: int, + block_size_ratio: int, + device_id: int, + ) -> list[tuple[int, int, int]]: + """Build 4 desc regions (x, B, C, ssm) per layer for local + mamba blocks, enabling the 3-read transfer with DS conv layout. + + Conv state sub-projection decomposition requires DS (dim, state_len) + conv layout so that x/B/C sub-projections are contiguous in memory. + """ + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got block_size_ratio={block_size_ratio}." + ) + conv_offsets = self._conv_decomp.local_conv_offsets + # SSM States come in tuples (conv_size, ssm_state_size) + conv_size, ssm_size = self._ssm_sizes + n_blocks = logical_num_blocks * block_size_ratio + phys_ratio = self._physical_blocks_per_logical + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio + # 3 conv sub-projection regions (x, B, C) + for off, sz in conv_offsets: + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + off, + sz, + device_id, + ) + ) + # SSM temporal state follows the conv state. + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + conv_size, + ssm_size, + device_id, + ) + ) + return result + + def build_remote_descs( + self, + transfer_topo, + engine_id, + nixl_agent_meta, + block_len_per_layer, + ): + remote_info = transfer_topo.get_engine_info(engine_id) + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info + tp_ratio = transfer_topo.tp_ratio(info.remote_tp_size) + result: list[tuple[int, int, int]] = [] + result.extend( + self._build_fa_remote_descs( + transfer_topo, + nixl_agent_meta, + info, + tp_ratio, + block_len_per_layer, + ) + ) + result.extend( + self._build_mamba_remote_descs( + nixl_agent_meta, + tp_ratio, + transfer_topo.tp_rank, + info.remote_physical_blocks_per_logical, + ) + ) + return result + + # NOTE (ZhanqiuHu): See ABC comment on _should_skip_fa for context. + # This method also handles FA replication (see ABC helpers above). + def _build_fa_remote_descs( + self, + transfer_topo: TransferTopology, + nixl_agent_meta: NixlAgentMetadata, + info: MambaEngineTransferInfo, + tp_ratio: int, + block_len_per_layer: list[int], + ): + """Build remote FA descriptors for mamba models using + transfer_cfg for GQA-aware sizing.""" + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size + is_mla = transfer_topo.is_mla + total_num_kv_heads = transfer_topo.total_num_kv_heads + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + block_size_ratio = transfer_topo.block_size_ratio(info.remote_block_size) + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + local_block_len = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + if tp_ratio < 0 and not is_mla: + local_block_len = local_block_len // info.remote_num_fa_reads + rank_offset = _fa_rank_offset( + info, + remote_kv_block_len, + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=is_mla, + total_num_kv_heads=total_num_kv_heads, + ) + num_blocks = nixl_agent_meta.num_blocks + page_size = nixl_agent_meta.block_lens[i] + dev_id = nixl_agent_meta.device_id + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + result.append((addr, local_block_len, dev_id)) + if is_blocks_first: + second_split = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + if tp_ratio < 0 and not is_mla: + second_split = second_split // info.remote_num_fa_reads + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + result.append( + ( + v_addr, + second_split, + dev_id, + ) + ) + return result + + def _build_mamba_remote_descs( + self, + nixl_agent_meta, + tp_ratio, + tp_rank, + physical_blocks_per_logical, + ): + """Build 4 remote desc regions (x, B, C, ssm) per layer + for the 3-read transfer. + + Mamba conv state is always TP-sharded, even when attention KV + is replicated (num_kv_heads < tp_size). + """ + effective_ratio = max(tp_ratio, 1) + local_offset = tp_rank % effective_ratio + conv_size_remote = nixl_agent_meta.ssm_sizes[0] + + if tp_ratio >= 1: + # D_TP >= P_TP: P page is larger, D reads its slice. + conv_offsets = self._conv_decomp.remote_conv_offsets( + local_offset, + effective_ratio, + ) + # SSM temporal state is also TP-sharded on the heads dimension. + ssm_read_size = self._ssm_sizes[1] + else: + # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages + # are smaller than D's. self._conv_decomp has D-sized dimensions, + # but we need P-sized offsets. Scale down by |tp_ratio|. + abs_ratio = -tp_ratio + xb_p = self._conv_decomp.x_bytes // abs_ratio + bb_p = self._conv_decomp.b_bytes // abs_ratio + conv_offsets = [ + (0, xb_p), + (xb_p, bb_p), + (xb_p + bb_p, bb_p), + ] + ssm_read_size = nixl_agent_meta.ssm_sizes[1] + + # Assume same num_blocks for mamba and fa + remote_ratio = physical_blocks_per_logical + num_blocks = nixl_agent_meta.num_blocks // remote_ratio + dev_id = nixl_agent_meta.device_id + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case + # block lengths vary across layers (e.g. MLA). + page_stride = nixl_agent_meta.block_lens[i] * remote_ratio + for off, sz in conv_offsets: + for blk in range(num_blocks): + result.append( + ( + base_addr + blk * page_stride + off, + sz, + dev_id, + ) + ) + for blk in range(num_blocks): + ssm_addr = ( + base_addr + + blk * page_stride + + conv_size_remote + + local_offset * ssm_read_size + ) + result.append((ssm_addr, ssm_read_size, dev_id)) + return result + + def build_src_split_handles( + self, + transfer_topo, + engine_id, + src_blocks_data, + num_descs, + ): + remote_info = transfer_topo.get_engine_info(engine_id) + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info + tp_size = transfer_topo.tp_size + assert info.remote_tp_size > tp_size + abs_tp = info.remote_tp_size // tp_size + if self.needs_split_handles( + info, + tp_size=tp_size, + is_mla=transfer_topo.is_mla, + ): + result = list( + self.compute_split_handle_data( + info, + src_blocks_data, + num_descs, + abs_tp, + total_num_kv_heads=transfer_topo.total_num_kv_heads, + ) + ) + logger.info( + "Mamba-HMA split handles: targets=%s, fa_reads=%s, " + "fa_entry=%s, mamba_reads=%s, num_descs=%s", + info.remote_all_source_ranks, + info.remote_num_fa_reads, + info.remote_fa_descriptor_bytes, + info.remote_num_mamba_reads, + num_descs, + ) + return result + return [] + + def compute_read_specs( + self, + local_block_ids, + remote_block_ids, + remote_ranks, + remote_info, + ): + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info + specs: list[ReadSpec] = [] + for rank in remote_ranks: + filtered_local, filtered_remote = self.filter_block_ids_for_rank( + info, + rank, + local_block_ids, + remote_block_ids, + ) + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=filtered_local, + remote_block_ids=filtered_remote, + ) + ) + return specs + + def build_engine_transfer_info( + self, + *, + # Local topology + transfer_topo: TransferTopology, + # Block geometry + local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + ) -> MambaEngineTransferInfo: + K = transfer_topo.total_num_kv_heads + local_tp = transfer_topo.tp_size + local_rank = transfer_topo.tp_rank + is_mla = transfer_topo.is_mla + is_kv_layout_blocks_first = transfer_topo.is_kv_layout_blocks_first + + is_remote_replicated = remote_tp_size > K + remote_physical_heads = max(1, K // remote_tp_size) + + if local_tp >= remote_tp_size: + assert local_tp % remote_tp_size == 0 + tp_ratio = local_tp // remote_tp_size + else: + assert remote_tp_size % local_tp == 0 + tp_ratio = -(remote_tp_size // local_tp) + + abs_tp = -tp_ratio if tp_ratio < 0 else 1 + + mamba_range: range | None = None + if tp_ratio < 0: + mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) + + # ---- FA read targets ---- + if is_mla or tp_ratio >= 0: + num_fa_reads = 1 + fa_source_ranks: list[int] = ( + [0] + if is_mla + else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] + ) + else: + local_needs = _physical_head_range(local_tp, K, local_rank) + search_range = ( + mamba_range if mamba_range is not None else range(remote_tp_size) + ) + seen: set[tuple[int, int]] = set() + fa_source_ranks = [] + for p in search_range: + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + if not fa_source_ranks: + for p in range(remote_tp_size): + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + num_fa_reads = len(fa_source_ranks) + + # ---- All source ranks (mamba + FA) ---- + if mamba_range is not None and abs_tp > num_fa_reads: + num_mamba_reads = abs_tp + all_source_ranks = list(mamba_range) + else: + num_mamba_reads = num_fa_reads + all_source_ranks = list(fa_source_ranks) + + # ---- FA descriptor bytes ---- + effective_block_len = min(local_block_len, remote_block_len) + if is_kv_layout_blocks_first: + fa_descriptor_bytes = effective_block_len // 2 + else: + fa_descriptor_bytes = effective_block_len + + # ---- Validation ---- + is_local_replicated = local_tp > K + if is_local_replicated and is_remote_replicated and tp_ratio > 0: + logger.info( + "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", + local_tp, + remote_tp_size, + K, + ) + tt_set = set(all_source_ranks) + for t in fa_source_ranks: + if t not in tt_set: + logger.error( + "FA source rank %d NOT in all_source_ranks %s.", + t, + all_source_ranks, + ) + if is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: + local_k_half = local_block_len // 2 + remote_k_half = remote_block_len // 2 + expected = local_k_half // num_fa_reads + if expected != remote_k_half: + logger.warning( + "FA size mismatch: local_k_half=%d / reads=%d = %d, " + "but remote_k_half=%d.", + local_k_half, + num_fa_reads, + expected, + remote_k_half, + ) + + return MambaEngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + remote_fa_source_ranks=tuple(fa_source_ranks), + remote_all_source_ranks=tuple(all_source_ranks), + remote_num_fa_reads=num_fa_reads, + remote_num_mamba_reads=num_mamba_reads, + remote_fa_descriptor_bytes=fa_descriptor_bytes, + is_remote_replicated=is_remote_replicated, + remote_physical_heads=remote_physical_heads, + ) + + def needs_split_handles( + self, + info: MambaEngineTransferInfo, + tp_size: int, + is_mla: bool, + ) -> bool: + """Whether per-remote-rank split handles are needed. + + True when FA and mamba have different read counts, requiring + different splitting factors in the local handle. + """ + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) # noqa: E501 + return tp_ratio < 0 and not is_mla and len(info.remote_all_source_ranks) > 1 + + def compute_split_handle_data( + self, + info: MambaEngineTransferInfo, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, + abs_tp: int, + total_num_kv_heads: int, + ) -> list[list[tuple[int, int, int]]]: + """Per-remote-rank (addr, len, dev) triples for split handles. + + FA descriptors (indices < num_fa_descs) are sliced by + ``remote_num_fa_reads``; mamba descriptors are sliced uniformly + by ``abs_tp``. + """ + all_handle_data: list[list[tuple[int, int, int]]] = [] + for p_idx, p_rank in enumerate(info.remote_all_source_ranks): + handle_data: list[tuple[int, int, int]] = [] + skip_fa = _should_skip_fa(info, p_rank) + fa_slot = ( + _fa_head_slot(info, p_rank, total_num_kv_heads) if not skip_fa else 0 + ) + for j, (addr, local_len, dev) in enumerate(src_blocks_data): + if j < num_fa_descs: + assert info.remote_num_fa_reads >= 1 + fa_chunk = local_len // info.remote_num_fa_reads + handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) + else: + mamba_chunk = local_len // abs_tp + handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) + all_handle_data.append(handle_data) + return all_handle_data + + def filter_block_ids_for_rank( + self, + info: MambaEngineTransferInfo, + remote_rank: int, + local_ids: BlockIds, + remote_ids: BlockIds, + ) -> tuple[BlockIds, BlockIds]: + """Zero out FA groups for remote ranks outside ``fa_source_ranks``. + + Returns (filtered_local_ids, filtered_remote_ids). When the + remote rank carries FA data for this local rank, returns the + inputs unchanged. + """ + if not _should_skip_fa(info, remote_rank): + return local_ids, remote_ids + num_groups = len(local_ids) + filtered_local: list[list[int]] = [ + [] if not self._is_mamba_group[g] else local_ids[g] + for g in range(num_groups) + ] + filtered_remote: list[list[int]] = [ + [] if not self._is_mamba_group[g] else remote_ids[g] + for g in range(num_groups) + ] + return filtered_local, filtered_remote diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index bd7ef5973f62..9039d145482d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -21,7 +21,6 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, - MambaEngineTransferInfo, TransferTopology, get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, @@ -30,6 +29,9 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( + ModelBlockTransferPolicy, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( GET_META_MSG, NixlAgentMetadata, @@ -48,9 +50,7 @@ zmq_ctx, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( - MambaConvSplitInfo, compute_physical_blocks_per_logical, - derive_mamba_conv_split, ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.distributed.parallel_state import ( @@ -58,7 +58,6 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout @@ -146,17 +145,6 @@ def __init__( ssm_shape.numel() * ssm_nbytes, ) self._mamba_ssm_size = mamba_ssm_size - # Conv state sub-projection decomposition (None when no Mamba). - # The 3-read transfer requires DS (dim, state_len) conv layout so - # that x/B/C sub-projections are contiguous in memory. - self._conv_decomp: MambaConvSplitInfo | None = None - if self._has_mamba: - assert is_conv_state_dim_first(), ( - "3-read Mamba conv transfer requires DS conv state layout. " - "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" - ) - local_tp = vllm_config.parallel_config.tensor_parallel_size - self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] @@ -268,14 +256,6 @@ def __init__( self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] - # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- - # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. - # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) - # where conv/ssm bytes are per-TP-rank (dimension-sharded). With - # heterogeneous TP the per-rank sizes differ, so the ratio differs: - # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. - self._physical_blocks_per_logical: dict[EngineId, int] = {} - # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} @@ -330,6 +310,13 @@ def __init__( self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() + self.transfer_policy = ModelBlockTransferPolicy.create( + kv_cache_config=kv_cache_config, + layer_specs=self._layer_specs, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, + tp_size=vllm_config.parallel_config.tensor_parallel_size, + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) @@ -652,6 +639,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, # SSM States come in tuples (ssm, conv) tensor_shape=next(iter(kv_caches.values())).shape if not self._has_mamba @@ -812,9 +800,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.dst_num_blocks[self.engine_id] = self.num_blocks if self._has_mamba: - self._physical_blocks_per_logical[self.engine_id] = ( - self._physical_blocks_per_logical_kv_block - ) logger.info( "Hybrid SSM registration: num_blocks=%s, " "logical_num_blocks=%s, ratio=%s, num_regions=%s, " @@ -856,157 +841,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata_bytes=encoder.encode(agent_metadata), ) - def _build_mamba_local( - self, - base_addresses: list[int], - block_size_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 desc regions (x, B, C, ssm) per layer for local mamba - blocks, enabling the 3-read transfer with DS conv layout.""" - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - assert self._conv_decomp is not None - conv_offsets = self._conv_decomp.local_conv_offsets - conv_size, ssm_size = self._mamba_ssm_size - num_blocks = self._logical_num_blocks * block_size_ratio - physical_per_logical = self._physical_blocks_per_logical_kv_block - - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): - page_stride = ( - self.block_len_per_layer[i] // block_size_ratio * physical_per_logical - ) - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append( - (base_addr + blk * page_stride + off, sz, self.device_id) - ) - # SSM temporal state follows the conv state. - for blk in range(num_blocks): - result.append( - ( - base_addr + blk * page_stride + conv_size, - ssm_size, - self.device_id, - ) - ) - return result - - def _build_fa_remote_for_mamba( - self, - nixl_agent_meta: NixlAgentMetadata, - block_size_ratio: int, - transfer_topo: TransferTopology, - remote_engine_id: EngineId, - ) -> list[tuple[int, int, int]]: - """Build remote FA descriptors for mamba models. - - Uses TransferTopology for GQA-aware FA divisor and head-based rank - offset instead of the standard uniform tp_ratio split. - """ - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA - # hetero-TP logic stabilizes. - mamba_info = transfer_topo.get_engine_info(remote_engine_id) - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size) - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - local_block_len = local_block_len // mamba_info.remote_num_fa_reads - - rank_offset = transfer_topo.fa_rank_offset( - remote_engine_id, remote_kv_block_len - ) - - num_blocks = nixl_agent_meta.num_blocks - page_size = nixl_agent_meta.block_lens[i] - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - result.append((addr, local_block_len, nixl_agent_meta.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // mamba_info.remote_num_fa_reads - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append((v_addr, second_split, nixl_agent_meta.device_id)) - return result - - def _build_mamba_remote( - self, - nixl_agent_meta: NixlAgentMetadata, - tp_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 remote desc regions (x, B, C, ssm) per layer for - the 3-read transfer. For hetero-TP, each D rank reads only its - sub-projection slice from the P rank.""" - assert self._conv_decomp is not None - effective_ratio = max(tp_ratio, 1) - # Mamba conv state is always TP-sharded, even when attention KV - # is replicated (num_kv_heads < tp_size). - local_offset = self.tp_rank % effective_ratio - conv_size_remote = nixl_agent_meta.ssm_sizes[0] - - if tp_ratio >= 1: - # D_TP >= P_TP: P page is larger, D reads its slice. - conv_offsets = self._conv_decomp.remote_conv_offsets( - local_offset, effective_ratio - ) - ssm_read_size = self._mamba_ssm_size[1] - else: - # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages - # are smaller than D's. self._conv_decomp has D-sized dimensions, - # but we need P-sized offsets. Scale down by |tp_ratio|. - abs_ratio = -tp_ratio - xb_p = self._conv_decomp.x_bytes // abs_ratio - bb_p = self._conv_decomp.b_bytes // abs_ratio - conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] - ssm_read_size = nixl_agent_meta.ssm_sizes[1] - - remote_physical_per_logical = self._physical_blocks_per_logical[ - nixl_agent_meta.engine_id - ] - num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical - device_id = nixl_agent_meta.device_id - - result: list[tuple[int, int, int]] = [] - # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case - # block lengths vary across layers (e.g. MLA). - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append((base_addr + blk * page_stride + off, sz, device_id)) - # SSM temporal state is also TP-sharded on the heads dimension. - for blk in range(num_blocks): - ssm_addr = ( - base_addr - + blk * page_stride - + conv_size_remote - + local_offset * ssm_read_size - ) - result.append((ssm_addr, ssm_read_size, device_id)) - return result - def register_local_xfer_handler( self, block_size: int, @@ -1026,71 +860,27 @@ def register_local_xfer_handler( transfer_topo = self.transfer_topo block_size_ratio = self.block_size // block_size - blocks_data: list[tuple[int, int, int]] = [] local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] - def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool): - for i, base_addr in enumerate(local_base_addresses): - # The new block_len is using prefill block_len; - # and num_blocks is multiple with N - kv_block_len = ( - self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - // block_size_ratio - ) - # Jump one page_size, but ssm page_size may be bigger when kernel - # locks block size to a specific value. - block_len_per_layer = ( - self.block_len_per_layer[i] - // block_size_ratio - * (1 if not mamba else self._physical_blocks_per_logical_kv_block) - ) - num_blocks = self._logical_num_blocks if mamba else self.num_blocks - num_blocks = num_blocks * block_size_ratio - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, second_split, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - # NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used - # right now — we use _build_mamba_local instead for the 3-read - # approach. However, we might still need this as a fallback for homogeneous TP. - register_blocks(blocks_data, mamba=False) - if self._has_mamba: - assert self.num_descs == len(blocks_data) - # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is - # unnecessary — a single conv desc per block suffices. Consider - # adding a fast path that falls back to the standard 2-region - # registration (register_blocks mamba=True) when no hetero-TP - # remote has been seen. Currently we always register 4 regions - # because local descs are created before knowing the remote TP. - logger.debug("Registering local Mamba descriptors (4 regions/layer)") - blocks_data.extend( - self._build_mamba_local(local_base_addresses, block_size_ratio) - ) + blocks_data = self.transfer_policy.build_local_descs( + # Memory + base_addresses=local_base_addresses, + device_id=self.device_id, + # Block geometry + num_blocks=self.num_blocks, + logical_num_blocks=self._logical_num_blocks, + block_size_ratio=block_size_ratio, + block_len_per_layer=self.block_len_per_layer, + # Layout + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + ) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1167,17 +957,18 @@ def add_remote_agent( if self._has_mamba else 1 ) - transfer_topo.register_remote_engine( - remote_engine_id=engine_id, + transfer_info = self.transfer_policy.build_engine_transfer_info( + # Local topology + transfer_topo=transfer_topo, + # Block geometry + local_block_len=self.block_len_per_layer[0], + # Remote facts (from NixlAgentMetadata handshake) remote_tp_size=remote_tp_size, remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], remote_physical_blocks_per_logical=physical_blocks_per_logical, - local_block_len=self.block_len_per_layer[0], ) - if self._has_mamba and engine_id not in self._physical_blocks_per_logical: - self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical - + transfer_topo.register_remote_engine(engine_id, transfer_info) logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -1206,11 +997,6 @@ def add_remote_agent( # this is the ratio between the two sizes. tp_ratio = transfer_topo.tp_ratio(remote_tp_size) - # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = ( - not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 - ) - logger.debug( "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", engine_id, @@ -1227,149 +1013,34 @@ def add_remote_agent( # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). - abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - if self._has_mamba: - if transfer_topo.needs_split_handles(engine_id): - # Mamba-HMA: FA and Mamba use different split factors. - for handle_data in transfer_topo.compute_split_handle_data( - engine_id, self.src_blocks_data, self.num_descs, abs_tp - ): - descs = self.nixl_wrapper.get_xfer_descs( - handle_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - logger.info( - "Mamba-HMA split handles: %s, num_descs=%s", - transfer_topo.describe(engine_id), - self.num_descs, - ) - else: - # Original path: uniform divide by abs_tp (non-Mamba-HMA). - for i in range(abs_tp): - blocks_data = [] - for memory_region in self.src_blocks_data: - addr, local_block_len, own_tp_rank = memory_region - remote_block_len = local_block_len // abs_tp - addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs( - blocks_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - ### Register remote agent memory regions - blocks_data = [] - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogeneous TP, prepare the descriptors by splitting the - # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - - # Register all remote blocks, but only the corresponding kv heads. - def register_remote_blocks( - blocks_data: list[tuple[int, int, int]], mamba: bool - ): - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote. - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - # Remote tp is bigger: read a chunk of local region from remote - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if indexes_into_remote - else 0 - ) - - # Assume same num_blocks for mamba and fa - num_blocks = ( - nixl_agent_meta.num_blocks - if not mamba - else nixl_agent_meta.num_blocks - // self._physical_blocks_per_logical_kv_block - ) - page_size = nixl_agent_meta.block_lens[i] * ( - 1 if not mamba else self._physical_blocks_per_logical_kv_block - ) - for block_id in range(num_blocks): - block_offset = block_id * page_size - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append( - (addr, local_block_len, nixl_agent_meta.device_id) - ) - - if transfer_topo.is_kv_layout_blocks_first: - # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Apply the same scaling as local_block_len above for when we read - # a chunk of local V from `tp_ratio` separate remote workers. - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // (-tp_ratio) - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - # Hop over the first split of remote page: either K or Conv. - if mamba: - v_addr = addr + nixl_agent_meta.ssm_sizes[0] - else: - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append( - (v_addr, second_split, nixl_agent_meta.device_id) - ) - - logger.debug( - "Created %s blocks for dst engine %s" - " with remote rank %s and local rank %s", - len(blocks_data), + for handle_data in self.transfer_policy.build_src_split_handles( + transfer_topo, engine_id, - remote_tp_rank, - self.tp_rank, - ) - - if self._has_mamba: - # Mamba-HMA: separate FA registration with GQA-aware sizing, - # plus mamba 3-read registration for the Mamba "view" of the - # same KV cache tensors. - logger.debug( - "Registering remote Mamba blocks for engine %s rank %s", - engine_id, - remote_tp_rank, - ) - blocks_data.extend( - self._build_fa_remote_for_mamba( - nixl_agent_meta, - block_size_ratio, - transfer_topo, - engine_id, - ) - ) - blocks_data.extend( - self._build_mamba_remote( - nixl_agent_meta, - tp_ratio, + self.src_blocks_data, + self.num_descs, + ): + descs = self.nixl_wrapper.get_xfer_descs( + handle_data, self.nixl_memory_type ) - ) - else: - register_remote_blocks(blocks_data, mamba=False) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) + + ### Register remote agent memory regions + blocks_data = self.transfer_policy.build_remote_descs( + transfer_topo, + engine_id, + nixl_agent_meta, + self.block_len_per_layer, + ) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -1901,23 +1572,36 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) + # TODO (ZhanqiuHu): Unify logical_to_kernel_block_ids + # and logical_to_remote_kernel_block_ids. if self._has_mamba: # Expand remote logical → kernel block IDs. meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, - self._physical_blocks_per_logical[meta.remote.engine_id], + remote_info.remote_physical_blocks_per_logical, ) else: meta.remote.block_ids = self._logical_to_kernel_block_ids( meta.remote.block_ids ) - # D may have to perform multiple reads from different remote ranks. - for i, remote_rank in enumerate(remote_ranks): - if self.use_mla and tp_ratio < 0 and i > 0: - # MLA opt: when P TP > D TP, only a single read is executed for - # the first remote rank (cache is duplicated).. - break + remote_block_ids = meta.remote.block_ids + read_specs = self.transfer_policy.compute_read_specs( + local_block_ids=meta.local_physical_block_ids, + remote_block_ids=remote_block_ids, + remote_ranks=remote_ranks, + remote_info=remote_info, + ) + # D may have to perform multiple reads from different remote ranks. + # MLA opt: when P TP > D TP, only a single read is executed for + # the first remote rank (cache is duplicated). + if self.use_mla and tp_ratio < 0: + read_specs = read_specs[:1] + + for i, spec in enumerate(read_specs): + remote_rank = spec.remote_rank + local_block_ids = spec.local_block_ids + remote_block_ids = spec.remote_block_ids remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" @@ -1945,37 +1629,26 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_rank ] - local_ids: BlockIds = meta.local_physical_block_ids - remote_ids: BlockIds = meta.remote.block_ids - if self._has_mamba: - # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. - local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank( - engine_id, - remote_rank, - local_ids, - remote_ids, - self._is_mamba_group, - ) - self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, remote_request_id=meta.remote.request_id, - local_block_ids=local_ids, - remote_block_ids=remote_ids, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, remote_rank=remote_rank, local_xfer_side_handle=local_xfer_side_handle, remote_xfer_side_handle=remote_xfer_side_handle, ) - if self.use_mla and tp_ratio < 0: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - notif_id = f"{req_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + if self.use_mla and tp_ratio < 0 and read_specs: + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. + notif_id = f"{req_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + read_ranks = {s.remote_rank for s in read_specs} + for rank_to_notify, agent in remote_agents.items(): + if rank_to_notify not in read_ranks: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, @@ -2078,14 +1751,20 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, - remote_block_ids, + remote_block_descs_ids = self.transfer_policy.get_block_descs_ids( + block_ids=remote_block_ids, + num_regions=self.num_regions, + dst_num_blocks=self.dst_num_blocks[dst_engine_id], + block_len_per_layer=self.block_len_per_layer, + physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) - local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, - local_block_ids, + local_block_descs_ids = self.transfer_policy.get_block_descs_ids( + block_ids=local_block_ids, + num_regions=self.num_regions, + dst_num_blocks=self.dst_num_blocks[self.engine_id], + block_len_per_layer=self.block_len_per_layer, block_size_ratio=block_size_ratio, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -2147,63 +1826,6 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) - def _get_block_descs_ids( - self, - engine_id: str, - block_ids: BlockIds, - block_size_ratio: float | None = None, - ) -> np.ndarray: - """ - Get the descs ids for a set of block ids. - When HMA is enabled number of descriptors across kv cache groups might differ. - A single flattened array is returned for all groups anyway. - """ - region_ids = np.arange(self.num_regions) - - # NOTE (NickLucche) With HMA, every kv group has the same number of layers and - # layers from different groups share the same kv tensor. - # eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, - # same for [3], but group0-group1 blocks will always differ (different areas). - # Therefore we can just flatten the block_ids and compute the descs ids for all - # groups at once. - num_blocks = self.dst_num_blocks[engine_id] - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - - # Compute desc ids per group using the right stride: FA descs have - # num_blocks entries per region (kernel granularity), SSM descs have - # logical_blocks entries per region (no kernel splitting). - region_ids = region_ids[:, None] - if not self._has_mamba: - block_ids = np.concatenate(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids - return descs_ids.flatten() - else: - # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged - # arbitrarily by manager. Therefore, descs are duplicated for SSM and - # Attention like so: - # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. - # This is like having two "low-level views" of the same storage. - # `num_fa_descs` offset must be computed per-engine since P and D can - # have different num_blocks (and thus different FA descs counts). - physical_per_logical = self._physical_blocks_per_logical[engine_id] - logical_blocks = num_blocks // physical_per_logical - num_fa_descs = self.num_regions * num_blocks - # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). - mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] - all_descs = [] - for i, group in enumerate(block_ids): - group_arr = np.asarray(group)[None, :] - if self._is_mamba_group[i]: - all_descs.append( - ( - mamba_region_ids * logical_blocks + group_arr + num_fa_descs - ).flatten() - ) - else: - all_descs.append((region_ids * num_blocks + group_arr).flatten()) - return np.concatenate(all_descs) - def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: """ Convert logical block ids to kernel physical block ids. @@ -2260,52 +1882,6 @@ def _logical_to_remote_kernel_block_ids( result.append(group) return result - def get_backend_aware_kv_block_len( - self, layer_idx: int, first_split: bool = True, mamba_view: bool = False - ) -> int: - """ - Get the block length for one K/V element (K and V have the same size). - - For FA and other backends, this is equal to the length of the whole - block, as K and V are in separate regions. - For FlashInfer, this is half the length of the whole block, as K and V - share the same region. - Similarly, for SSM-based models, state and conv are interleaved, but crucially - the their size differs. - Reference diagram: - KVCacheTensor (Shared) - / \\ - / \\ - / \\ - Attention (FlashInfer) View Mamba View - | | - | | - +-------------------+ +-------------------+ - | KVCacheTensor | | KVCacheTensor | - | | | | - |<----- page ------>| |<----- page ------->| - | size | | size | - | Key 0 | Val 0 | |Conv 0 | SSM 0 | - | Key 1 | Val 1 | |Conv 1 | SSM 1 | - | ... | ... | | ... | ... | - | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | - | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | - +-------------------+ +--------------------+ - |1st_split-2nd_split| |1st_split-2nd_split | - """ - assert self.transfer_topo is not None - if self.transfer_topo.is_kv_layout_blocks_first: - # For indexing only half (either just the K or V part). - if mamba_view: - # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so - # we're only transferring the minimum required bytes. - block_len = self._mamba_ssm_size[not first_split] - else: - block_len = self.block_len_per_layer[layer_idx] // 2 - else: - block_len = self.block_len_per_layer[layer_idx] - return block_len - def get_kv_connector_stats(self) -> KVConnectorStats | None: """ Get the KV transfer stats for the connector.