From 659826a2d817b2cc6058169df3299b8e8bfb1c8c Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 16 Apr 2026 20:29:18 +0000 Subject: [PATCH 01/49] move mamba-specifc states to mamba-engine-info; add policy class Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/utils.py | 27 +- .../v1/nixl/block_transfer_policy.py | 358 ++++++++++++++++++ 2 files changed, 372 insertions(+), 13 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 63b56eddfaed..5aae2ef4daf3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -417,6 +417,14 @@ class MambaEngineTransferInfo(EngineTransferInfo): remote_physical_heads: int """Physical KV heads stored per remote rank.""" + @property + def fa_source_set(self) -> frozenset[int]: + return frozenset(self.remote_fa_source_ranks) + + @property + def fa_source_indices(self) -> dict[int, int]: + return {r: i for i, r in enumerate(self.remote_fa_source_ranks)} + # ---- Transfer topology ---- @@ -439,8 +447,6 @@ def __post_init__(self): self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) self._engines: dict[EngineId, EngineTransferInfo] = {} - self._fa_source_sets: dict[EngineId, frozenset[int]] = {} - self._fa_source_indices: dict[EngineId, dict[int, int]] = {} # Figure out whether the first dimension of the cache is K/V # or num_blocks. @@ -521,13 +527,6 @@ def register_remote_engine( remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), local_block_len=local_block_len, ) - assert isinstance(info, MambaEngineTransferInfo) - self._fa_source_sets[remote_engine_id] = frozenset( - info.remote_fa_source_ranks - ) - self._fa_source_indices[remote_engine_id] = { - r: i for i, r in enumerate(info.remote_fa_source_ranks) - } else: info = EngineTransferInfo( remote_tp_size=remote_tp_size, @@ -668,7 +667,9 @@ def get_transfer_cache_regions( def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool: """Whether to skip FA groups for this remote rank (mamba-only).""" - return remote_rank not in self._fa_source_sets[remote_engine_id] + mamba_info = self._engines[remote_engine_id] + assert isinstance(mamba_info, MambaEngineTransferInfo) + return remote_rank not in mamba_info.fa_source_set def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: """Index into local FA block for this remote rank's head data. @@ -677,11 +678,11 @@ def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: For ranks NOT in ``fa_source_ranks`` (replicated duplicates), returns the slot of the matching source rank with the same head. """ - fa_index = self._fa_source_indices[remote_engine_id] - if remote_rank in fa_index: - return fa_index[remote_rank] mamba_info = self._engines[remote_engine_id] assert isinstance(mamba_info, MambaEngineTransferInfo) + fa_index = mamba_info.fa_source_indices + if remote_rank in fa_index: + return fa_index[remote_rank] K = self.total_num_kv_heads remote_tp = mamba_info.remote_tp_size r_head = self._physical_head_range(remote_tp, K, remote_rank) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py new file mode 100644 index 000000000000..c7cf57d33fbc --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ModelBlockTransferPolicy: model-specific transfer intelligence. + +This module defines the ``ModelBlockTransferPolicy`` ABC and its concrete +implementations for Dense and Mamba models. The policy encapsulates all +model-specific logic that was previously scattered across ``worker.py`` +and ``TransferTopology``, making both of those model-agnostic. + +The policy is an *immutable config holder*: its state is set once at +``__init__`` and never mutated. It computes results but stores no +per-engine state (that lives on ``TransferTopology._engines``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import torch + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, + MambaEngineTransferInfo, + _physical_head_range, + _range_overlap, +) +from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, + derive_mamba_conv_split, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first +from vllm.v1.kv_cache_interface import MambaSpec + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig + +logger = init_logger(__name__) + + +class ModelBlockTransferPolicy(ABC): + """Abstract base for model-specific block transfer logic. + + Concrete subclasses encapsulate: + - Model identity (is_mamba, per-group flags) + - Mamba state sizes and conv decomposition + - Per-engine transfer info computation (``build_engine_transfer_info``) + """ + + # ------------------------------------------------------------------ + # Model identity + # ------------------------------------------------------------------ + + @property + @abstractmethod + def is_mamba(self) -> bool: + """Whether this policy handles a hybrid Mamba+Attention model.""" + + @property + @abstractmethod + def mamba_group_flags(self) -> list[bool]: + """Per-group flag: True if the group is a Mamba (SSM) group.""" + + def is_mamba_group(self, group_idx: int) -> bool: + return self.mamba_group_flags[group_idx] + + @property + @abstractmethod + def ssm_sizes(self) -> tuple[int, int]: + """(conv_state_bytes, ssm_state_bytes) per logical block. + + Returns (0, 0) for dense models. + """ + + @property + @abstractmethod + def conv_decomp(self) -> MambaConvSplitInfo | None: + """Conv-state sub-projection decomposition, or None for dense.""" + + # ------------------------------------------------------------------ + # Per-engine transfer info + # ------------------------------------------------------------------ + + @abstractmethod + def build_engine_transfer_info( + self, + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_kv_layout_blocks_first: bool, + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + local_block_len: int, + ) -> EngineTransferInfo: + """Compute transfer info for a remote engine. + + Dense models return ``EngineTransferInfo``. + Mamba models return ``MambaEngineTransferInfo``. + """ + + # ------------------------------------------------------------------ + # Factory + # ------------------------------------------------------------------ + + @staticmethod + def create( + kv_cache_config: KVCacheConfig, tp_size: int + ) -> ModelBlockTransferPolicy: + """Create the appropriate policy based on model architecture.""" + group_flags = [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ] + if any(group_flags): + return MambaModelBlockTransferPolicy( + kv_cache_config=kv_cache_config, + group_flags=group_flags, + tp_size=tp_size, + ) + return DenseModelBlockTransferPolicy(group_flags=group_flags) + + +# ====================================================================== +# Dense (pure-attention) policy +# ====================================================================== + + +class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): + def __init__(self, group_flags: list[bool]): + self._group_flags = group_flags + + @property + def is_mamba(self) -> bool: + return False + + @property + def mamba_group_flags(self) -> list[bool]: + return self._group_flags + + @property + def ssm_sizes(self) -> tuple[int, int]: + return (0, 0) + + @property + def conv_decomp(self) -> MambaConvSplitInfo | None: + return None + + def build_engine_transfer_info( + self, + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_kv_layout_blocks_first: bool, + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + local_block_len: int, + ) -> EngineTransferInfo: + return EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) + + +# ====================================================================== +# Mamba (hybrid SSM+Attention) policy +# ====================================================================== + + +class MambaModelBlockTransferPolicy(ModelBlockTransferPolicy): + def __init__( + self, + kv_cache_config: KVCacheConfig, + group_flags: list[bool], + tp_size: int, + ): + self._group_flags = group_flags + + mamba_spec = next( + spec + for group in kv_cache_config.kv_cache_groups + for spec in [group.kv_cache_spec] + if isinstance(spec, MambaSpec) + ) + conv_nbytes = torch.tensor( + [], + dtype=mamba_spec.dtypes[0], # type: ignore[misc] + ).element_size() + ssm_nbytes = torch.tensor( + [], + dtype=mamba_spec.dtypes[1], # type: ignore[misc] + ).element_size() + conv_shape = torch.Size(mamba_spec.shapes[0]) + ssm_shape = torch.Size(mamba_spec.shapes[1]) + self._ssm_sizes = ( + conv_shape.numel() * conv_nbytes, + ssm_shape.numel() * ssm_nbytes, + ) + + assert is_conv_state_dim_first(), ( + "3-read Mamba conv transfer requires DS conv state layout. " + "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + ) + self._conv_decomp = derive_mamba_conv_split(mamba_spec, tp_size) + + @property + def is_mamba(self) -> bool: + return True + + @property + def mamba_group_flags(self) -> list[bool]: + return self._group_flags + + @property + def ssm_sizes(self) -> tuple[int, int]: + return self._ssm_sizes + + @property + def conv_decomp(self) -> MambaConvSplitInfo | None: + return self._conv_decomp + + def build_engine_transfer_info( + self, + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_kv_layout_blocks_first: bool, + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + local_block_len: int, + ) -> MambaEngineTransferInfo: + K = total_num_kv_heads + local_tp = tp_size + local_rank = tp_rank + + is_remote_replicated = remote_tp_size > K + remote_physical_heads = max(1, K // remote_tp_size) + + if local_tp >= remote_tp_size: + assert local_tp % remote_tp_size == 0 + tp_ratio = local_tp // remote_tp_size + else: + assert remote_tp_size % local_tp == 0 + tp_ratio = -(remote_tp_size // local_tp) + + abs_tp = -tp_ratio if tp_ratio < 0 else 1 + + mamba_range: range | None = None + if tp_ratio < 0: + mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) + + # ---- FA read targets ---- + if is_mla or tp_ratio >= 0: + num_fa_reads = 1 + fa_source_ranks: list[int] = ( + [0] + if is_mla + else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] + ) + else: + local_needs = _physical_head_range(local_tp, K, local_rank) + search_range = ( + mamba_range if mamba_range is not None else range(remote_tp_size) + ) + seen: set[tuple[int, int]] = set() + fa_source_ranks = [] + for p in search_range: + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + if not fa_source_ranks: + for p in range(remote_tp_size): + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + num_fa_reads = len(fa_source_ranks) + + # ---- All source ranks (mamba + FA) ---- + if mamba_range is not None and abs_tp > num_fa_reads: + num_mamba_reads = abs_tp + all_source_ranks = list(mamba_range) + else: + num_mamba_reads = num_fa_reads + all_source_ranks = list(fa_source_ranks) + + # ---- FA descriptor bytes ---- + effective_block_len = min(local_block_len, remote_block_len) + if is_kv_layout_blocks_first: + fa_descriptor_bytes = effective_block_len // 2 + else: + fa_descriptor_bytes = effective_block_len + + # ---- Validation ---- + is_local_replicated = local_tp > K + if is_local_replicated and is_remote_replicated and tp_ratio > 0: + logger.info( + "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", + local_tp, + remote_tp_size, + K, + ) + tt_set = set(all_source_ranks) + for t in fa_source_ranks: + if t not in tt_set: + logger.error( + "FA source rank %d NOT in all_source_ranks %s.", + t, + all_source_ranks, + ) + if is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: + local_k_half = local_block_len // 2 + remote_k_half = remote_block_len // 2 + expected = local_k_half // num_fa_reads + if expected != remote_k_half: + logger.warning( + "FA size mismatch: local_k_half=%d / reads=%d = %d, " + "but remote_k_half=%d.", + local_k_half, + num_fa_reads, + expected, + remote_k_half, + ) + + return MambaEngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + remote_fa_source_ranks=tuple(fa_source_ranks), + remote_all_source_ranks=tuple(all_source_ranks), + remote_num_fa_reads=num_fa_reads, + remote_num_mamba_reads=num_mamba_reads, + remote_fa_descriptor_bytes=fa_descriptor_bytes, + is_remote_replicated=is_remote_replicated, + remote_physical_heads=remote_physical_heads, + ) From f2524c0d24070aa6db2298bd069f54cdd15ad91f Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 16 Apr 2026 22:02:05 +0000 Subject: [PATCH 02/49] wire policy into TransferTopology, delegate build_engine_transfer_info Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_nixl_connector.py | 5 +++ .../kv_transfer/kv_connector/utils.py | 39 +++++++++++-------- .../v1/nixl/block_transfer_policy.py | 25 ++++++------ .../kv_connector/v1/nixl/worker.py | 9 +++++ 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fb4b641e1376..2f44895816e0 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -39,6 +39,9 @@ NixlHandshakePayload, NixlKVConnectorStats, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( + DenseModelBlockTransferPolicy, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( compute_nixl_compatibility_hash, ) @@ -472,6 +475,7 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + policy=DenseModelBlockTransferPolicy(kv_cache_config), tensor_shape=test_shape, ) @@ -2435,6 +2439,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], + policy=DenseModelBlockTransferPolicy(decode_worker.kv_cache_config), tensor_shape=test_shape, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 5aae2ef4daf3..c0953405e73c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -20,7 +20,12 @@ from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + from vllm.distributed.kv_transfer.kv_connector.base import ( + KVConnectorBase, + ) + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 + ModelBlockTransferPolicy, + ) from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -441,6 +446,7 @@ class TransferTopology: is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] + policy: "ModelBlockTransferPolicy | None" = None tensor_shape: torch.Size | None = None def __post_init__(self): @@ -518,22 +524,21 @@ def register_remote_engine( ) if remote_engine_id in self._engines: return self._engines[remote_engine_id] - info: EngineTransferInfo - if self.is_mamba: - info = self._build_mamba_info( - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_len, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - local_block_len=local_block_len, - ) - else: - info = EngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - ) + assert self.policy is not None, ( + "TransferTopology.policy must be set before registering engines" + ) + info = self.policy.build_engine_transfer_info( + tp_rank=self.tp_rank, + tp_size=self.tp_size, + is_mla=self.is_mla, + total_num_kv_heads=self.total_num_kv_heads, + is_kv_layout_blocks_first=self.is_kv_layout_blocks_first, + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_len, + remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), + local_block_len=local_block_len, + ) self._engines[remote_engine_id] = info return info diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index c7cf57d33fbc..a644c1807c64 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -112,17 +112,17 @@ def create( kv_cache_config: KVCacheConfig, tp_size: int ) -> ModelBlockTransferPolicy: """Create the appropriate policy based on model architecture.""" - group_flags = [ + is_mamba_group = [ isinstance(group.kv_cache_spec, MambaSpec) for group in kv_cache_config.kv_cache_groups ] - if any(group_flags): + if any(is_mamba_group): return MambaModelBlockTransferPolicy( kv_cache_config=kv_cache_config, - group_flags=group_flags, + is_mamba_group=is_mamba_group, tp_size=tp_size, ) - return DenseModelBlockTransferPolicy(group_flags=group_flags) + return DenseModelBlockTransferPolicy(kv_cache_config) # ====================================================================== @@ -131,8 +131,8 @@ def create( class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): - def __init__(self, group_flags: list[bool]): - self._group_flags = group_flags + def __init__(self, kv_cache_config: KVCacheConfig): + self._num_groups = len(kv_cache_config.kv_cache_groups) @property def is_mamba(self) -> bool: @@ -140,7 +140,7 @@ def is_mamba(self) -> bool: @property def mamba_group_flags(self) -> list[bool]: - return self._group_flags + return [False] * self._num_groups @property def ssm_sizes(self) -> tuple[int, int]: @@ -181,16 +181,15 @@ class MambaModelBlockTransferPolicy(ModelBlockTransferPolicy): def __init__( self, kv_cache_config: KVCacheConfig, - group_flags: list[bool], + is_mamba_group: list[bool], tp_size: int, ): - self._group_flags = group_flags + self._is_mamba_group = is_mamba_group mamba_spec = next( - spec + group.kv_cache_spec for group in kv_cache_config.kv_cache_groups - for spec in [group.kv_cache_spec] - if isinstance(spec, MambaSpec) + if isinstance(group.kv_cache_spec, MambaSpec) ) conv_nbytes = torch.tensor( [], @@ -219,7 +218,7 @@ def is_mamba(self) -> bool: @property def mamba_group_flags(self) -> list[bool]: - return self._group_flags + return self._is_mamba_group @property def ssm_sizes(self) -> tuple[int, int]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 607bf4b988ff..3341eb7a46e7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -30,6 +30,9 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( + ModelBlockTransferPolicy, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( GET_META_MSG, NixlAgentMetadata, @@ -158,6 +161,11 @@ def __init__( local_tp = vllm_config.parallel_config.tensor_parallel_size self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) + self.block_transfer_policy = ModelBlockTransferPolicy.create( + kv_cache_config=kv_cache_config, + tp_size=vllm_config.parallel_config.tensor_parallel_size, + ) + # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] # Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs. @@ -652,6 +660,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + policy=self.block_transfer_policy, # SSM States come in tuples (ssm, conv) tensor_shape=next(iter(kv_caches.values())).shape if not self._has_mamba From d3b618a0214979ce9631cd25e6029f003fdddd1e Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Fri, 17 Apr 2026 00:32:12 +0000 Subject: [PATCH 03/49] decouple policy from topology; extract orchestration methods Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_nixl_connector.py | 5 - .../kv_transfer/kv_connector/utils.py | 41 +--- .../v1/nixl/block_transfer_policy.py | 182 +++++++++++++++++- .../kv_connector/v1/nixl/worker.py | 51 +++-- 4 files changed, 220 insertions(+), 59 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 2f44895816e0..fb4b641e1376 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -39,9 +39,6 @@ NixlHandshakePayload, NixlKVConnectorStats, ) -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( - DenseModelBlockTransferPolicy, -) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( compute_nixl_compatibility_hash, ) @@ -475,7 +472,6 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - policy=DenseModelBlockTransferPolicy(kv_cache_config), tensor_shape=test_shape, ) @@ -2439,7 +2435,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], - policy=DenseModelBlockTransferPolicy(decode_worker.kv_cache_config), tensor_shape=test_shape, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c0953405e73c..f17c446cf7f7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -20,12 +20,7 @@ from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.base import ( - KVConnectorBase, - ) - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 - ModelBlockTransferPolicy, - ) + from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -446,7 +441,6 @@ class TransferTopology: is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] - policy: "ModelBlockTransferPolicy | None" = None tensor_shape: torch.Size | None = None def __post_init__(self): @@ -499,24 +493,12 @@ def __post_init__(self): def register_remote_engine( self, remote_engine_id: EngineId, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - *, - local_block_len: int = 0, + info: EngineTransferInfo, ) -> EngineTransferInfo: """Register a remote engine, unifying worker dicts state. - Only remote engines should be registered here — the local engine's - identity (tp_size, block_size, etc.) is set via ``__init__`` params. - - For Mamba models, also computes the Mamba transfer plan and - builds the FA source lookup caches. - - Args: - local_block_len: Local representative block_len (bytes). - Required for Mamba models to compute ``fa_descriptor_bytes``. + The caller (worker) is responsible for computing the info via + the transfer policy. This method only stores and deduplicates. """ assert remote_engine_id != self.engine_id, ( f"Cannot register local engine {self.engine_id} as remote. " @@ -524,21 +506,6 @@ def register_remote_engine( ) if remote_engine_id in self._engines: return self._engines[remote_engine_id] - assert self.policy is not None, ( - "TransferTopology.policy must be set before registering engines" - ) - info = self.policy.build_engine_transfer_info( - tp_rank=self.tp_rank, - tp_size=self.tp_size, - is_mla=self.is_mla, - total_num_kv_heads=self.total_num_kv_heads, - is_kv_layout_blocks_first=self.is_kv_layout_blocks_first, - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_len, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - local_block_len=local_block_len, - ) self._engines[remote_engine_id] = info return info diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index a644c1807c64..7cf661826beb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -20,6 +20,7 @@ import torch from vllm.distributed.kv_transfer.kv_connector.utils import ( + BlockIds, EngineTransferInfo, MambaEngineTransferInfo, _physical_head_range, @@ -79,23 +80,26 @@ def conv_decomp(self) -> MambaConvSplitInfo | None: """Conv-state sub-projection decomposition, or None for dense.""" # ------------------------------------------------------------------ - # Per-engine transfer info + # Per-engine transfer info (data operations) # ------------------------------------------------------------------ + # TODO (ZhanqiuHu): Revisit data packing for local facts and remote facts. @abstractmethod def build_engine_transfer_info( self, *, + # Local facts (from TransferTopology). tp_rank: int, tp_size: int, is_mla: bool, total_num_kv_heads: int, is_kv_layout_blocks_first: bool, + local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake). remote_tp_size: int, remote_block_size: int, remote_block_len: int, remote_physical_blocks_per_logical: int, - local_block_len: int, ) -> EngineTransferInfo: """Compute transfer info for a remote engine. @@ -158,11 +162,11 @@ def build_engine_transfer_info( is_mla: bool, total_num_kv_heads: int, is_kv_layout_blocks_first: bool, + local_block_len: int, remote_tp_size: int, remote_block_size: int, remote_block_len: int, remote_physical_blocks_per_logical: int, - local_block_len: int, ) -> EngineTransferInfo: return EngineTransferInfo( remote_tp_size=remote_tp_size, @@ -236,11 +240,11 @@ def build_engine_transfer_info( is_mla: bool, total_num_kv_heads: int, is_kv_layout_blocks_first: bool, + local_block_len: int, remote_tp_size: int, remote_block_size: int, remote_block_len: int, remote_physical_blocks_per_logical: int, - local_block_len: int, ) -> MambaEngineTransferInfo: K = total_num_kv_heads local_tp = tp_size @@ -355,3 +359,173 @@ def build_engine_transfer_info( is_remote_replicated=is_remote_replicated, remote_physical_heads=remote_physical_heads, ) + + # ------------------------------------------------------------------ + # Orchestration methods + # ------------------------------------------------------------------ + + def should_skip_fa(self, info: MambaEngineTransferInfo, remote_rank: int) -> bool: + """Whether to skip FA groups for this remote rank.""" + return remote_rank not in info.fa_source_set + + def fa_head_slot( + self, + info: MambaEngineTransferInfo, + remote_rank: int, + total_num_kv_heads: int, + ) -> int: + """Index into local FA block for this remote rank's head data. + + For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. + For ranks NOT in ``fa_source_ranks`` (replicated duplicates), + returns the slot of the matching source rank with the same head. + """ + fa_index = info.fa_source_indices + if remote_rank in fa_index: + return fa_index[remote_rank] + K = total_num_kv_heads + remote_tp = info.remote_tp_size + r_head = _physical_head_range(remote_tp, K, remote_rank) + for target in info.remote_fa_source_ranks: + t_head = _physical_head_range(remote_tp, K, target) + if _range_overlap(r_head, t_head): + return fa_index[target] + return 0 + + def fa_rank_offset( + self, + info: MambaEngineTransferInfo, + remote_kv_block_len: int, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + ) -> int: + """Byte offset into remote FA block for this local rank. + + When local TP is replicated (local_tp > K), multiple local ranks + share a head. Computes offset *relative to the target remote + rank's first head* so it works regardless of how many heads the + remote has. Returns 0 when local does not index into remote. + """ + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) # noqa: E501 + if is_mla or tp_ratio <= 0: + return 0 + K = total_num_kv_heads + is_local_replicated = tp_size > K + if is_local_replicated: + local_head = tp_rank * K // tp_size + p_rank = info.remote_fa_source_ranks[0] + p_start = p_rank * K // info.remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return tp_rank % tp_ratio * remote_kv_block_len + + def needs_split_handles( + self, + info: MambaEngineTransferInfo, + tp_size: int, + is_mla: bool, + ) -> bool: + """Whether per-remote-rank split handles are needed. + + True when FA and mamba have different read counts, requiring + different splitting factors in the local handle. + """ + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) # noqa: E501 + return tp_ratio < 0 and not is_mla and len(info.remote_all_source_ranks) > 1 + + def compute_split_handle_data( + self, + info: MambaEngineTransferInfo, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, + abs_tp: int, + total_num_kv_heads: int, + ) -> list[list[tuple[int, int, int]]]: + """Per-remote-rank (addr, len, dev) triples for split handles. + + FA descriptors (indices < num_fa_descs) are sliced by + ``remote_num_fa_reads``; mamba descriptors are sliced uniformly + by ``abs_tp``. + """ + all_handle_data: list[list[tuple[int, int, int]]] = [] + for p_idx, p_rank in enumerate(info.remote_all_source_ranks): + handle_data: list[tuple[int, int, int]] = [] + skip_fa = self.should_skip_fa(info, p_rank) + fa_slot = ( + self.fa_head_slot(info, p_rank, total_num_kv_heads) + if not skip_fa + else 0 + ) + for j, (addr, local_len, dev) in enumerate(src_blocks_data): + if j < num_fa_descs: + assert info.remote_num_fa_reads >= 1 + fa_chunk = local_len // info.remote_num_fa_reads + handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) + else: + mamba_chunk = local_len // abs_tp + handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) + all_handle_data.append(handle_data) + return all_handle_data + + def filter_block_ids_for_rank( + self, + info: MambaEngineTransferInfo, + remote_rank: int, + local_ids: BlockIds, + remote_ids: BlockIds, + ) -> tuple[BlockIds, BlockIds]: + """Zero out FA groups for remote ranks outside ``fa_source_ranks``. + + Returns (filtered_local_ids, filtered_remote_ids). When the + remote rank carries FA data for this local rank, returns the + inputs unchanged. + """ + if not self.should_skip_fa(info, remote_rank): + return local_ids, remote_ids + num_groups = len(local_ids) + filtered_local: list[list[int]] = [ + [] if not self._is_mamba_group[g] else local_ids[g] + for g in range(num_groups) + ] + filtered_remote: list[list[int]] = [ + [] if not self._is_mamba_group[g] else remote_ids[g] + for g in range(num_groups) + ] + return filtered_local, filtered_remote + + def describe_mamba( + self, + info: MambaEngineTransferInfo, + tp_rank: int, + tp_size: int, + total_num_kv_heads: int, + ) -> str: + """One-line summary of Mamba transfer config for logging.""" + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) # noqa: E501 + return ( + f"MambaTransferPolicy(" + f"tp_ratio={tp_ratio}, " + f"K={total_num_kv_heads}, " + f"local_tp={tp_size}, " + f"remote_tp={info.remote_tp_size}, " + f"local_rank={tp_rank}, " + f"fa_reads={info.remote_num_fa_reads}, " + f"mamba_reads={info.remote_num_mamba_reads}, " + f"fa_sources={list(info.remote_fa_source_ranks)}, " + f"all_sources={list(info.remote_all_source_ranks)}, " + f"fa_desc_bytes={info.remote_fa_descriptor_bytes}, " + f"remote_block_len={info.remote_block_len})" + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 3341eb7a46e7..03552fa9884c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -31,6 +31,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( + MambaModelBlockTransferPolicy, ModelBlockTransferPolicy, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( @@ -161,7 +162,7 @@ def __init__( local_tp = vllm_config.parallel_config.tensor_parallel_size self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) - self.block_transfer_policy = ModelBlockTransferPolicy.create( + self.transfer_policy = ModelBlockTransferPolicy.create( kv_cache_config=kv_cache_config, tp_size=vllm_config.parallel_config.tensor_parallel_size, ) @@ -660,7 +661,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - policy=self.block_transfer_policy, # SSM States come in tuples (ssm, conv) tensor_shape=next(iter(kv_caches.values())).shape if not self._has_mamba @@ -921,6 +921,7 @@ def _build_fa_remote_for_mamba( ) # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA # hetero-TP logic stabilizes. + assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) mamba_info = transfer_topo.get_engine_info(remote_engine_id) assert isinstance(mamba_info, MambaEngineTransferInfo) tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size) @@ -936,8 +937,13 @@ def _build_fa_remote_for_mamba( if tp_ratio < 0 and not self.use_mla: local_block_len = local_block_len // mamba_info.remote_num_fa_reads - rank_offset = transfer_topo.fa_rank_offset( - remote_engine_id, remote_kv_block_len + rank_offset = self.transfer_policy.fa_rank_offset( + mamba_info, + remote_kv_block_len, + tp_rank=transfer_topo.tp_rank, + tp_size=transfer_topo.tp_size, + is_mla=transfer_topo.is_mla, + total_num_kv_heads=transfer_topo.total_num_kv_heads, ) num_blocks = nixl_agent_meta.num_blocks @@ -1176,14 +1182,21 @@ def add_remote_agent( if self._has_mamba else 1 ) - transfer_topo.register_remote_engine( - remote_engine_id=engine_id, + transfer_info = self.transfer_policy.build_engine_transfer_info( + # Local facts (from TransferTopology). + tp_rank=transfer_topo.tp_rank, + tp_size=transfer_topo.tp_size, + is_mla=transfer_topo.is_mla, + total_num_kv_heads=transfer_topo.total_num_kv_heads, + is_kv_layout_blocks_first=transfer_topo.is_kv_layout_blocks_first, + local_block_len=self.block_len_per_layer[0], + # Remote facts (from NixlAgentMetadata handshake). remote_tp_size=remote_tp_size, remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], remote_physical_blocks_per_logical=physical_blocks_per_logical, - local_block_len=self.block_len_per_layer[0], ) + transfer_topo.register_remote_engine(engine_id, transfer_info) if self._has_mamba and engine_id not in self._physical_blocks_per_logical: self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical @@ -1240,10 +1253,21 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] if self._has_mamba: - if transfer_topo.needs_split_handles(engine_id): + assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) + mamba_info = transfer_topo.get_engine_info(engine_id) + assert isinstance(mamba_info, MambaEngineTransferInfo) + if self.transfer_policy.needs_split_handles( + mamba_info, + tp_size=transfer_topo.tp_size, + is_mla=transfer_topo.is_mla, + ): # Mamba-HMA: FA and Mamba use different split factors. - for handle_data in transfer_topo.compute_split_handle_data( - engine_id, self.src_blocks_data, self.num_descs, abs_tp + for handle_data in self.transfer_policy.compute_split_handle_data( + mamba_info, + self.src_blocks_data, + self.num_descs, + abs_tp, + total_num_kv_heads=transfer_topo.total_num_kv_heads, ): descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type @@ -1958,12 +1982,13 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ids: BlockIds = meta.remote.block_ids if self._has_mamba: # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. - local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank( - engine_id, + assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) + assert isinstance(remote_info, MambaEngineTransferInfo) + local_ids, remote_ids = self.transfer_policy.filter_block_ids_for_rank( + remote_info, remote_rank, local_ids, remote_ids, - self._is_mamba_group, ) self._read_blocks( From 6bde491c5f1c1dde148a920ecbc891facf22620a Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Fri, 17 Apr 2026 01:27:01 +0000 Subject: [PATCH 04/49] extract descriptor building, block ID mapping, and read specs into ModelBlockTransferPolicy Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 864 +++++++++++++++++- .../kv_connector/v1/nixl/worker.py | 796 +++------------- 2 files changed, 992 insertions(+), 668 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 7cf661826beb..2e23da1624c1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -15,8 +15,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast +import numpy as np import torch from vllm.distributed.kv_transfer.kv_connector.utils import ( @@ -33,13 +35,30 @@ from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.v1.kv_cache_interface import MambaSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: - from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, + ) + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec logger = init_logger(__name__) +@dataclass(frozen=True) +class ReadSpec: + """Specification for a single remote block read operation. + + Computed upfront by ``compute_read_specs`` so that the worker's read + loop is a simple iteration with no model-specific branching. + """ + + remote_rank: int + local_block_ids: BlockIds + remote_block_ids: BlockIds + + class ModelBlockTransferPolicy(ABC): """Abstract base for model-specific block transfer logic. @@ -49,6 +68,14 @@ class ModelBlockTransferPolicy(ABC): - Per-engine transfer info computation (``build_engine_transfer_info``) """ + def __init__( + self, + kv_cache_config: KVCacheConfig, + physical_blocks_per_logical: int, + ): + self._kv_cache_config = kv_cache_config + self._physical_blocks_per_logical = physical_blocks_per_logical + # ------------------------------------------------------------------ # Model identity # ------------------------------------------------------------------ @@ -107,13 +134,230 @@ def build_engine_transfer_info( Mamba models return ``MambaEngineTransferInfo``. """ + # ------------------------------------------------------------------ + # Registration helpers + # ------------------------------------------------------------------ + + @abstractmethod + def compute_page_size( + self, + layer_spec: KVCacheSpec, + physical_ratio: int, + ) -> int: + """Physical page size in bytes for one layer.""" + ... + + @abstractmethod + def get_num_blocks( + self, + layer_spec: KVCacheSpec, + num_blocks: int, + logical_num_blocks: int, + ) -> int: + """Number of blocks to register for this layer spec.""" + ... + + @abstractmethod + def compute_layer_block_bytes( + self, + layer_spec: KVCacheSpec, + physical_page_size: int, + physical_ratio: int, + ) -> int: + """Block byte size for one layer (entry for ``block_len_per_layer``).""" + ... + + @abstractmethod + def get_tensor_shape( + self, + kv_caches: dict[str, torch.Tensor], + ) -> torch.Size | None: + """Tensor shape for ``TpKVTopology`` (None for Mamba).""" + ... + + @abstractmethod + def get_block_len( + self, + layer_idx: int, + first_split: bool, + block_len_per_layer: list[int], + is_blocks_first: bool, + mamba_view: bool = False, + ) -> int: + """Block length for one K/V (or conv/ssm) element.""" + ... + + # ------------------------------------------------------------------ + # Descriptor ID computation + block ID mapping + # ------------------------------------------------------------------ + + @abstractmethod + def get_block_descs_ids( + self, + block_ids: BlockIds, + num_regions: int, + dst_num_blocks: int, + block_len_per_layer: list[int], + block_size_ratio: float | None = None, + physical_blocks_per_logical: int = 1, + ) -> np.ndarray: + """Compute NIXL descriptor IDs for a set of block IDs.""" + ... + + @abstractmethod + def logical_to_kernel_block_ids( + self, + block_ids: BlockIds, + ) -> BlockIds: + """Convert logical block IDs to kernel physical block IDs.""" + ... + + @abstractmethod + def logical_to_remote_kernel_block_ids( + self, + block_ids: BlockIds, + remote_ratio: int, + ) -> BlockIds: + """Map logical block IDs to physical kernel block IDs on remote.""" + ... + + # ------------------------------------------------------------------ + # Local descriptor building + # ------------------------------------------------------------------ + + @abstractmethod + def build_local_descs( + self, + base_addresses: list[int], + block_len_per_layer: list[int], + num_blocks: int, + logical_num_blocks: int, + block_size_ratio: int, + device_id: int, + is_blocks_first: bool, + ) -> list[tuple[int, int, int]]: + """Build local (src) descriptor tuples for NIXL registration.""" + ... + + def _build_fa_local_descs( + self, + base_addresses: list[int], + block_len_per_layer: list[int], + num_blocks: int, + block_size_ratio: int, + device_id: int, + is_blocks_first: bool, + ) -> list[tuple[int, int, int]]: + """Build FA local descriptors (shared by Dense and Mamba).""" + result: list[tuple[int, int, int]] = [] + n_blocks = num_blocks * block_size_ratio + for i, base_addr in enumerate(base_addresses): + kv_block_len = ( + self.get_block_len( + i, + True, + block_len_per_layer, + is_blocks_first, + ) + // block_size_ratio + ) + page_stride = block_len_per_layer[i] // block_size_ratio + for block_id in range(n_blocks): + result.append( + ( + base_addr + block_id * page_stride, + kv_block_len, + device_id, + ) + ) + if is_blocks_first: + second_split = self.get_block_len( + i, + False, + block_len_per_layer, + is_blocks_first, + ) + for block_id in range(n_blocks): + v_addr = base_addr + block_id * page_stride + kv_block_len + result.append( + ( + v_addr, + second_split, + device_id, + ) + ) + return result + + # ------------------------------------------------------------------ + # Remote descriptor building + # ------------------------------------------------------------------ + + @abstractmethod + def build_remote_descs( + self, + nixl_agent_meta: NixlAgentMetadata, + block_size_ratio: int, + tp_ratio: int, + tp_rank: int, + use_mla: bool, + block_len_per_layer: list[int], + is_blocks_first: bool, + indexes_into_remote: bool, + transfer_config: Any | None = None, + physical_blocks_per_logical: int = 1, + tp_size: int = 1, + total_num_kv_heads: int = 1, + ) -> list[tuple[int, int, int]]: + """Build remote (dst) descriptor tuples.""" + ... + + @abstractmethod + def build_src_split_handles( + self, + src_blocks_data: list[tuple[int, int, int]], + num_descs: int, + abs_tp: int, + transfer_config: Any | None = None, + tp_size: int = 1, + is_mla: bool = False, + total_num_kv_heads: int = 1, + ) -> list[list[tuple[int, int, int]]]: + """Build split handle data for P_TP > D_TP scenario.""" + ... + + def compute_read_specs( + self, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, + remote_ranks: list[int], + physical_blocks_per_logical: int = 1, + transfer_config: Any | None = None, + ) -> list[ReadSpec]: + """Compute the full set of read operations needed for a request. + + Returns one ``ReadSpec`` per remote rank that requires a read. + The worker iterates the result without model-specific branching. + MLA trimming (keeping only the first spec) is handled by the worker. + """ + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + ) + for rank in remote_ranks + ] + # ------------------------------------------------------------------ # Factory # ------------------------------------------------------------------ @staticmethod def create( - kv_cache_config: KVCacheConfig, tp_size: int + kv_cache_config: KVCacheConfig, + layer_specs: dict[str, KVCacheSpec], + physical_blocks_per_logical: int, + tp_size: int, ) -> ModelBlockTransferPolicy: """Create the appropriate policy based on model architecture.""" is_mamba_group = [ @@ -125,8 +369,13 @@ def create( kv_cache_config=kv_cache_config, is_mamba_group=is_mamba_group, tp_size=tp_size, + layer_specs=layer_specs, + physical_blocks_per_logical=physical_blocks_per_logical, ) - return DenseModelBlockTransferPolicy(kv_cache_config) + return DenseModelBlockTransferPolicy( + kv_cache_config, + physical_blocks_per_logical, + ) # ====================================================================== @@ -135,7 +384,12 @@ def create( class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): - def __init__(self, kv_cache_config: KVCacheConfig): + def __init__( + self, + kv_cache_config: KVCacheConfig, + physical_blocks_per_logical: int, + ): + super().__init__(kv_cache_config, physical_blocks_per_logical) self._num_groups = len(kv_cache_config.kv_cache_groups) @property @@ -154,6 +408,179 @@ def ssm_sizes(self) -> tuple[int, int]: def conv_decomp(self) -> MambaConvSplitInfo | None: return None + def compute_page_size(self, layer_spec, physical_ratio): + return layer_spec.page_size_bytes // physical_ratio + + def get_num_blocks(self, layer_spec, num_blocks, logical_num_blocks): + return num_blocks + + def compute_layer_block_bytes(self, layer_spec, physical_page_size, physical_ratio): + return physical_page_size + + def get_tensor_shape(self, kv_caches): + return next(iter(kv_caches.values())).shape + + def get_block_len( + self, + layer_idx, + first_split, + block_len_per_layer, + is_blocks_first, + mamba_view=False, + ): + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + def get_block_descs_ids( + self, + block_ids, + num_regions, + dst_num_blocks, + block_len_per_layer, + block_size_ratio=None, + physical_blocks_per_logical=1, + ): + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + region_ids = np.arange(num_regions)[:, None] + block_ids_arr = np.concatenate(block_ids)[None, :] + return (region_ids * num_blocks + block_ids_arr).flatten() + + def logical_to_kernel_block_ids(self, block_ids): + if self._physical_blocks_per_logical == 1: + return block_ids + block_arange = np.arange( + 0, + self._physical_blocks_per_logical, + ).reshape(1, -1) + return [ + BlockTable.map_to_kernel_blocks( + np.array(group), + self._physical_blocks_per_logical, + block_arange, + ).tolist() + for group in block_ids + ] + + def logical_to_remote_kernel_block_ids( + self, + block_ids, + remote_ratio, + ): + if remote_ratio == 1: + return block_ids + local_arange = np.arange( + self._physical_blocks_per_logical, + ).reshape(1, -1) + return [ + (np.array(group).reshape(-1, 1) * remote_ratio + local_arange) + .flatten() + .tolist() + for group in block_ids + ] + + def build_local_descs( + self, + base_addresses, + block_len_per_layer, + num_blocks, + logical_num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ): + return self._build_fa_local_descs( + base_addresses, + block_len_per_layer, + num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ) + + def build_remote_descs( + self, + nixl_agent_meta, + block_size_ratio, + tp_ratio, + tp_rank, + use_mla, + block_len_per_layer, + is_blocks_first, + indexes_into_remote, + transfer_config=None, + physical_blocks_per_logical=1, + tp_size: int = 1, + total_num_kv_heads: int = 1, + ): + _ = (tp_size, total_num_kv_heads, physical_blocks_per_logical, transfer_config) + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + local_block_len = self.get_block_len( + i, + True, + block_len_per_layer, + is_blocks_first, + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + if tp_ratio < 0 and not use_mla: + local_block_len = local_block_len // (-tp_ratio) + rank_offset = ( + tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 + ) + num_blocks = nixl_agent_meta.num_blocks + page_size = nixl_agent_meta.block_lens[i] + dev_id = nixl_agent_meta.device_id + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + result.append((addr, local_block_len, dev_id)) + if is_blocks_first: + second_split = self.get_block_len( + i, + False, + block_len_per_layer, + is_blocks_first, + ) + if tp_ratio < 0 and not use_mla: + second_split = second_split // (-tp_ratio) + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + result.append((v_addr, second_split, dev_id)) + return result + + def build_src_split_handles( + self, + src_blocks_data, + num_descs, + abs_tp, + transfer_config=None, + tp_size: int = 1, + is_mla: bool = False, + total_num_kv_heads: int = 1, + ): + _ = (num_descs, transfer_config, tp_size, is_mla, total_num_kv_heads) + result: list[list[tuple[int, int, int]]] = [] + for i in range(abs_tp): + blocks_data: list[tuple[int, int, int]] = [] + for addr, local_block_len, own_rank in src_blocks_data: + remote_block_len = local_block_len // abs_tp + blocks_data.append( + ( + addr + i * remote_block_len, + remote_block_len, + own_rank, + ) + ) + result.append(blocks_data) + return result + def build_engine_transfer_info( self, *, @@ -187,13 +614,14 @@ def __init__( kv_cache_config: KVCacheConfig, is_mamba_group: list[bool], tp_size: int, + layer_specs: dict[str, KVCacheSpec], + physical_blocks_per_logical: int, ): + super().__init__(kv_cache_config, physical_blocks_per_logical) self._is_mamba_group = is_mamba_group mamba_spec = next( - group.kv_cache_spec - for group in kv_cache_config.kv_cache_groups - if isinstance(group.kv_cache_spec, MambaSpec) + spec for spec in layer_specs.values() if isinstance(spec, MambaSpec) ) conv_nbytes = torch.tensor( [], @@ -232,6 +660,426 @@ def ssm_sizes(self) -> tuple[int, int]: def conv_decomp(self) -> MambaConvSplitInfo | None: return self._conv_decomp + def compute_page_size(self, layer_spec, physical_ratio): + if isinstance(layer_spec, MambaSpec): + return layer_spec.page_size_bytes + return layer_spec.page_size_bytes // physical_ratio + + def get_num_blocks(self, layer_spec, num_blocks, logical_num_blocks): + if isinstance(layer_spec, MambaSpec): + return logical_num_blocks + return num_blocks + + def compute_layer_block_bytes(self, layer_spec, physical_page_size, physical_ratio): + if isinstance(layer_spec, MambaSpec): + return physical_page_size // physical_ratio + return physical_page_size + + def get_tensor_shape(self, kv_caches): + return None + + def get_block_len( + self, + layer_idx, + first_split, + block_len_per_layer, + is_blocks_first, + mamba_view=False, + ): + if is_blocks_first: + if mamba_view: + return self._ssm_sizes[not first_split] + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + def get_block_descs_ids( + self, + block_ids, + num_regions, + dst_num_blocks, + block_len_per_layer, + block_size_ratio=None, + physical_blocks_per_logical=1, + ): + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + region_ids = np.arange(num_regions)[:, None] + ratio = physical_blocks_per_logical + logical_blocks = num_blocks // ratio + num_fa_descs = num_regions * num_blocks + mamba_region_ids = np.arange( + len(block_len_per_layer) * 4, + )[:, None] + all_descs: list[np.ndarray] = [] + for i, group in enumerate(block_ids): + group_arr = np.asarray(group)[None, :] + if self._is_mamba_group[i]: + all_descs.append( + ( + mamba_region_ids * logical_blocks + group_arr + num_fa_descs + ).flatten() + ) + else: + all_descs.append((region_ids * num_blocks + group_arr).flatten()) + return np.concatenate(all_descs) + + def logical_to_kernel_block_ids(self, block_ids): + if self._physical_blocks_per_logical == 1: + return block_ids + block_arange = np.arange( + 0, + self._physical_blocks_per_logical, + ).reshape(1, -1) + group_specs = self._kv_cache_config.kv_cache_groups + return [ + BlockTable.map_to_kernel_blocks( + np.array(group), + self._physical_blocks_per_logical, + block_arange, + ).tolist() + if not isinstance( + group_specs[i].kv_cache_spec, + MambaSpec, + ) + else group + for i, group in enumerate(block_ids) + ] + + def logical_to_remote_kernel_block_ids( + self, + block_ids, + remote_ratio, + ): + if remote_ratio == 1: + return block_ids + local_arange = np.arange( + self._physical_blocks_per_logical, + ).reshape(1, -1) + group_specs = self._kv_cache_config.kv_cache_groups + result: list[list[int]] = [] + for i, group in enumerate(block_ids): + if not isinstance( + group_specs[i].kv_cache_spec, + MambaSpec, + ): + arr = np.array(group).reshape(-1, 1) + expanded = (arr * remote_ratio + local_arange).flatten() + result.append(expanded.tolist()) + else: + result.append(group) + return result + + def build_local_descs( + self, + base_addresses, + block_len_per_layer, + num_blocks, + logical_num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ): + fa_descs = self._build_fa_local_descs( + base_addresses, + block_len_per_layer, + num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ) + num_regions = len(base_addresses) * (2 if is_blocks_first else 1) + assert len(fa_descs) == num_regions * num_blocks + logger.debug("Registering local Mamba descriptors (4 regions/layer)") + mamba_descs = self._build_mamba_local_descs( + base_addresses, + block_len_per_layer, + logical_num_blocks, + block_size_ratio, + device_id, + ) + return fa_descs + mamba_descs + + def _build_mamba_local_descs( + self, + base_addresses: list[int], + block_len_per_layer: list[int], + logical_num_blocks: int, + block_size_ratio: int, + device_id: int, + ) -> list[tuple[int, int, int]]: + """Build 4 desc regions (x, B, C, ssm) per layer for local + mamba blocks, enabling the 3-read transfer with DS conv layout. + """ + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got block_size_ratio={block_size_ratio}." + ) + conv_offsets = self._conv_decomp.local_conv_offsets + conv_size, ssm_size = self._ssm_sizes + n_blocks = logical_num_blocks * block_size_ratio + phys_ratio = self._physical_blocks_per_logical + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio + for off, sz in conv_offsets: + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + off, + sz, + device_id, + ) + ) + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + conv_size, + ssm_size, + device_id, + ) + ) + return result + + def build_remote_descs( + self, + nixl_agent_meta, + block_size_ratio, + tp_ratio, + tp_rank, + use_mla, + block_len_per_layer, + is_blocks_first, + indexes_into_remote, + transfer_config=None, + physical_blocks_per_logical=1, + tp_size: int = 1, + total_num_kv_heads: int = 1, + ): + _ = indexes_into_remote + info = cast(MambaEngineTransferInfo, transfer_config) + result: list[tuple[int, int, int]] = [] + result.extend( + self._build_fa_remote_descs( + nixl_agent_meta, + info, + tp_ratio, + tp_rank, + tp_size, + total_num_kv_heads, + block_size_ratio, + is_blocks_first, + use_mla, + block_len_per_layer, + ) + ) + result.extend( + self._build_mamba_remote_descs( + nixl_agent_meta, + tp_ratio, + tp_rank, + physical_blocks_per_logical, + ) + ) + return result + + def _build_fa_remote_descs( + self, + nixl_agent_meta, + info: MambaEngineTransferInfo, + tp_ratio: int, + tp_rank: int, + tp_size: int, + total_num_kv_heads: int, + block_size_ratio: int, + is_blocks_first: bool, + use_mla: bool, + block_len_per_layer: list[int], + ): + """Build remote FA descriptors for mamba models using + transfer_cfg for GQA-aware sizing.""" + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + local_block_len = self.get_block_len( + i, + True, + block_len_per_layer, + is_blocks_first, + ) + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + if tp_ratio < 0 and not use_mla: + local_block_len = local_block_len // info.remote_num_fa_reads + rank_offset = self.fa_rank_offset( + info, + remote_kv_block_len, + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=use_mla, + total_num_kv_heads=total_num_kv_heads, + ) + num_blocks = nixl_agent_meta.num_blocks + page_size = nixl_agent_meta.block_lens[i] + dev_id = nixl_agent_meta.device_id + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + result.append((addr, local_block_len, dev_id)) + if is_blocks_first: + second_split = self.get_block_len( + i, + False, + block_len_per_layer, + is_blocks_first, + ) + if tp_ratio < 0 and not use_mla: + second_split = second_split // info.remote_num_fa_reads + for blk in range(num_blocks): + addr = base_addr + blk * page_size + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + result.append( + ( + v_addr, + second_split, + dev_id, + ) + ) + return result + + def _build_mamba_remote_descs( + self, + nixl_agent_meta, + tp_ratio, + tp_rank, + physical_blocks_per_logical, + ): + """Build 4 remote desc regions (x, B, C, ssm) per layer + for the 3-read transfer.""" + effective_ratio = max(tp_ratio, 1) + local_offset = tp_rank % effective_ratio + conv_size_remote = nixl_agent_meta.ssm_sizes[0] + + if tp_ratio >= 1: + conv_offsets = self._conv_decomp.remote_conv_offsets( + local_offset, + effective_ratio, + ) + ssm_read_size = self._ssm_sizes[1] + else: + abs_ratio = -tp_ratio + xb_p = self._conv_decomp.x_bytes // abs_ratio + bb_p = self._conv_decomp.b_bytes // abs_ratio + conv_offsets = [ + (0, xb_p), + (xb_p, bb_p), + (xb_p + bb_p, bb_p), + ] + ssm_read_size = nixl_agent_meta.ssm_sizes[1] + + remote_ratio = physical_blocks_per_logical + num_blocks = nixl_agent_meta.num_blocks // remote_ratio + dev_id = nixl_agent_meta.device_id + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr, + ): + page_stride = nixl_agent_meta.block_lens[i] * remote_ratio + for off, sz in conv_offsets: + for blk in range(num_blocks): + result.append( + ( + base_addr + blk * page_stride + off, + sz, + dev_id, + ) + ) + for blk in range(num_blocks): + ssm_addr = ( + base_addr + + blk * page_stride + + conv_size_remote + + local_offset * ssm_read_size + ) + result.append((ssm_addr, ssm_read_size, dev_id)) + return result + + def build_src_split_handles( + self, + src_blocks_data, + num_descs, + abs_tp, + transfer_config=None, + tp_size: int = 1, + is_mla: bool = False, + total_num_kv_heads: int = 1, + ): + info = cast(MambaEngineTransferInfo, transfer_config) + assert transfer_config is not None + if self.needs_split_handles( + info, + tp_size=tp_size, + is_mla=is_mla, + ): + result = list( + self.compute_split_handle_data( + info, + src_blocks_data, + num_descs, + abs_tp, + total_num_kv_heads=total_num_kv_heads, + ) + ) + logger.info( + "Mamba-HMA split handles: targets=%s, fa_reads=%s, " + "fa_entry=%s, mamba_reads=%s, num_descs=%s", + info.remote_all_source_ranks, + info.remote_num_fa_reads, + info.remote_fa_descriptor_bytes, + info.remote_num_mamba_reads, + num_descs, + ) + return result + return [] + + def compute_read_specs( + self, + local_block_ids, + remote_block_ids, + remote_ranks, + physical_blocks_per_logical=1, + transfer_config=None, + ): + expanded = self.logical_to_remote_kernel_block_ids( + remote_block_ids, + physical_blocks_per_logical, + ) + info = cast(MambaEngineTransferInfo, transfer_config) + assert transfer_config is not None + specs: list[ReadSpec] = [] + for rank in remote_ranks: + filtered_local, filtered_remote = self.filter_block_ids_for_rank( + info, + rank, + local_block_ids, + expanded, + ) + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=filtered_local, + remote_block_ids=filtered_remote, + ) + ) + return specs + def build_engine_transfer_info( self, *, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 03552fa9884c..0147109249dd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -52,9 +52,7 @@ zmq_ctx, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( - MambaConvSplitInfo, compute_physical_blocks_per_logical, - derive_mamba_conv_split, ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.distributed.parallel_state import ( @@ -62,16 +60,13 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.kv_cache_interface import ( FullAttentionSpec, - MambaSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: @@ -123,50 +118,6 @@ def __init__( } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) - # ---- Mamba model state (derived from model config) ---- - self._is_mamba_group = [ - isinstance(group.kv_cache_spec, MambaSpec) - for group in kv_cache_config.kv_cache_groups - ] - mamba_ssm_size = (0, 0) - self._has_mamba = any(self._is_mamba_group) - if self._has_mamba: - assert self._is_hma_required - mamba_spec = next( - spec - for spec in self._layer_specs.values() - if isinstance(spec, MambaSpec) - ) - conv_nbytes, ssm_nbytes = ( - torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc] - torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc] - ) - conv_shape, ssm_shape = ( - torch.Size(mamba_spec.shapes[0]), - torch.Size(mamba_spec.shapes[1]), - ) - mamba_ssm_size = ( - conv_shape.numel() * conv_nbytes, - ssm_shape.numel() * ssm_nbytes, - ) - self._mamba_ssm_size = mamba_ssm_size - # Conv state sub-projection decomposition (None when no Mamba). - # The 3-read transfer requires DS (dim, state_len) conv layout so - # that x/B/C sub-projections are contiguous in memory. - self._conv_decomp: MambaConvSplitInfo | None = None - if self._has_mamba: - assert is_conv_state_dim_first(), ( - "3-read Mamba conv transfer requires DS conv state layout. " - "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" - ) - local_tp = vllm_config.parallel_config.tensor_parallel_size - self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp) - - self.transfer_policy = ModelBlockTransferPolicy.create( - kv_cache_config=kv_cache_config, - tp_size=vllm_config.parallel_config.tensor_parallel_size, - ) - # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] # Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs. @@ -277,12 +228,7 @@ def __init__( self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] - # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- - # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. - # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) - # where conv/ssm bytes are per-TP-rank (dimension-sharded). With - # heterogeneous TP the per-rank sizes differ, so the ratio differs: - # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. + # Per-engine physical-blocks-per-logical ratio (only used for Mamba). self._physical_blocks_per_logical: dict[EngineId, int] = {} # In progress transfers. @@ -339,6 +285,13 @@ def __init__( self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() + self.transfer_policy = ModelBlockTransferPolicy.create( + kv_cache_config=kv_cache_config, + layer_specs=self._layer_specs, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, + tp_size=vllm_config.parallel_config.tensor_parallel_size, + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) @@ -661,11 +614,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - # SSM States come in tuples (ssm, conv) - tensor_shape=next(iter(kv_caches.values())).shape - if not self._has_mamba - else None, - is_mamba=self._has_mamba, + tensor_shape=self.transfer_policy.get_tensor_shape(kv_caches), + is_mamba=self.transfer_policy.is_mamba, ) self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks @@ -712,7 +662,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to - # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. + # that of FI, with block laid out as in `get_block_len`. # However, physical page_size may differ when kernel requires a specific # block size. This leads to SSM and FA layers having different num_blocks. # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. @@ -725,11 +675,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) # `layer_spec.page_size_bytes` only accounts for logical page_size, that is # the page_size assuming constant `self._logical_num_blocks`. - physical_page_size = ( - layer_spec.page_size_bytes - if isinstance(layer_spec, MambaSpec) - else layer_spec.page_size_bytes - // self._physical_blocks_per_logical_kv_block + physical_page_size = self.transfer_policy.compute_page_size( + layer_spec, self._physical_blocks_per_logical_kv_block ) # For when registering multiple tensors eg K/V in separate regions. physical_page_size = physical_page_size // len(cache_list) @@ -738,10 +685,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): physical_page_size = physical_page_size * len( self.kv_cache_config.kv_cache_tensors ) - num_blocks = ( - self._logical_num_blocks - if isinstance(layer_spec, MambaSpec) - else self.num_blocks + num_blocks = self.transfer_policy.get_num_blocks( + layer_spec, self.num_blocks, self._logical_num_blocks ) # `page_size` accounts for physical blocks, st KVCache is always # [`num_blocks` * `page_size`] @@ -764,13 +709,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Registering layer %s with cache shape: %s", layer_name, cache.shape ) seen_base_addresses.append(base_addr) - # Only record non-Mamba page sizes. - if isinstance(layer_spec, MambaSpec): - self.block_len_per_layer.append( - physical_page_size // self._physical_blocks_per_logical_kv_block + self.block_len_per_layer.append( + self.transfer_policy.compute_layer_block_bytes( + layer_spec, + physical_page_size, + self._physical_blocks_per_logical_kv_block, ) - else: - self.block_len_per_layer.append(physical_page_size) + ) assert cache.shape[0] == num_blocks, ( "All kv cache tensors must have the same number of blocks" @@ -820,7 +765,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self._has_mamba: + if self.transfer_policy.is_mamba: self._physical_blocks_per_logical[self.engine_id] = ( self._physical_blocks_per_logical_kv_block ) @@ -833,7 +778,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._physical_blocks_per_logical_kv_block, self.num_regions, self.num_descs, - self._mamba_ssm_size, + self.transfer_policy.ssm_sizes, set(self.block_len_per_layer), ) @@ -854,7 +799,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, - ssm_sizes=self._mamba_ssm_size, + ssm_sizes=self.transfer_policy.ssm_sizes, attn_backend_name=self.backend_name, ) # Wrap metadata in payload with hash for defensive decoding @@ -865,163 +810,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata_bytes=encoder.encode(agent_metadata), ) - def _build_mamba_local( - self, - base_addresses: list[int], - block_size_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 desc regions (x, B, C, ssm) per layer for local mamba - blocks, enabling the 3-read transfer with DS conv layout.""" - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - assert self._conv_decomp is not None - conv_offsets = self._conv_decomp.local_conv_offsets - conv_size, ssm_size = self._mamba_ssm_size - num_blocks = self._logical_num_blocks * block_size_ratio - physical_per_logical = self._physical_blocks_per_logical_kv_block - - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): - page_stride = ( - self.block_len_per_layer[i] // block_size_ratio * physical_per_logical - ) - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append( - (base_addr + blk * page_stride + off, sz, self.device_id) - ) - # SSM temporal state follows the conv state. - for blk in range(num_blocks): - result.append( - ( - base_addr + blk * page_stride + conv_size, - ssm_size, - self.device_id, - ) - ) - return result - - def _build_fa_remote_for_mamba( - self, - nixl_agent_meta: NixlAgentMetadata, - block_size_ratio: int, - transfer_topo: TransferTopology, - remote_engine_id: EngineId, - ) -> list[tuple[int, int, int]]: - """Build remote FA descriptors for mamba models. - - Uses TransferTopology for GQA-aware FA divisor and head-based rank - offset instead of the standard uniform tp_ratio split. - """ - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) - # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA - # hetero-TP logic stabilizes. - assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) - mamba_info = transfer_topo.get_engine_info(remote_engine_id) - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size) - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - local_block_len = local_block_len // mamba_info.remote_num_fa_reads - - rank_offset = self.transfer_policy.fa_rank_offset( - mamba_info, - remote_kv_block_len, - tp_rank=transfer_topo.tp_rank, - tp_size=transfer_topo.tp_size, - is_mla=transfer_topo.is_mla, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - ) - - num_blocks = nixl_agent_meta.num_blocks - page_size = nixl_agent_meta.block_lens[i] - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - result.append((addr, local_block_len, nixl_agent_meta.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // mamba_info.remote_num_fa_reads - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append((v_addr, second_split, nixl_agent_meta.device_id)) - return result - - def _build_mamba_remote( - self, - nixl_agent_meta: NixlAgentMetadata, - tp_ratio: int, - ) -> list[tuple[int, int, int]]: - """Build 4 remote desc regions (x, B, C, ssm) per layer for - the 3-read transfer. For hetero-TP, each D rank reads only its - sub-projection slice from the P rank.""" - assert self._conv_decomp is not None - effective_ratio = max(tp_ratio, 1) - # Mamba conv state is always TP-sharded, even when attention KV - # is replicated (num_kv_heads < tp_size). - local_offset = self.tp_rank % effective_ratio - conv_size_remote = nixl_agent_meta.ssm_sizes[0] - - if tp_ratio >= 1: - # D_TP >= P_TP: P page is larger, D reads its slice. - conv_offsets = self._conv_decomp.remote_conv_offsets( - local_offset, effective_ratio - ) - ssm_read_size = self._mamba_ssm_size[1] - else: - # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages - # are smaller than D's. self._conv_decomp has D-sized dimensions, - # but we need P-sized offsets. Scale down by |tp_ratio|. - abs_ratio = -tp_ratio - xb_p = self._conv_decomp.x_bytes // abs_ratio - bb_p = self._conv_decomp.b_bytes // abs_ratio - conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] - ssm_read_size = nixl_agent_meta.ssm_sizes[1] - - remote_physical_per_logical = self._physical_blocks_per_logical[ - nixl_agent_meta.engine_id - ] - num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical - device_id = nixl_agent_meta.device_id - - result: list[tuple[int, int, int]] = [] - # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case - # block lengths vary across layers (e.g. MLA). - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append((base_addr + blk * page_stride + off, sz, device_id)) - # SSM temporal state is also TP-sharded on the heads dimension. - for blk in range(num_blocks): - ssm_addr = ( - base_addr - + blk * page_stride - + conv_size_remote - + local_offset * ssm_read_size - ) - result.append((ssm_addr, ssm_read_size, device_id)) - return result - def register_local_xfer_handler( self, block_size: int, @@ -1041,71 +829,24 @@ def register_local_xfer_handler( transfer_topo = self.transfer_topo block_size_ratio = self.block_size // block_size - blocks_data: list[tuple[int, int, int]] = [] local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] - def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool): - for i, base_addr in enumerate(local_base_addresses): - # The new block_len is using prefill block_len; - # and num_blocks is multiple with N - kv_block_len = ( - self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - // block_size_ratio - ) - # Jump one page_size, but ssm page_size may be bigger when kernel - # locks block size to a specific value. - block_len_per_layer = ( - self.block_len_per_layer[i] - // block_size_ratio - * (1 if not mamba else self._physical_blocks_per_logical_kv_block) - ) - num_blocks = self._logical_num_blocks if mamba else self.num_blocks - num_blocks = num_blocks * block_size_ratio - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) - - if transfer_topo.is_kv_layout_blocks_first: - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, second_split, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - # NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used - # right now — we use _build_mamba_local instead for the 3-read - # approach. However, we might still need this as a fallback for homogeneous TP. - register_blocks(blocks_data, mamba=False) - if self._has_mamba: - assert self.num_descs == len(blocks_data) - # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is - # unnecessary — a single conv desc per block suffices. Consider - # adding a fast path that falls back to the standard 2-region - # registration (register_blocks mamba=True) when no hetero-TP - # remote has been seen. Currently we always register 4 regions - # because local descs are created before knowing the remote TP. - logger.debug("Registering local Mamba descriptors (4 regions/layer)") - blocks_data.extend( - self._build_mamba_local(local_base_addresses, block_size_ratio) - ) + blocks_data = self.transfer_policy.build_local_descs( + base_addresses=local_base_addresses, + block_len_per_layer=self.block_len_per_layer, + num_blocks=self.num_blocks, + logical_num_blocks=self._logical_num_blocks, + block_size_ratio=block_size_ratio, + device_id=self.device_id, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + ) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1179,7 +920,7 @@ def add_remote_agent( nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0], ) - if self._has_mamba + if self.transfer_policy.is_mamba else 1 ) transfer_info = self.transfer_policy.build_engine_transfer_info( @@ -1197,7 +938,10 @@ def add_remote_agent( remote_physical_blocks_per_logical=physical_blocks_per_logical, ) transfer_topo.register_remote_engine(engine_id, transfer_info) - if self._has_mamba and engine_id not in self._physical_blocks_per_logical: + if ( + self.transfer_policy.is_mamba + and engine_id not in self._physical_blocks_per_logical + ): self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) @@ -1246,163 +990,52 @@ def add_remote_agent( and not self.use_mla and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): - # Remote tp_size > local tp_size: read from multiple remote ranks. - # Logically "split" own regions into |tp_ratio| chunks. Mind that - # we only do this once per remote tp_size (replica-friendly). abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - if self._has_mamba: - assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) - mamba_info = transfer_topo.get_engine_info(engine_id) - assert isinstance(mamba_info, MambaEngineTransferInfo) - if self.transfer_policy.needs_split_handles( - mamba_info, - tp_size=transfer_topo.tp_size, - is_mla=transfer_topo.is_mla, - ): - # Mamba-HMA: FA and Mamba use different split factors. - for handle_data in self.transfer_policy.compute_split_handle_data( - mamba_info, - self.src_blocks_data, - self.num_descs, - abs_tp, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - ): - descs = self.nixl_wrapper.get_xfer_descs( - handle_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - logger.info( - "Mamba-HMA split handles: %s, num_descs=%s", - transfer_topo.describe(engine_id), - self.num_descs, - ) - else: - # Original path: uniform divide by abs_tp (non-Mamba-HMA). - for i in range(abs_tp): - blocks_data = [] - for memory_region in self.src_blocks_data: - addr, local_block_len, own_tp_rank = memory_region - remote_block_len = local_block_len // abs_tp - addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs( - blocks_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - ### Register remote agent memory regions - blocks_data = [] - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogeneous TP, prepare the descriptors by splitting the - # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - - # Register all remote blocks, but only the corresponding kv heads. - def register_remote_blocks( - blocks_data: list[tuple[int, int, int]], mamba: bool - ): - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote. - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=mamba - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - # Remote tp is bigger: read a chunk of local region from remote - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if indexes_into_remote - else 0 - ) - - # Assume same num_blocks for mamba and fa - num_blocks = ( - nixl_agent_meta.num_blocks - if not mamba - else nixl_agent_meta.num_blocks - // self._physical_blocks_per_logical_kv_block - ) - page_size = nixl_agent_meta.block_lens[i] * ( - 1 if not mamba else self._physical_blocks_per_logical_kv_block + for handle_data in self.transfer_policy.build_src_split_handles( + self.src_blocks_data, + self.num_descs, + abs_tp, + transfer_config=transfer_topo.get_engine_info(engine_id) + if self.transfer_policy.is_mamba + else None, + tp_size=transfer_topo.tp_size, + is_mla=transfer_topo.is_mla, + total_num_kv_heads=transfer_topo.total_num_kv_heads, + ): + descs = self.nixl_wrapper.get_xfer_descs( + handle_data, self.nixl_memory_type ) - for block_id in range(num_blocks): - block_offset = block_id * page_size - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append( - (addr, local_block_len, nixl_agent_meta.device_id) - ) - - if transfer_topo.is_kv_layout_blocks_first: - # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=mamba - ) - # Apply the same scaling as local_block_len above for when we read - # a chunk of local V from `tp_ratio` separate remote workers. - if tp_ratio < 0 and not self.use_mla: - second_split = second_split // (-tp_ratio) - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - # Hop over the first split of remote page: either K or Conv. - if mamba: - v_addr = addr + nixl_agent_meta.ssm_sizes[0] - else: - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append( - (v_addr, second_split, nixl_agent_meta.device_id) - ) - - logger.debug( - "Created %s blocks for dst engine %s" - " with remote rank %s and local rank %s", - len(blocks_data), - engine_id, - remote_tp_rank, - self.tp_rank, - ) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - if self._has_mamba: - # Mamba-HMA: separate FA registration with GQA-aware sizing, - # plus mamba 3-read registration for the Mamba "view" of the - # same KV cache tensors. - logger.debug( - "Registering remote Mamba blocks for engine %s rank %s", - engine_id, - remote_tp_rank, - ) - blocks_data.extend( - self._build_fa_remote_for_mamba( - nixl_agent_meta, - block_size_ratio, - transfer_topo, - engine_id, - ) - ) - blocks_data.extend( - self._build_mamba_remote( - nixl_agent_meta, - tp_ratio, - ) - ) - else: - register_remote_blocks(blocks_data, mamba=False) + ### Register remote agent memory regions + blocks_data = self.transfer_policy.build_remote_descs( + nixl_agent_meta=nixl_agent_meta, + block_size_ratio=block_size_ratio, + tp_ratio=tp_ratio, + tp_rank=self.tp_rank, + use_mla=self.use_mla, + block_len_per_layer=self.block_len_per_layer, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + indexes_into_remote=indexes_into_remote, + transfer_config=transfer_topo.get_engine_info(engine_id) + if self.transfer_policy.is_mamba + else None, + physical_blocks_per_logical=self._physical_blocks_per_logical.get( + engine_id, 1 + ), + tp_size=transfer_topo.tp_size, + total_num_kv_heads=transfer_topo.total_num_kv_heads, + ) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -1438,7 +1071,7 @@ def _validate_remote_agent_handshake( ) # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # Mamba models can have replicated FA KV with tp_ratio < 0. - if not self._has_mamba: + if not self.transfer_policy.is_mamba: assert not ( tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id) ) @@ -1508,7 +1141,7 @@ def _validate_remote_agent_handshake( # With replicated KV cache, only the number of blocks can differ. # TODO (ZhanqiuHu): For mamba models, validate FA and mamba # block_lens separately. - if not self._has_mamba: + if not self.transfer_policy.is_mamba: for i in range(len(self.block_len_per_layer)): assert ( self.block_len_per_layer[i] // block_size_ratio @@ -1524,7 +1157,7 @@ def _validate_remote_agent_handshake( # HMA hybrid models (mamba+attention) pad block_len to # max(attn_page, mamba_page), so the linear tp_ratio scaling # assumption only holds for pure-attention models. - if not self._has_mamba: + if not self.transfer_policy.is_mamba: if tp_ratio > 0: assert ( remote_block_len @@ -1581,8 +1214,8 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): - meta.local_physical_block_ids = self._logical_to_kernel_block_ids( - meta.local_block_ids + meta.local_physical_block_ids = ( + self.transfer_policy.logical_to_kernel_block_ids(meta.local_block_ids) ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -1876,8 +1509,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): - meta.local_physical_block_ids = self._logical_to_kernel_block_ids( - meta.local_block_ids + meta.local_physical_block_ids = ( + self.transfer_policy.logical_to_kernel_block_ids(meta.local_block_ids) ) assert meta.remote is not None # Remote block IDs are kept logical here; expanded in @@ -1934,82 +1567,63 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) - if self._has_mamba: - # Expand remote logical → kernel block IDs. - meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( - meta.remote.block_ids, - self._physical_blocks_per_logical[meta.remote.engine_id], - ) - else: - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids - ) - # D may have to perform multiple reads from different remote ranks. - for i, remote_rank in enumerate(remote_ranks): - if self.use_mla and tp_ratio < 0 and i > 0: - # MLA opt: when P TP > D TP, only a single read is executed for - # the first remote rank (cache is duplicated).. - break + read_specs = self.transfer_policy.compute_read_specs( + local_block_ids=meta.local_physical_block_ids, + remote_block_ids=meta.remote.block_ids, + remote_ranks=remote_ranks, + physical_blocks_per_logical=self._physical_blocks_per_logical.get( + engine_id, 1 + ), + transfer_config=remote_info if self.transfer_policy.is_mamba else None, + ) + + # MLA opt: when P TP > D TP, only a single read is needed. + if self.use_mla and tp_ratio < 0: + read_specs = read_specs[:1] + for i, spec in enumerate(read_specs): remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s with remote block size %s for req %s", meta.remote.engine_id, - remote_rank, + spec.remote_rank, remote_block_size, req_id, ) # Get side handles. if tp_ratio < 0 and not self.use_mla: assert remote_block_size == self.block_size - # Remote tp_size > local tp_size: we must perform multiple - # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] else: - # Single read from remote, we write to the whole memory region. - # Also handle remote block size different from local block size. local_xfer_side_handle = self.src_xfer_handles_by_block_size[ remote_block_size ] - # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ - remote_rank + spec.remote_rank ] - local_ids: BlockIds = meta.local_physical_block_ids - remote_ids: BlockIds = meta.remote.block_ids - if self._has_mamba: - # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. - assert isinstance(self.transfer_policy, MambaModelBlockTransferPolicy) - assert isinstance(remote_info, MambaEngineTransferInfo) - local_ids, remote_ids = self.transfer_policy.filter_block_ids_for_rank( - remote_info, - remote_rank, - local_ids, - remote_ids, - ) - self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, remote_request_id=meta.remote.request_id, - local_block_ids=local_ids, - remote_block_ids=remote_ids, - remote_rank=remote_rank, + local_block_ids=spec.local_block_ids, + remote_block_ids=spec.remote_block_ids, + remote_rank=spec.remote_rank, local_xfer_side_handle=local_xfer_side_handle, remote_xfer_side_handle=remote_xfer_side_handle, ) - if self.use_mla and tp_ratio < 0: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + if self.use_mla and tp_ratio < 0 and read_specs: + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + read_ranks = {s.remote_rank for s in read_specs} + for rank_to_notify, agent in remote_agents.items(): + if rank_to_notify not in read_ranks: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, @@ -2099,12 +1713,15 @@ def _read_blocks( for i, remote_group in enumerate(remote_block_ids): num_remote_blocks = len(remote_group) num_local_blocks = len(local_block_ids[i]) - if not self._is_mamba_group[i]: + if not self.transfer_policy.is_mamba_group(i): assert num_local_blocks <= num_remote_blocks # Partial prefix cache hit: just read uncomputed blocks. # Skip mamba groups — their blocks represent full state (conv+ssm), # not per-token data, so trimming would corrupt the transfer. - if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: + if ( + num_local_blocks < num_remote_blocks + and not self.transfer_policy.is_mamba_group(i) + ): remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -2187,158 +1804,17 @@ def _get_block_descs_ids( block_ids: BlockIds, block_size_ratio: float | None = None, ) -> np.ndarray: - """ - Get the descs ids for a set of block ids. - When HMA is enabled number of descriptors across kv cache groups might differ. - A single flattened array is returned for all groups anyway. - """ - region_ids = np.arange(self.num_regions) - - # NOTE (NickLucche) With HMA, every kv group has the same number of layers and - # layers from different groups share the same kv tensor. - # eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, - # same for [3], but group0-group1 blocks will always differ (different areas). - # Therefore we can just flatten the block_ids and compute the descs ids for all - # groups at once. - num_blocks = self.dst_num_blocks[engine_id] - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - - # Compute desc ids per group using the right stride: FA descs have - # num_blocks entries per region (kernel granularity), SSM descs have - # logical_blocks entries per region (no kernel splitting). - region_ids = region_ids[:, None] - if not self._has_mamba: - block_ids = np.concatenate(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids - return descs_ids.flatten() - else: - # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged - # arbitrarily by manager. Therefore, descs are duplicated for SSM and - # Attention like so: - # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. - # This is like having two "low-level views" of the same storage. - # `num_fa_descs` offset must be computed per-engine since P and D can - # have different num_blocks (and thus different FA descs counts). - physical_per_logical = self._physical_blocks_per_logical[engine_id] - logical_blocks = num_blocks // physical_per_logical - num_fa_descs = self.num_regions * num_blocks - # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). - mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] - all_descs = [] - for i, group in enumerate(block_ids): - group_arr = np.asarray(group)[None, :] - if self._is_mamba_group[i]: - all_descs.append( - ( - mamba_region_ids * logical_blocks + group_arr + num_fa_descs - ).flatten() - ) - else: - all_descs.append((region_ids * num_blocks + group_arr).flatten()) - return np.concatenate(all_descs) - - def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: - """ - Convert logical block ids to kernel physical block ids. - This is required when the logical block size (the one set by the user) - does not match the one required by the attn backend. - """ - if self._physical_blocks_per_logical_kv_block == 1: - # Noop when physical and logical block sizes are the same - return block_ids - block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( - 1, -1 + """Thin wrapper delegating to the block transfer policy.""" + return self.transfer_policy.get_block_descs_ids( + block_ids=block_ids, + num_regions=self.num_regions, + dst_num_blocks=self.dst_num_blocks[engine_id], + block_len_per_layer=self.block_len_per_layer, + block_size_ratio=block_size_ratio, + physical_blocks_per_logical=self._physical_blocks_per_logical.get( + engine_id, 1 + ), ) - # Mamba blocks have no logical<>physical discrepancy - group_specs = self.kv_cache_config.kv_cache_groups - return [ - BlockTable.map_to_kernel_blocks( - np.array(group), - self._physical_blocks_per_logical_kv_block, - block_arange, - ).tolist() - if not isinstance(group_specs[i].kv_cache_spec, MambaSpec) - else group - for i, group in enumerate(block_ids) - ] - - def _logical_to_remote_kernel_block_ids( - self, block_ids: BlockIds, remote_physical_per_logical: int - ) -> BlockIds: - """Map logical block IDs to physical kernel block IDs on the remote. - - Args: - block_ids: per-group lists of logical block IDs. - remote_physical_per_logical: remote engine's physical blocks - per logical block. - - Returns: - Same structure with FA groups expanded (each logical block L - becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]). - Mamba groups are passed through unchanged. - """ - local_ratio = self._physical_blocks_per_logical_kv_block - if remote_physical_per_logical == 1: - return block_ids - local_arange = np.arange(local_ratio).reshape(1, -1) - group_specs = self.kv_cache_config.kv_cache_groups - result: list[list[int]] = [] - for i, group in enumerate(block_ids): - if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): - arr = np.array(group).reshape(-1, 1) - expanded = (arr * remote_physical_per_logical + local_arange).flatten() - result.append(expanded.tolist()) - else: - # Mamba blocks are 1:1 logical-to-physical (no expansion). - result.append(group) - return result - - def get_backend_aware_kv_block_len( - self, layer_idx: int, first_split: bool = True, mamba_view: bool = False - ) -> int: - """ - Get the block length for one K/V element (K and V have the same size). - - For FA and other backends, this is equal to the length of the whole - block, as K and V are in separate regions. - For FlashInfer, this is half the length of the whole block, as K and V - share the same region. - Similarly, for SSM-based models, state and conv are interleaved, but crucially - the their size differs. - Reference diagram: - KVCacheTensor (Shared) - / \\ - / \\ - / \\ - Attention (FlashInfer) View Mamba View - | | - | | - +-------------------+ +-------------------+ - | KVCacheTensor | | KVCacheTensor | - | | | | - |<----- page ------>| |<----- page ------->| - | size | | size | - | Key 0 | Val 0 | |Conv 0 | SSM 0 | - | Key 1 | Val 1 | |Conv 1 | SSM 1 | - | ... | ... | | ... | ... | - | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | - | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | - +-------------------+ +--------------------+ - |1st_split-2nd_split| |1st_split-2nd_split | - """ - assert self.transfer_topo is not None - if self.transfer_topo.is_kv_layout_blocks_first: - # For indexing only half (either just the K or V part). - if mamba_view: - # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so - # we're only transferring the minimum required bytes. - block_len = self._mamba_ssm_size[not first_split] - else: - block_len = self.block_len_per_layer[layer_idx] // 2 - else: - block_len = self.block_len_per_layer[layer_idx] - return block_len def get_kv_connector_stats(self) -> KVConnectorStats | None: """ From 6961ae05271ac3ddb541a25e1230644a67727695 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Fri, 17 Apr 2026 16:43:10 +0000 Subject: [PATCH 05/49] updates Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 87 ++++++++++++++++++- .../kv_connector/v1/nixl/worker.py | 21 ++++- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 2e23da1624c1..5d860d29b0fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -252,6 +252,8 @@ def _build_fa_local_descs( result: list[tuple[int, int, int]] = [] n_blocks = num_blocks * block_size_ratio for i, base_addr in enumerate(base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N kv_block_len = ( self.get_block_len( i, @@ -261,8 +263,11 @@ def _build_fa_local_descs( ) // block_size_ratio ) + # Jump one page_size, but ssm page_size may be bigger when kernel + # locks block size to a specific value. page_stride = block_len_per_layer[i] // block_size_ratio for block_id in range(n_blocks): + # (addr, len, device id) result.append( ( base_addr + block_id * page_stride, @@ -277,7 +282,11 @@ def _build_fa_local_descs( block_len_per_layer, is_blocks_first, ) + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. for block_id in range(n_blocks): + # Register addresses for V cache (K registered first). v_addr = base_addr + block_id * page_stride + kv_block_len result.append( ( @@ -339,6 +348,7 @@ def compute_read_specs( The worker iterates the result without model-specific branching. MLA trimming (keeping only the first spec) is handled by the worker. """ + _ = (physical_blocks_per_logical, transfer_config) return [ ReadSpec( remote_rank=rank, @@ -481,6 +491,25 @@ def logical_to_remote_kernel_block_ids( for group in block_ids ] + def compute_read_specs( + self, + local_block_ids, + remote_block_ids, + remote_ranks, + physical_blocks_per_logical=1, + transfer_config=None, + ): + _ = (physical_blocks_per_logical, transfer_config) + expanded = self.logical_to_kernel_block_ids(remote_block_ids) + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=local_block_ids, + remote_block_ids=expanded, + ) + for rank in remote_ranks + ] + def build_local_descs( self, base_addresses, @@ -515,21 +544,29 @@ def build_remote_descs( tp_size: int = 1, total_num_kv_heads: int = 1, ): + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + # Register all remote blocks, but only the corresponding kv heads. _ = (tp_size, total_num_kv_heads, physical_blocks_per_logical, transfer_config) result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): + # Read our whole local region size from remote. local_block_len = self.get_block_len( i, True, block_len_per_layer, is_blocks_first, ) + # using remote kv_block_len as transfer unit remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: local_block_len = remote_kv_block_len if tp_ratio < 0 and not use_mla: + # Remote tp is bigger: read a chunk of local region from remote local_block_len = local_block_len // (-tp_ratio) rank_offset = ( tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 @@ -686,8 +723,11 @@ def get_block_len( is_blocks_first, mamba_view=False, ): + # For indexing only half (either just the K or V part). if is_blocks_first: if mamba_view: + # NOTE (NickLucche) Mamba Opt: this is already skipping the + # padding so we're only transferring the minimum required bytes. return self._ssm_sizes[not first_split] return block_len_per_layer[layer_idx] // 2 return block_len_per_layer[layer_idx] @@ -701,6 +741,16 @@ def get_block_descs_ids( block_size_ratio=None, physical_blocks_per_logical=1, ): + # NOTE (NickLucche) With HMA, every kv group has the same number of + # layers and layers from different groups share the same kv tensor. + # eg block_ids=[[1,2],[3]]->blocks [1,2] need to be read across all + # regions, same for [3], but group0-group1 blocks will always differ + # (different areas). Therefore we can just flatten the block_ids and + # compute the descs ids for all groups at once. + + # Compute desc ids per group using the right stride: FA descs have + # num_blocks entries per region (kernel granularity), SSM descs have + # logical_blocks entries per region (no kernel splitting). num_blocks = dst_num_blocks if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) @@ -708,6 +758,14 @@ def get_block_descs_ids( ratio = physical_blocks_per_logical logical_blocks = num_blocks // ratio num_fa_descs = num_regions * num_blocks + # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged + # arbitrarily by manager. Therefore, descs are duplicated for SSM and + # Attention like so: + # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. + # This is like having two "low-level views" of the same storage. + # `num_fa_descs` offset must be computed per-engine since P and D can + # have different num_blocks (and thus different FA descs counts). + # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). mamba_region_ids = np.arange( len(block_len_per_layer) * 4, )[:, None] @@ -715,6 +773,7 @@ def get_block_descs_ids( for i, group in enumerate(block_ids): group_arr = np.asarray(group)[None, :] if self._is_mamba_group[i]: + # Mamba blocks are 1:1 logical-to-physical (no expansion). all_descs.append( ( mamba_region_ids * logical_blocks + group_arr + num_fa_descs @@ -790,6 +849,12 @@ def build_local_descs( ) num_regions = len(base_addresses) * (2 if is_blocks_first else 1) assert len(fa_descs) == num_regions * num_blocks + # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read + # split is unnecessary — a single conv desc per block suffices. + # Consider adding a fast path that falls back to the standard 2-region + # registration when no hetero-TP remote has been seen. Currently we + # always register 4 regions because local descs are created before + # knowing the remote TP. logger.debug("Registering local Mamba descriptors (4 regions/layer)") mamba_descs = self._build_mamba_local_descs( base_addresses, @@ -810,12 +875,16 @@ def _build_mamba_local_descs( ) -> list[tuple[int, int, int]]: """Build 4 desc regions (x, B, C, ssm) per layer for local mamba blocks, enabling the 3-read transfer with DS conv layout. + + Conv state sub-projection decomposition requires DS (dim, state_len) + conv layout so that x/B/C sub-projections are contiguous in memory. """ assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got block_size_ratio={block_size_ratio}." ) conv_offsets = self._conv_decomp.local_conv_offsets + # SSM States come in tuples (conv_size, ssm_state_size) conv_size, ssm_size = self._ssm_sizes n_blocks = logical_num_blocks * block_size_ratio phys_ratio = self._physical_blocks_per_logical @@ -823,6 +892,7 @@ def _build_mamba_local_descs( result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio + # 3 conv sub-projection regions (x, B, C) for off, sz in conv_offsets: for blk in range(n_blocks): result.append( @@ -832,6 +902,7 @@ def _build_mamba_local_descs( device_id, ) ) + # SSM temporal state follows the conv state. for blk in range(n_blocks): result.append( ( @@ -857,6 +928,8 @@ def build_remote_descs( tp_size: int = 1, total_num_kv_heads: int = 1, ): + # indexes_into_remote is not used for Mamba: FA offset is computed + # via fa_rank_offset which accounts for GQA/HMA head mapping. _ = indexes_into_remote info = cast(MambaEngineTransferInfo, transfer_config) result: list[tuple[int, int, int]] = [] @@ -961,18 +1034,27 @@ def _build_mamba_remote_descs( physical_blocks_per_logical, ): """Build 4 remote desc regions (x, B, C, ssm) per layer - for the 3-read transfer.""" + for the 3-read transfer. + + Mamba conv state is always TP-sharded, even when attention KV + is replicated (num_kv_heads < tp_size). + """ effective_ratio = max(tp_ratio, 1) local_offset = tp_rank % effective_ratio conv_size_remote = nixl_agent_meta.ssm_sizes[0] if tp_ratio >= 1: + # D_TP >= P_TP: P page is larger, D reads its slice. conv_offsets = self._conv_decomp.remote_conv_offsets( local_offset, effective_ratio, ) + # SSM temporal state is also TP-sharded on the heads dimension. ssm_read_size = self._ssm_sizes[1] else: + # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages + # are smaller than D's. self._conv_decomp has D-sized dimensions, + # but we need P-sized offsets. Scale down by |tp_ratio|. abs_ratio = -tp_ratio xb_p = self._conv_decomp.x_bytes // abs_ratio bb_p = self._conv_decomp.b_bytes // abs_ratio @@ -983,6 +1065,7 @@ def _build_mamba_remote_descs( ] ssm_read_size = nixl_agent_meta.ssm_sizes[1] + # Assume same num_blocks for mamba and fa remote_ratio = physical_blocks_per_logical num_blocks = nixl_agent_meta.num_blocks // remote_ratio dev_id = nixl_agent_meta.device_id @@ -991,6 +1074,8 @@ def _build_mamba_remote_descs( for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): + # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case + # block lengths vary across layers (e.g. MLA). page_stride = nixl_agent_meta.block_lens[i] * remote_ratio for off, sz in conv_offsets: for blk in range(num_blocks): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0147109249dd..f12984b241e4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -228,7 +228,12 @@ def __init__( self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] - # Per-engine physical-blocks-per-logical ratio (only used for Mamba). + # Mamba-HMA per-engine state (only used when is_mamba). + # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. + # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) + # where conv/ssm bytes are per-TP-rank (dimension-sharded). With + # heterogeneous TP the per-rank sizes differ, so the ratio differs: + # e.g. Nemotron 30B: P(TP=4) -> 131, D(TP=1) -> 261. self._physical_blocks_per_logical: dict[EngineId, int] = {} # In progress transfers. @@ -291,6 +296,8 @@ def __init__( physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, tp_size=vllm_config.parallel_config.tensor_parallel_size, ) + if self.transfer_policy.is_mamba: + assert self._is_hma_required self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True @@ -985,6 +992,9 @@ def add_remote_agent( ) ### (Optional) Register local agent memory regions. MLA is not split. + # Remote tp_size > local tp_size: read from multiple remote ranks. + # Logically "split" own regions into |tp_ratio| chunks. Mind that + # we only do this once per remote tp_size (replica-friendly). if ( tp_ratio < 0 and not self.use_mla @@ -1594,12 +1604,17 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # Get side handles. if tp_ratio < 0 and not self.use_mla: assert remote_block_size == self.block_size + # Remote tp_size > local tp_size: we must perform multiple + # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] else: + # Single read from remote, we write to the whole memory region. + # Also handle remote block size different from local block size. local_xfer_side_handle = self.src_xfer_handles_by_block_size[ remote_block_size ] + # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ spec.remote_rank ] @@ -1616,8 +1631,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): ) if self.use_mla and tp_ratio < 0 and read_specs: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. + # Notify remote ranks that were not read from so they can update + # request state (cache is duplicated under MLA). notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] read_ranks = {s.remote_rank for s in read_specs} From 00a44dcca677e227b01afb1187826e6d017db3a1 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Sat, 18 Apr 2026 01:39:03 +0000 Subject: [PATCH 06/49] restore original comments; remove redundant _physical_blocks_per_logical dict Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 377 ++++-------------- .../kv_connector/v1/nixl/worker.py | 220 ++++++---- 2 files changed, 218 insertions(+), 379 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 5d860d29b0fc..dbd2eb2e37f9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -35,7 +35,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.v1.kv_cache_interface import MambaSpec -from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( @@ -62,10 +61,10 @@ class ReadSpec: class ModelBlockTransferPolicy(ABC): """Abstract base for model-specific block transfer logic. - Concrete subclasses encapsulate: - - Model identity (is_mamba, per-group flags) - - Mamba state sizes and conv decomposition - - Per-engine transfer info computation (``build_engine_transfer_info``) + Encapsulates genuinely model-specific algorithms: descriptor building, + transfer info computation, split handles, read spec filtering, and + orchestration. Simple per-layer branches (``isinstance(MambaSpec)``) + and block-ID mapping remain on ``worker.py``. """ def __init__( @@ -76,36 +75,6 @@ def __init__( self._kv_cache_config = kv_cache_config self._physical_blocks_per_logical = physical_blocks_per_logical - # ------------------------------------------------------------------ - # Model identity - # ------------------------------------------------------------------ - - @property - @abstractmethod - def is_mamba(self) -> bool: - """Whether this policy handles a hybrid Mamba+Attention model.""" - - @property - @abstractmethod - def mamba_group_flags(self) -> list[bool]: - """Per-group flag: True if the group is a Mamba (SSM) group.""" - - def is_mamba_group(self, group_idx: int) -> bool: - return self.mamba_group_flags[group_idx] - - @property - @abstractmethod - def ssm_sizes(self) -> tuple[int, int]: - """(conv_state_bytes, ssm_state_bytes) per logical block. - - Returns (0, 0) for dense models. - """ - - @property - @abstractmethod - def conv_decomp(self) -> MambaConvSplitInfo | None: - """Conv-state sub-projection decomposition, or None for dense.""" - # ------------------------------------------------------------------ # Per-engine transfer info (data operations) # ------------------------------------------------------------------ @@ -115,14 +84,12 @@ def conv_decomp(self) -> MambaConvSplitInfo | None: def build_engine_transfer_info( self, *, - # Local facts (from TransferTopology). tp_rank: int, tp_size: int, is_mla: bool, total_num_kv_heads: int, is_kv_layout_blocks_first: bool, local_block_len: int, - # Remote facts (from NixlAgentMetadata handshake). remote_tp_size: int, remote_block_size: int, remote_block_len: int, @@ -135,47 +102,9 @@ def build_engine_transfer_info( """ # ------------------------------------------------------------------ - # Registration helpers + # Block length helper (used by descriptor building) # ------------------------------------------------------------------ - @abstractmethod - def compute_page_size( - self, - layer_spec: KVCacheSpec, - physical_ratio: int, - ) -> int: - """Physical page size in bytes for one layer.""" - ... - - @abstractmethod - def get_num_blocks( - self, - layer_spec: KVCacheSpec, - num_blocks: int, - logical_num_blocks: int, - ) -> int: - """Number of blocks to register for this layer spec.""" - ... - - @abstractmethod - def compute_layer_block_bytes( - self, - layer_spec: KVCacheSpec, - physical_page_size: int, - physical_ratio: int, - ) -> int: - """Block byte size for one layer (entry for ``block_len_per_layer``).""" - ... - - @abstractmethod - def get_tensor_shape( - self, - kv_caches: dict[str, torch.Tensor], - ) -> torch.Size | None: - """Tensor shape for ``TpKVTopology`` (None for Mamba).""" - ... - - @abstractmethod def get_block_len( self, layer_idx: int, @@ -184,11 +113,13 @@ def get_block_len( is_blocks_first: bool, mamba_view: bool = False, ) -> int: - """Block length for one K/V (or conv/ssm) element.""" - ... + """Block length for one K/V element. Mamba overrides for SSM view.""" + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] # ------------------------------------------------------------------ - # Descriptor ID computation + block ID mapping + # Descriptor ID computation (abstract — genuinely different per model) # ------------------------------------------------------------------ @abstractmethod @@ -204,28 +135,10 @@ def get_block_descs_ids( """Compute NIXL descriptor IDs for a set of block IDs.""" ... - @abstractmethod - def logical_to_kernel_block_ids( - self, - block_ids: BlockIds, - ) -> BlockIds: - """Convert logical block IDs to kernel physical block IDs.""" - ... - - @abstractmethod - def logical_to_remote_kernel_block_ids( - self, - block_ids: BlockIds, - remote_ratio: int, - ) -> BlockIds: - """Map logical block IDs to physical kernel block IDs on remote.""" - ... - # ------------------------------------------------------------------ - # Local descriptor building + # Local descriptor building (concrete default = FA-only) # ------------------------------------------------------------------ - @abstractmethod def build_local_descs( self, base_addresses: list[int], @@ -236,8 +149,19 @@ def build_local_descs( device_id: int, is_blocks_first: bool, ) -> list[tuple[int, int, int]]: - """Build local (src) descriptor tuples for NIXL registration.""" - ... + """Build local (src) descriptor tuples for NIXL registration. + + Default builds FA descriptors only. Mamba overrides to extend + with SSM (conv + temporal state) descriptors. + """ + return self._build_fa_local_descs( + base_addresses, + block_len_per_layer, + num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ) def _build_fa_local_descs( self, @@ -298,7 +222,7 @@ def _build_fa_local_descs( return result # ------------------------------------------------------------------ - # Remote descriptor building + # Remote descriptor building (abstract — genuinely different) # ------------------------------------------------------------------ @abstractmethod @@ -334,12 +258,15 @@ def build_src_split_handles( """Build split handle data for P_TP > D_TP scenario.""" ... + # ------------------------------------------------------------------ + # Read spec computation (concrete default = one spec per rank) + # ------------------------------------------------------------------ + def compute_read_specs( self, local_block_ids: BlockIds, remote_block_ids: BlockIds, remote_ranks: list[int], - physical_blocks_per_logical: int = 1, transfer_config: Any | None = None, ) -> list[ReadSpec]: """Compute the full set of read operations needed for a request. @@ -347,8 +274,12 @@ def compute_read_specs( Returns one ``ReadSpec`` per remote rank that requires a read. The worker iterates the result without model-specific branching. MLA trimming (keeping only the first spec) is handled by the worker. + + Block ID expansion (logical→kernel) is done by the worker before + calling this method. Mamba overrides to additionally filter + block IDs per rank via ``filter_block_ids_for_rank``. """ - _ = (physical_blocks_per_logical, transfer_config) + _ = transfer_config return [ ReadSpec( remote_rank=rank, @@ -370,14 +301,13 @@ def create( tp_size: int, ) -> ModelBlockTransferPolicy: """Create the appropriate policy based on model architecture.""" - is_mamba_group = [ + has_mamba = any( isinstance(group.kv_cache_spec, MambaSpec) for group in kv_cache_config.kv_cache_groups - ] - if any(is_mamba_group): + ) + if has_mamba: return MambaModelBlockTransferPolicy( kv_cache_config=kv_cache_config, - is_mamba_group=is_mamba_group, tp_size=tp_size, layer_specs=layer_specs, physical_blocks_per_logical=physical_blocks_per_logical, @@ -394,53 +324,21 @@ def create( class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): + """Policy for pure-attention (dense) models. + + Inherits all registration helpers, block ID mapping, + ``compute_read_specs``, ``build_local_descs``, ``get_block_len`` + from the ABC. Only overrides genuinely different methods: + ``get_block_descs_ids``, ``build_remote_descs``, + ``build_src_split_handles``, and ``build_engine_transfer_info``. + """ + def __init__( self, kv_cache_config: KVCacheConfig, physical_blocks_per_logical: int, ): super().__init__(kv_cache_config, physical_blocks_per_logical) - self._num_groups = len(kv_cache_config.kv_cache_groups) - - @property - def is_mamba(self) -> bool: - return False - - @property - def mamba_group_flags(self) -> list[bool]: - return [False] * self._num_groups - - @property - def ssm_sizes(self) -> tuple[int, int]: - return (0, 0) - - @property - def conv_decomp(self) -> MambaConvSplitInfo | None: - return None - - def compute_page_size(self, layer_spec, physical_ratio): - return layer_spec.page_size_bytes // physical_ratio - - def get_num_blocks(self, layer_spec, num_blocks, logical_num_blocks): - return num_blocks - - def compute_layer_block_bytes(self, layer_spec, physical_page_size, physical_ratio): - return physical_page_size - - def get_tensor_shape(self, kv_caches): - return next(iter(kv_caches.values())).shape - - def get_block_len( - self, - layer_idx, - first_split, - block_len_per_layer, - is_blocks_first, - mamba_view=False, - ): - if is_blocks_first: - return block_len_per_layer[layer_idx] // 2 - return block_len_per_layer[layer_idx] def get_block_descs_ids( self, @@ -458,77 +356,6 @@ def get_block_descs_ids( block_ids_arr = np.concatenate(block_ids)[None, :] return (region_ids * num_blocks + block_ids_arr).flatten() - def logical_to_kernel_block_ids(self, block_ids): - if self._physical_blocks_per_logical == 1: - return block_ids - block_arange = np.arange( - 0, - self._physical_blocks_per_logical, - ).reshape(1, -1) - return [ - BlockTable.map_to_kernel_blocks( - np.array(group), - self._physical_blocks_per_logical, - block_arange, - ).tolist() - for group in block_ids - ] - - def logical_to_remote_kernel_block_ids( - self, - block_ids, - remote_ratio, - ): - if remote_ratio == 1: - return block_ids - local_arange = np.arange( - self._physical_blocks_per_logical, - ).reshape(1, -1) - return [ - (np.array(group).reshape(-1, 1) * remote_ratio + local_arange) - .flatten() - .tolist() - for group in block_ids - ] - - def compute_read_specs( - self, - local_block_ids, - remote_block_ids, - remote_ranks, - physical_blocks_per_logical=1, - transfer_config=None, - ): - _ = (physical_blocks_per_logical, transfer_config) - expanded = self.logical_to_kernel_block_ids(remote_block_ids) - return [ - ReadSpec( - remote_rank=rank, - local_block_ids=local_block_ids, - remote_block_ids=expanded, - ) - for rank in remote_ranks - ] - - def build_local_descs( - self, - base_addresses, - block_len_per_layer, - num_blocks, - logical_num_blocks, - block_size_ratio, - device_id, - is_blocks_first, - ): - return self._build_fa_local_descs( - base_addresses, - block_len_per_layer, - num_blocks, - block_size_ratio, - device_id, - is_blocks_first, - ) - def build_remote_descs( self, nixl_agent_meta, @@ -646,16 +473,26 @@ def build_engine_transfer_info( class MambaModelBlockTransferPolicy(ModelBlockTransferPolicy): + """Policy for hybrid Mamba+Attention (SSM) models. + + Stores Mamba-specific state (SSM sizes, conv decomposition, + per-group flags) and overrides methods that differ from the + FA-only base: descriptor building, split handles, read specs, + registration helpers for MambaSpec layers, and orchestration. + """ + def __init__( self, kv_cache_config: KVCacheConfig, - is_mamba_group: list[bool], tp_size: int, layer_specs: dict[str, KVCacheSpec], physical_blocks_per_logical: int, ): super().__init__(kv_cache_config, physical_blocks_per_logical) - self._is_mamba_group = is_mamba_group + self._is_mamba_group = [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ] mamba_spec = next( spec for spec in layer_specs.values() if isinstance(spec, MambaSpec) @@ -681,40 +518,16 @@ def __init__( ) self._conv_decomp = derive_mamba_conv_split(mamba_spec, tp_size) - @property - def is_mamba(self) -> bool: - return True - - @property - def mamba_group_flags(self) -> list[bool]: - return self._is_mamba_group - @property def ssm_sizes(self) -> tuple[int, int]: + """(conv_state_bytes, ssm_state_bytes) per logical block.""" return self._ssm_sizes @property - def conv_decomp(self) -> MambaConvSplitInfo | None: + def conv_decomp(self) -> MambaConvSplitInfo: + """Conv-state sub-projection decomposition.""" return self._conv_decomp - def compute_page_size(self, layer_spec, physical_ratio): - if isinstance(layer_spec, MambaSpec): - return layer_spec.page_size_bytes - return layer_spec.page_size_bytes // physical_ratio - - def get_num_blocks(self, layer_spec, num_blocks, logical_num_blocks): - if isinstance(layer_spec, MambaSpec): - return logical_num_blocks - return num_blocks - - def compute_layer_block_bytes(self, layer_spec, physical_page_size, physical_ratio): - if isinstance(layer_spec, MambaSpec): - return physical_page_size // physical_ratio - return physical_page_size - - def get_tensor_shape(self, kv_caches): - return None - def get_block_len( self, layer_idx, @@ -723,14 +536,13 @@ def get_block_len( is_blocks_first, mamba_view=False, ): - # For indexing only half (either just the K or V part). - if is_blocks_first: - if mamba_view: - # NOTE (NickLucche) Mamba Opt: this is already skipping the - # padding so we're only transferring the minimum required bytes. - return self._ssm_sizes[not first_split] - return block_len_per_layer[layer_idx] // 2 - return block_len_per_layer[layer_idx] + if mamba_view and is_blocks_first: + # NOTE (NickLucche) Mamba Opt: this is already skipping the + # padding so we're only transferring the minimum required bytes. + return self._ssm_sizes[not first_split] + return super().get_block_len( + layer_idx, first_split, block_len_per_layer, is_blocks_first + ) def get_block_descs_ids( self, @@ -783,52 +595,6 @@ def get_block_descs_ids( all_descs.append((region_ids * num_blocks + group_arr).flatten()) return np.concatenate(all_descs) - def logical_to_kernel_block_ids(self, block_ids): - if self._physical_blocks_per_logical == 1: - return block_ids - block_arange = np.arange( - 0, - self._physical_blocks_per_logical, - ).reshape(1, -1) - group_specs = self._kv_cache_config.kv_cache_groups - return [ - BlockTable.map_to_kernel_blocks( - np.array(group), - self._physical_blocks_per_logical, - block_arange, - ).tolist() - if not isinstance( - group_specs[i].kv_cache_spec, - MambaSpec, - ) - else group - for i, group in enumerate(block_ids) - ] - - def logical_to_remote_kernel_block_ids( - self, - block_ids, - remote_ratio, - ): - if remote_ratio == 1: - return block_ids - local_arange = np.arange( - self._physical_blocks_per_logical, - ).reshape(1, -1) - group_specs = self._kv_cache_config.kv_cache_groups - result: list[list[int]] = [] - for i, group in enumerate(block_ids): - if not isinstance( - group_specs[i].kv_cache_spec, - MambaSpec, - ): - arr = np.array(group).reshape(-1, 1) - expanded = (arr * remote_ratio + local_arange).flatten() - result.append(expanded.tolist()) - else: - result.append(group) - return result - def build_local_descs( self, base_addresses, @@ -1139,13 +905,8 @@ def compute_read_specs( local_block_ids, remote_block_ids, remote_ranks, - physical_blocks_per_logical=1, transfer_config=None, ): - expanded = self.logical_to_remote_kernel_block_ids( - remote_block_ids, - physical_blocks_per_logical, - ) info = cast(MambaEngineTransferInfo, transfer_config) assert transfer_config is not None specs: list[ReadSpec] = [] @@ -1154,7 +915,7 @@ def compute_read_specs( info, rank, local_block_ids, - expanded, + remote_block_ids, ) specs.append( ReadSpec( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index f12984b241e4..4965ccf979c2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -65,8 +65,10 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.kv_cache_interface import ( FullAttentionSpec, + MambaSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: @@ -118,6 +120,34 @@ def __init__( } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) + # ---- Mamba model state (derived from model config) ---- + self._is_mamba_group = [ + isinstance(group.kv_cache_spec, MambaSpec) + for group in kv_cache_config.kv_cache_groups + ] + mamba_ssm_size = (0, 0) + self._has_mamba = any(self._is_mamba_group) + if self._has_mamba: + assert self._is_hma_required + mamba_spec = next( + spec + for spec in self._layer_specs.values() + if isinstance(spec, MambaSpec) + ) + conv_nbytes, ssm_nbytes = ( + torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc] + torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc] + ) + conv_shape, ssm_shape = ( + torch.Size(mamba_spec.shapes[0]), + torch.Size(mamba_spec.shapes[1]), + ) + mamba_ssm_size = ( + conv_shape.numel() * conv_nbytes, + ssm_shape.numel() * ssm_nbytes, + ) + self._mamba_ssm_size = mamba_ssm_size + # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] # Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs. @@ -228,14 +258,6 @@ def __init__( self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] - # Mamba-HMA per-engine state (only used when is_mamba). - # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. - # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) - # where conv/ssm bytes are per-TP-rank (dimension-sharded). With - # heterogeneous TP the per-rank sizes differ, so the ratio differs: - # e.g. Nemotron 30B: P(TP=4) -> 131, D(TP=1) -> 261. - self._physical_blocks_per_logical: dict[EngineId, int] = {} - # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} @@ -296,8 +318,6 @@ def __init__( physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, tp_size=vllm_config.parallel_config.tensor_parallel_size, ) - if self.transfer_policy.is_mamba: - assert self._is_hma_required self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True @@ -621,8 +641,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - tensor_shape=self.transfer_policy.get_tensor_shape(kv_caches), - is_mamba=self.transfer_policy.is_mamba, + # SSM States come in tuples (ssm, conv) + tensor_shape=next(iter(kv_caches.values())).shape + if not self._has_mamba + else None, + is_mamba=self._has_mamba, ) self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks @@ -669,7 +692,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to - # that of FI, with block laid out as in `get_block_len`. + # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. # However, physical page_size may differ when kernel requires a specific # block size. This leads to SSM and FA layers having different num_blocks. # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. @@ -682,8 +705,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) # `layer_spec.page_size_bytes` only accounts for logical page_size, that is # the page_size assuming constant `self._logical_num_blocks`. - physical_page_size = self.transfer_policy.compute_page_size( - layer_spec, self._physical_blocks_per_logical_kv_block + physical_page_size = ( + layer_spec.page_size_bytes + if isinstance(layer_spec, MambaSpec) + else layer_spec.page_size_bytes + // self._physical_blocks_per_logical_kv_block ) # For when registering multiple tensors eg K/V in separate regions. physical_page_size = physical_page_size // len(cache_list) @@ -692,8 +718,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): physical_page_size = physical_page_size * len( self.kv_cache_config.kv_cache_tensors ) - num_blocks = self.transfer_policy.get_num_blocks( - layer_spec, self.num_blocks, self._logical_num_blocks + num_blocks = ( + self._logical_num_blocks + if isinstance(layer_spec, MambaSpec) + else self.num_blocks ) # `page_size` accounts for physical blocks, st KVCache is always # [`num_blocks` * `page_size`] @@ -716,13 +744,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Registering layer %s with cache shape: %s", layer_name, cache.shape ) seen_base_addresses.append(base_addr) - self.block_len_per_layer.append( - self.transfer_policy.compute_layer_block_bytes( - layer_spec, - physical_page_size, - self._physical_blocks_per_logical_kv_block, + # Only record non-Mamba page sizes. + if isinstance(layer_spec, MambaSpec): + self.block_len_per_layer.append( + physical_page_size // self._physical_blocks_per_logical_kv_block ) - ) + else: + self.block_len_per_layer.append(physical_page_size) assert cache.shape[0] == num_blocks, ( "All kv cache tensors must have the same number of blocks" @@ -772,10 +800,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.transfer_policy.is_mamba: - self._physical_blocks_per_logical[self.engine_id] = ( - self._physical_blocks_per_logical_kv_block - ) + if self._has_mamba: logger.info( "Hybrid SSM registration: num_blocks=%s, " "logical_num_blocks=%s, ratio=%s, num_regions=%s, " @@ -785,7 +810,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._physical_blocks_per_logical_kv_block, self.num_regions, self.num_descs, - self.transfer_policy.ssm_sizes, + self._mamba_ssm_size, set(self.block_len_per_layer), ) @@ -806,7 +831,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, - ssm_sizes=self.transfer_policy.ssm_sizes, + ssm_sizes=self._mamba_ssm_size, attn_backend_name=self.backend_name, ) # Wrap metadata in payload with hash for defensive decoding @@ -927,7 +952,7 @@ def add_remote_agent( nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0], ) - if self.transfer_policy.is_mamba + if self._has_mamba else 1 ) transfer_info = self.transfer_policy.build_engine_transfer_info( @@ -945,12 +970,6 @@ def add_remote_agent( remote_physical_blocks_per_logical=physical_blocks_per_logical, ) transfer_topo.register_remote_engine(engine_id, transfer_info) - if ( - self.transfer_policy.is_mamba - and engine_id not in self._physical_blocks_per_logical - ): - self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical - logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -992,14 +1011,14 @@ def add_remote_agent( ) ### (Optional) Register local agent memory regions. MLA is not split. - # Remote tp_size > local tp_size: read from multiple remote ranks. - # Logically "split" own regions into |tp_ratio| chunks. Mind that - # we only do this once per remote tp_size (replica-friendly). if ( tp_ratio < 0 and not self.use_mla and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): + # Remote tp_size > local tp_size: read from multiple remote ranks. + # Logically "split" own regions into |tp_ratio| chunks. Mind that + # we only do this once per remote tp_size (replica-friendly). abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] @@ -1008,7 +1027,7 @@ def add_remote_agent( self.num_descs, abs_tp, transfer_config=transfer_topo.get_engine_info(engine_id) - if self.transfer_policy.is_mamba + if self._has_mamba else None, tp_size=transfer_topo.tp_size, is_mla=transfer_topo.is_mla, @@ -1031,11 +1050,9 @@ def add_remote_agent( is_blocks_first=transfer_topo.is_kv_layout_blocks_first, indexes_into_remote=indexes_into_remote, transfer_config=transfer_topo.get_engine_info(engine_id) - if self.transfer_policy.is_mamba + if self._has_mamba else None, - physical_blocks_per_logical=self._physical_blocks_per_logical.get( - engine_id, 1 - ), + physical_blocks_per_logical=transfer_info.remote_physical_blocks_per_logical, tp_size=transfer_topo.tp_size, total_num_kv_heads=transfer_topo.total_num_kv_heads, ) @@ -1081,7 +1098,7 @@ def _validate_remote_agent_handshake( ) # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # Mamba models can have replicated FA KV with tp_ratio < 0. - if not self.transfer_policy.is_mamba: + if not self._has_mamba: assert not ( tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id) ) @@ -1151,7 +1168,7 @@ def _validate_remote_agent_handshake( # With replicated KV cache, only the number of blocks can differ. # TODO (ZhanqiuHu): For mamba models, validate FA and mamba # block_lens separately. - if not self.transfer_policy.is_mamba: + if not self._has_mamba: for i in range(len(self.block_len_per_layer)): assert ( self.block_len_per_layer[i] // block_size_ratio @@ -1167,7 +1184,7 @@ def _validate_remote_agent_handshake( # HMA hybrid models (mamba+attention) pad block_len to # max(attn_page, mamba_page), so the linear tp_ratio scaling # assumption only holds for pure-attention models. - if not self.transfer_policy.is_mamba: + if not self._has_mamba: if tp_ratio > 0: assert ( remote_block_len @@ -1224,8 +1241,8 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): - meta.local_physical_block_ids = ( - self.transfer_policy.logical_to_kernel_block_ids(meta.local_block_ids) + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -1519,8 +1536,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): - meta.local_physical_block_ids = ( - self.transfer_policy.logical_to_kernel_block_ids(meta.local_block_ids) + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids ) assert meta.remote is not None # Remote block IDs are kept logical here; expanded in @@ -1577,17 +1594,24 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) + if self._has_mamba: + # Expand remote logical → kernel block IDs. + expanded_remote = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids, + remote_info.remote_physical_blocks_per_logical, + ) + else: + expanded_remote = self._logical_to_kernel_block_ids(meta.remote.block_ids) read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote.block_ids, + remote_block_ids=expanded_remote, remote_ranks=remote_ranks, - physical_blocks_per_logical=self._physical_blocks_per_logical.get( - engine_id, 1 - ), - transfer_config=remote_info if self.transfer_policy.is_mamba else None, + transfer_config=remote_info if self._has_mamba else None, ) - # MLA opt: when P TP > D TP, only a single read is needed. + # D may have to perform multiple reads from different remote ranks. + # MLA opt: when P TP > D TP, only a single read is executed for + # the first remote rank (cache is duplicated). if self.use_mla and tp_ratio < 0: read_specs = read_specs[:1] @@ -1631,8 +1655,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): ) if self.use_mla and tp_ratio < 0 and read_specs: - # Notify remote ranks that were not read from so they can update - # request state (cache is duplicated under MLA). + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] read_ranks = {s.remote_rank for s in read_specs} @@ -1728,15 +1752,12 @@ def _read_blocks( for i, remote_group in enumerate(remote_block_ids): num_remote_blocks = len(remote_group) num_local_blocks = len(local_block_ids[i]) - if not self.transfer_policy.is_mamba_group(i): + if not self._is_mamba_group[i]: assert num_local_blocks <= num_remote_blocks # Partial prefix cache hit: just read uncomputed blocks. # Skip mamba groups — their blocks represent full state (conv+ssm), # not per-token data, so trimming would corrupt the transfer. - if ( - num_local_blocks < num_remote_blocks - and not self.transfer_policy.is_mamba_group(i) - ): + if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -1745,13 +1766,15 @@ def _read_blocks( # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids, + self.dst_num_blocks[dst_engine_id], + physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids, + self.dst_num_blocks[self.engine_id], block_size_ratio=block_size_ratio, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -1813,22 +1836,77 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) + def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: + """ + Convert logical block ids to kernel physical block ids. + This is required when the logical block size (the one set by the user) + does not match the one required by the attn backend. + """ + if self._physical_blocks_per_logical_kv_block == 1: + # Noop when physical and logical block sizes are the same + return block_ids + block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( + 1, -1 + ) + # Mamba blocks have no logical<>physical discrepancy + group_specs = self.kv_cache_config.kv_cache_groups + return [ + BlockTable.map_to_kernel_blocks( + np.array(group), + self._physical_blocks_per_logical_kv_block, + block_arange, + ).tolist() + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec) + else group + for i, group in enumerate(block_ids) + ] + + def _logical_to_remote_kernel_block_ids( + self, block_ids: BlockIds, remote_physical_per_logical: int + ) -> BlockIds: + """Map logical block IDs to physical kernel block IDs on the remote. + + Args: + block_ids: per-group lists of logical block IDs. + remote_physical_per_logical: remote engine's physical blocks + per logical block. + + Returns: + Same structure with FA groups expanded (each logical block L + becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]). + Mamba groups are passed through unchanged. + """ + local_ratio = self._physical_blocks_per_logical_kv_block + if remote_physical_per_logical == 1: + return block_ids + local_arange = np.arange(local_ratio).reshape(1, -1) + group_specs = self.kv_cache_config.kv_cache_groups + result: list[list[int]] = [] + for i, group in enumerate(block_ids): + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): + arr = np.array(group).reshape(-1, 1) + expanded = (arr * remote_physical_per_logical + local_arange).flatten() + result.append(expanded.tolist()) + else: + # Mamba blocks are 1:1 logical-to-physical (no expansion). + result.append(group) + return result + def _get_block_descs_ids( self, - engine_id: str, block_ids: BlockIds, + dst_num_blocks: int, block_size_ratio: float | None = None, + physical_blocks_per_logical: int = 1, ) -> np.ndarray: """Thin wrapper delegating to the block transfer policy.""" return self.transfer_policy.get_block_descs_ids( block_ids=block_ids, num_regions=self.num_regions, - dst_num_blocks=self.dst_num_blocks[engine_id], + dst_num_blocks=dst_num_blocks, block_len_per_layer=self.block_len_per_layer, block_size_ratio=block_size_ratio, - physical_blocks_per_logical=self._physical_blocks_per_logical.get( - engine_id, 1 - ), + physical_blocks_per_logical=physical_blocks_per_logical, ) def get_kv_connector_stats(self) -> KVConnectorStats | None: From 3891e97dd058533168550ba13c32ccf533694356 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Mon, 20 Apr 2026 12:43:31 +0000 Subject: [PATCH 07/49] policy ABC cleanup: static helpers, abstract methods, remove dead code Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/utils.py | 287 -------------- .../v1/nixl/block_transfer_policy.py | 351 +++++++++--------- .../kv_connector/v1/nixl/worker.py | 4 +- 3 files changed, 179 insertions(+), 463 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f17c446cf7f7..e2d6cab01185 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -633,133 +633,6 @@ def get_transfer_cache_regions( # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] - # ============================================================ - # Mamba-specific methods - # ============================================================ - - def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank (mamba-only).""" - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - return remote_rank not in mamba_info.fa_source_set - - def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - fa_index = mamba_info.fa_source_indices - if remote_rank in fa_index: - return fa_index[remote_rank] - K = self.total_num_kv_heads - remote_tp = mamba_info.remote_tp_size - r_head = self._physical_head_range(remote_tp, K, remote_rank) - for target in mamba_info.remote_fa_source_ranks: - t_head = self._physical_head_range(remote_tp, K, target) - if self._range_overlap(r_head, t_head): - return fa_index[target] - return 0 - - def fa_rank_offset( - self, remote_engine_id: EngineId, remote_kv_block_len: int - ) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when local does not index into remote. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - if self.is_mla or tp_ratio <= 0: - return 0 - K = self.total_num_kv_heads - is_local_replicated = self.tp_size > K - if is_local_replicated: - local_head = self.tp_rank * K // self.tp_size - p_rank = mamba_info.remote_fa_source_ranks[0] - p_start = p_rank * K // mamba_info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return self.tp_rank % tp_ratio * remote_kv_block_len - - def needs_split_handles(self, remote_engine_id: EngineId) -> bool: - """Whether per-remote-rank split handles are needed. - - True when FA and mamba have different read counts, requiring - different splitting factors in the local handle. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) - return ( - tp_ratio < 0 - and not self.is_mla - and len(mamba_info.remote_all_source_ranks) > 1 - ) - - def compute_split_handle_data( - self, - remote_engine_id: EngineId, - src_blocks_data: list[tuple[int, int, int]], - num_fa_descs: int, - abs_tp: int, - ) -> list[list[tuple[int, int, int]]]: - """Per-remote-rank (addr, len, dev) triples for Mamba-HMA split - handles. - - FA descriptors (indices < num_fa_descs) are sliced by - ``remote_num_fa_reads``; mamba descriptors are sliced uniformly - by ``abs_tp``. - """ - mamba_info = self._engines[remote_engine_id] - assert isinstance(mamba_info, MambaEngineTransferInfo) - all_handle_data: list[list[tuple[int, int, int]]] = [] - for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks): - handle_data: list[tuple[int, int, int]] = [] - skip_fa = self.should_skip_fa(remote_engine_id, p_rank) - fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0 - for j, (addr, local_len, dev) in enumerate(src_blocks_data): - if j < num_fa_descs: - assert mamba_info.remote_num_fa_reads >= 1 - fa_chunk = local_len // mamba_info.remote_num_fa_reads - handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) - else: - mamba_chunk = local_len // abs_tp - handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) - all_handle_data.append(handle_data) - return all_handle_data - - def filter_block_ids_for_rank( - self, - remote_engine_id: EngineId, - remote_rank: int, - local_ids: BlockIds, - remote_ids: BlockIds, - is_mamba_group: list[bool], - ) -> tuple[BlockIds, BlockIds]: - """Zero out FA groups for remote ranks outside ``fa_source_ranks``. - - Returns (filtered_local_ids, filtered_remote_ids). When the - remote rank carries FA data for this local rank, returns the - inputs unchanged. - """ - if not self.should_skip_fa(remote_engine_id, remote_rank): - return local_ids, remote_ids - num_groups = len(local_ids) - filtered_local: list[list[int]] = [ - [] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups) - ] - filtered_remote: list[list[int]] = [ - [] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups) - ] - return filtered_local, filtered_remote - def describe(self, remote_engine_id: EngineId) -> str: """One-line summary of transfer config for logging.""" info = self._engines[remote_engine_id] @@ -781,163 +654,3 @@ def describe(self, remote_engine_id: EngineId) -> str: f"fa_desc_bytes={info.remote_fa_descriptor_bytes})" ) return f"TransferTopology({base})" - - # ============================================================ - # Private helpers - # ============================================================ - # Mamba-HMA hetero-TP transfer config: - # With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across - # P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is - # almost always uniquely sharded per P rank. So the number of P - # ranks D must read from can differ between FA and Mamba, and they - # must be handled separately. - - @staticmethod - def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight partitioning): - consecutive ranks share a head before moving to the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - @staticmethod - def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - - # ============================================================ - # Private: build Mamba transfer info - # ============================================================ - - def _build_mamba_info( - self, - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - local_block_len: int, - ) -> MambaEngineTransferInfo: - """Compute Mamba transfer plan.""" - K = self.total_num_kv_heads - local_tp = self.tp_size - local_rank = self.tp_rank - - is_remote_replicated = remote_tp_size > K - remote_physical_heads = max(1, K // remote_tp_size) - - if local_tp >= remote_tp_size: - assert local_tp % remote_tp_size == 0 - tp_ratio = local_tp // remote_tp_size - else: - assert remote_tp_size % local_tp == 0 - tp_ratio = -(remote_tp_size // local_tp) - - abs_tp = -tp_ratio if tp_ratio < 0 else 1 - - mamba_range: range | None = None - if tp_ratio < 0: - mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) - - # ---- FA read targets ---- - if self.is_mla or tp_ratio >= 0: - num_fa_reads = 1 - fa_source_ranks: list[int] = ( - [0] - if self.is_mla - else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] - ) - else: - local_needs = self._physical_head_range(local_tp, K, local_rank) - search_range = ( - mamba_range if mamba_range is not None else range(remote_tp_size) - ) - seen: set[tuple[int, int]] = set() - fa_source_ranks = [] - for p in search_range: - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - if not fa_source_ranks: - for p in range(remote_tp_size): - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - num_fa_reads = len(fa_source_ranks) - - # ---- All source ranks (mamba + FA) ---- - if mamba_range is not None and abs_tp > num_fa_reads: - num_mamba_reads = abs_tp - all_source_ranks = list(mamba_range) - else: - num_mamba_reads = num_fa_reads - all_source_ranks = list(fa_source_ranks) - - # ---- FA descriptor bytes ---- - effective_block_len = min(local_block_len, remote_block_len) - if self.is_kv_layout_blocks_first: - fa_descriptor_bytes = effective_block_len // 2 - else: - fa_descriptor_bytes = effective_block_len - - # ---- Validation ---- - is_local_replicated = local_tp > K - if is_local_replicated and is_remote_replicated and tp_ratio > 0: - logger.info( - "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", - local_tp, - remote_tp_size, - K, - ) - tt_set = set(all_source_ranks) - for t in fa_source_ranks: - if t not in tt_set: - logger.error( - "FA source rank %d NOT in all_source_ranks %s.", - t, - all_source_ranks, - ) - if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: - local_k_half = local_block_len // 2 - remote_k_half = remote_block_len // 2 - expected = local_k_half // num_fa_reads - if expected != remote_k_half: - logger.warning( - "FA size mismatch: local_k_half=%d / reads=%d = %d, " - "but remote_k_half=%d.", - local_k_half, - num_fa_reads, - expected, - remote_k_half, - ) - - return MambaEngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), - remote_fa_source_ranks=tuple(fa_source_ranks), - remote_all_source_ranks=tuple(all_source_ranks), - remote_num_fa_reads=num_fa_reads, - remote_num_mamba_reads=num_mamba_reads, - remote_fa_descriptor_bytes=fa_descriptor_bytes, - is_remote_replicated=is_remote_replicated, - remote_physical_heads=remote_physical_heads, - ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index dbd2eb2e37f9..088334d87981 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -25,8 +25,6 @@ BlockIds, EngineTransferInfo, MambaEngineTransferInfo, - _physical_head_range, - _range_overlap, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, @@ -102,18 +100,21 @@ def build_engine_transfer_info( """ # ------------------------------------------------------------------ - # Block length helper (used by descriptor building) + # KV block length helper (used by FA descriptor building) # ------------------------------------------------------------------ - def get_block_len( - self, + @staticmethod + def get_kv_block_len( layer_idx: int, - first_split: bool, block_len_per_layer: list[int], is_blocks_first: bool, - mamba_view: bool = False, ) -> int: - """Block length for one K/V element. Mamba overrides for SSM view.""" + """Byte length of one K or V descriptor for a layer. + + For FA/KV layers only — Mamba SSM descriptors use ``_ssm_sizes`` + directly. When ``is_blocks_first``, K and V are interleaved so + the descriptor covers half the block. + """ if is_blocks_first: return block_len_per_layer[layer_idx] // 2 return block_len_per_layer[layer_idx] @@ -139,6 +140,7 @@ def get_block_descs_ids( # Local descriptor building (concrete default = FA-only) # ------------------------------------------------------------------ + @abstractmethod def build_local_descs( self, base_addresses: list[int], @@ -149,19 +151,8 @@ def build_local_descs( device_id: int, is_blocks_first: bool, ) -> list[tuple[int, int, int]]: - """Build local (src) descriptor tuples for NIXL registration. - - Default builds FA descriptors only. Mamba overrides to extend - with SSM (conv + temporal state) descriptors. - """ - return self._build_fa_local_descs( - base_addresses, - block_len_per_layer, - num_blocks, - block_size_ratio, - device_id, - is_blocks_first, - ) + """Build local (src) descriptor tuples for NIXL registration.""" + ... def _build_fa_local_descs( self, @@ -179,9 +170,8 @@ def _build_fa_local_descs( # The new block_len is using prefill block_len; # and num_blocks is multiple with N kv_block_len = ( - self.get_block_len( + self.get_kv_block_len( i, - True, block_len_per_layer, is_blocks_first, ) @@ -200,9 +190,8 @@ def _build_fa_local_descs( ) ) if is_blocks_first: - second_split = self.get_block_len( + second_split = self.get_kv_block_len( i, - False, block_len_per_layer, is_blocks_first, ) @@ -259,9 +248,10 @@ def build_src_split_handles( ... # ------------------------------------------------------------------ - # Read spec computation (concrete default = one spec per rank) + # Read spec computation # ------------------------------------------------------------------ + @abstractmethod def compute_read_specs( self, local_block_ids: BlockIds, @@ -276,18 +266,113 @@ def compute_read_specs( MLA trimming (keeping only the first spec) is handled by the worker. Block ID expansion (logical→kernel) is done by the worker before - calling this method. Mamba overrides to additionally filter - block IDs per rank via ``filter_block_ids_for_rank``. + calling this method. """ - _ = transfer_config - return [ - ReadSpec( - remote_rank=rank, - local_block_ids=local_block_ids, - remote_block_ids=remote_block_ids, - ) - for rank in remote_ranks - ] + ... + + # ------------------------------------------------------------------ + # FA head replication helpers (hetero-TP: tp_size > num_kv_heads) + # ------------------------------------------------------------------ + + @staticmethod + def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + """Physical KV head range stored in a rank's KV cache tensor. + + When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. + When ``tp_size > num_heads``: 1 physical head per rank. Heads are + distributed **contiguously** (matching vLLM's GQA weight + partitioning): consecutive ranks share a head before moving to + the next one. + """ + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + + @staticmethod + def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) + + @staticmethod + def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: + """Whether to skip FA groups for this remote rank. + + Returns False unless ``info`` carries FA source tracking + (``MambaEngineTransferInfo``) and the rank is a replicated + duplicate. + """ + if not isinstance(info, MambaEngineTransferInfo): + return False + return remote_rank not in info.fa_source_set + + @staticmethod + def _fa_head_slot( + info: EngineTransferInfo, + remote_rank: int, + total_num_kv_heads: int, + ) -> int: + """Index into local FA block for this remote rank's head data. + + For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. + For ranks NOT in ``fa_source_ranks`` (replicated duplicates), + returns the slot of the matching source rank with the same head. + Returns 0 when ``info`` has no FA source tracking. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + fa_index = info.fa_source_indices + if remote_rank in fa_index: + return fa_index[remote_rank] + K = total_num_kv_heads + remote_tp = info.remote_tp_size + phr = ModelBlockTransferPolicy._physical_head_range + rov = ModelBlockTransferPolicy._range_overlap + r_head = phr(remote_tp, K, remote_rank) + for target in info.remote_fa_source_ranks: + t_head = phr(remote_tp, K, target) + if rov(r_head, t_head): + return fa_index[target] + return 0 + + @staticmethod + def _fa_rank_offset( + info: EngineTransferInfo, + remote_kv_block_len: int, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + ) -> int: + """Byte offset into remote FA block for this local rank. + + When local TP is replicated (local_tp > K), multiple local ranks + share a head. Computes offset *relative to the target remote + rank's first head* so it works regardless of how many heads the + remote has. Returns 0 when ``info`` has no FA source tracking + or when local does not index into remote. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) # noqa: E501 + if is_mla or tp_ratio <= 0: + return 0 + K = total_num_kv_heads + is_local_replicated = tp_size > K + if is_local_replicated: + local_head = tp_rank * K // tp_size + p_rank = info.remote_fa_source_ranks[0] + p_start = p_rank * K // info.remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return tp_rank % tp_ratio * remote_kv_block_len # ------------------------------------------------------------------ # Factory @@ -326,11 +411,7 @@ def create( class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): """Policy for pure-attention (dense) models. - Inherits all registration helpers, block ID mapping, - ``compute_read_specs``, ``build_local_descs``, ``get_block_len`` - from the ABC. Only overrides genuinely different methods: - ``get_block_descs_ids``, ``build_remote_descs``, - ``build_src_split_handles``, and ``build_engine_transfer_info``. + Inherits ``get_block_len`` from the ABC. """ def __init__( @@ -340,6 +421,25 @@ def __init__( ): super().__init__(kv_cache_config, physical_blocks_per_logical) + def build_local_descs( + self, + base_addresses, + block_len_per_layer, + num_blocks, + logical_num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ): + return self._build_fa_local_descs( + base_addresses, + block_len_per_layer, + num_blocks, + block_size_ratio, + device_id, + is_blocks_first, + ) + def get_block_descs_ids( self, block_ids, @@ -382,9 +482,8 @@ def build_remote_descs( nixl_agent_meta.kv_caches_base_addr, ): # Read our whole local region size from remote. - local_block_len = self.get_block_len( + local_block_len = self.get_kv_block_len( i, - True, block_len_per_layer, is_blocks_first, ) @@ -405,9 +504,8 @@ def build_remote_descs( addr = base_addr + blk * page_size + rank_offset result.append((addr, local_block_len, dev_id)) if is_blocks_first: - second_split = self.get_block_len( + second_split = self.get_kv_block_len( i, - False, block_len_per_layer, is_blocks_first, ) @@ -445,6 +543,22 @@ def build_src_split_handles( result.append(blocks_data) return result + def compute_read_specs( + self, + local_block_ids, + remote_block_ids, + remote_ranks, + transfer_config=None, + ): + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + ) + for rank in remote_ranks + ] + def build_engine_transfer_info( self, *, @@ -528,22 +642,6 @@ def conv_decomp(self) -> MambaConvSplitInfo: """Conv-state sub-projection decomposition.""" return self._conv_decomp - def get_block_len( - self, - layer_idx, - first_split, - block_len_per_layer, - is_blocks_first, - mamba_view=False, - ): - if mamba_view and is_blocks_first: - # NOTE (NickLucche) Mamba Opt: this is already skipping the - # padding so we're only transferring the minimum required bytes. - return self._ssm_sizes[not first_split] - return super().get_block_len( - layer_idx, first_split, block_len_per_layer, is_blocks_first - ) - def get_block_descs_ids( self, block_ids, @@ -697,7 +795,8 @@ def build_remote_descs( # indexes_into_remote is not used for Mamba: FA offset is computed # via fa_rank_offset which accounts for GQA/HMA head mapping. _ = indexes_into_remote - info = cast(MambaEngineTransferInfo, transfer_config) + assert isinstance(transfer_config, MambaEngineTransferInfo) + info = transfer_config result: list[tuple[int, int, int]] = [] result.extend( self._build_fa_remote_descs( @@ -746,9 +845,8 @@ def _build_fa_remote_descs( for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): - local_block_len = self.get_block_len( + local_block_len = self.get_kv_block_len( i, - True, block_len_per_layer, is_blocks_first, ) @@ -757,7 +855,7 @@ def _build_fa_remote_descs( local_block_len = remote_kv_block_len if tp_ratio < 0 and not use_mla: local_block_len = local_block_len // info.remote_num_fa_reads - rank_offset = self.fa_rank_offset( + rank_offset = self._fa_rank_offset( info, remote_kv_block_len, tp_rank=tp_rank, @@ -772,9 +870,8 @@ def _build_fa_remote_descs( addr = base_addr + blk * page_size + rank_offset result.append((addr, local_block_len, dev_id)) if is_blocks_first: - second_split = self.get_block_len( + second_split = self.get_kv_block_len( i, - False, block_len_per_layer, is_blocks_first, ) @@ -872,8 +969,8 @@ def build_src_split_handles( is_mla: bool = False, total_num_kv_heads: int = 1, ): - info = cast(MambaEngineTransferInfo, transfer_config) - assert transfer_config is not None + assert isinstance(transfer_config, MambaEngineTransferInfo) + info = transfer_config if self.needs_split_handles( info, tp_size=tp_size, @@ -907,8 +1004,8 @@ def compute_read_specs( remote_ranks, transfer_config=None, ): - info = cast(MambaEngineTransferInfo, transfer_config) - assert transfer_config is not None + assert isinstance(transfer_config, MambaEngineTransferInfo) + info = transfer_config specs: list[ReadSpec] = [] for rank in remote_ranks: filtered_local, filtered_remote = self.filter_block_ids_for_rank( @@ -969,15 +1066,15 @@ def build_engine_transfer_info( else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] ) else: - local_needs = _physical_head_range(local_tp, K, local_rank) + local_needs = self._physical_head_range(local_tp, K, local_rank) search_range = ( mamba_range if mamba_range is not None else range(remote_tp_size) ) seen: set[tuple[int, int]] = set() fa_source_ranks = [] for p in search_range: - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) + p_has = self._physical_head_range(remote_tp_size, K, p) + ov = self._range_overlap(local_needs, p_has) if len(ov) > 0: key = (ov.start, ov.stop) if key not in seen: @@ -985,8 +1082,8 @@ def build_engine_transfer_info( fa_source_ranks.append(p) if not fa_source_ranks: for p in range(remote_tp_size): - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) + p_has = self._physical_head_range(remote_tp_size, K, p) + ov = self._range_overlap(local_needs, p_has) if len(ov) > 0: key = (ov.start, ov.stop) if key not in seen: @@ -1054,70 +1151,6 @@ def build_engine_transfer_info( remote_physical_heads=remote_physical_heads, ) - # ------------------------------------------------------------------ - # Orchestration methods - # ------------------------------------------------------------------ - - def should_skip_fa(self, info: MambaEngineTransferInfo, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank.""" - return remote_rank not in info.fa_source_set - - def fa_head_slot( - self, - info: MambaEngineTransferInfo, - remote_rank: int, - total_num_kv_heads: int, - ) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - """ - fa_index = info.fa_source_indices - if remote_rank in fa_index: - return fa_index[remote_rank] - K = total_num_kv_heads - remote_tp = info.remote_tp_size - r_head = _physical_head_range(remote_tp, K, remote_rank) - for target in info.remote_fa_source_ranks: - t_head = _physical_head_range(remote_tp, K, target) - if _range_overlap(r_head, t_head): - return fa_index[target] - return 0 - - def fa_rank_offset( - self, - info: MambaEngineTransferInfo, - remote_kv_block_len: int, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - ) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when local does not index into remote. - """ - tp_ratio = ( - tp_size // info.remote_tp_size - if tp_size >= info.remote_tp_size - else -(info.remote_tp_size // tp_size) - ) # noqa: E501 - if is_mla or tp_ratio <= 0: - return 0 - K = total_num_kv_heads - is_local_replicated = tp_size > K - if is_local_replicated: - local_head = tp_rank * K // tp_size - p_rank = info.remote_fa_source_ranks[0] - p_start = p_rank * K // info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return tp_rank % tp_ratio * remote_kv_block_len - def needs_split_handles( self, info: MambaEngineTransferInfo, @@ -1153,9 +1186,9 @@ def compute_split_handle_data( all_handle_data: list[list[tuple[int, int, int]]] = [] for p_idx, p_rank in enumerate(info.remote_all_source_ranks): handle_data: list[tuple[int, int, int]] = [] - skip_fa = self.should_skip_fa(info, p_rank) + skip_fa = self._should_skip_fa(info, p_rank) fa_slot = ( - self.fa_head_slot(info, p_rank, total_num_kv_heads) + self._fa_head_slot(info, p_rank, total_num_kv_heads) if not skip_fa else 0 ) @@ -1183,7 +1216,7 @@ def filter_block_ids_for_rank( remote rank carries FA data for this local rank, returns the inputs unchanged. """ - if not self.should_skip_fa(info, remote_rank): + if not self._should_skip_fa(info, remote_rank): return local_ids, remote_ids num_groups = len(local_ids) filtered_local: list[list[int]] = [ @@ -1195,31 +1228,3 @@ def filter_block_ids_for_rank( for g in range(num_groups) ] return filtered_local, filtered_remote - - def describe_mamba( - self, - info: MambaEngineTransferInfo, - tp_rank: int, - tp_size: int, - total_num_kv_heads: int, - ) -> str: - """One-line summary of Mamba transfer config for logging.""" - tp_ratio = ( - tp_size // info.remote_tp_size - if tp_size >= info.remote_tp_size - else -(info.remote_tp_size // tp_size) - ) # noqa: E501 - return ( - f"MambaTransferPolicy(" - f"tp_ratio={tp_ratio}, " - f"K={total_num_kv_heads}, " - f"local_tp={tp_size}, " - f"remote_tp={info.remote_tp_size}, " - f"local_rank={tp_rank}, " - f"fa_reads={info.remote_num_fa_reads}, " - f"mamba_reads={info.remote_num_mamba_reads}, " - f"fa_sources={list(info.remote_fa_source_ranks)}, " - f"all_sources={list(info.remote_all_source_ranks)}, " - f"fa_desc_bytes={info.remote_fa_descriptor_bytes}, " - f"remote_block_len={info.remote_block_len})" - ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 4965ccf979c2..fad9b2b52e11 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -21,7 +21,6 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, - MambaEngineTransferInfo, TransferTopology, get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, @@ -31,7 +30,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( - MambaModelBlockTransferPolicy, ModelBlockTransferPolicy, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( @@ -692,7 +690,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to - # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. + # that of FI, with block laid out as in `get_kv_block_len`. # However, physical page_size may differ when kernel requires a specific # block size. This leads to SSM and FA layers having different num_blocks. # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. From 3cea947ca7078db19dddc50b18d06a6cb13d8979 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:22:43 +0000 Subject: [PATCH 08/49] revert comment to reference get_backend_aware_kv_block_len Signed-off-by: Zhanqiu Hu --- vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index fad9b2b52e11..9f1a9ce25861 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -690,7 +690,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to - # that of FI, with block laid out as in `get_kv_block_len`. + # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. # However, physical page_size may differ when kernel requires a specific # block size. This leads to SSM and FA layers having different num_blocks. # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. From 2a1e6c949ec9dc2909bf5860d86f70fac2dd71c5 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:23:48 +0000 Subject: [PATCH 09/49] always pass transfer_config to policy methods, remove has_mamba guards Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 9f1a9ce25861..8ce3c2f871fe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1024,9 +1024,7 @@ def add_remote_agent( self.src_blocks_data, self.num_descs, abs_tp, - transfer_config=transfer_topo.get_engine_info(engine_id) - if self._has_mamba - else None, + transfer_config=transfer_topo.get_engine_info(engine_id), tp_size=transfer_topo.tp_size, is_mla=transfer_topo.is_mla, total_num_kv_heads=transfer_topo.total_num_kv_heads, @@ -1047,9 +1045,7 @@ def add_remote_agent( block_len_per_layer=self.block_len_per_layer, is_blocks_first=transfer_topo.is_kv_layout_blocks_first, indexes_into_remote=indexes_into_remote, - transfer_config=transfer_topo.get_engine_info(engine_id) - if self._has_mamba - else None, + transfer_config=transfer_topo.get_engine_info(engine_id), physical_blocks_per_logical=transfer_info.remote_physical_blocks_per_logical, tp_size=transfer_topo.tp_size, total_num_kv_heads=transfer_topo.total_num_kv_heads, @@ -1604,7 +1600,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): local_block_ids=meta.local_physical_block_ids, remote_block_ids=expanded_remote, remote_ranks=remote_ranks, - transfer_config=remote_info if self._has_mamba else None, + transfer_config=remote_info, ) # D may have to perform multiple reads from different remote ranks. From 2a2323ee06cee0609526bab6027faeef55e35852 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:30:10 +0000 Subject: [PATCH 10/49] remove abs_tp from build_src_split_handles args; compute from tp_size and remote_tp_size Signed-off-by: Zhanqiu Hu --- .../kv_connector/v1/nixl/block_transfer_policy.py | 10 ++++++---- .../kv_transfer/kv_connector/v1/nixl/worker.py | 2 -- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 088334d87981..0a24500ad28e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -238,7 +238,6 @@ def build_src_split_handles( self, src_blocks_data: list[tuple[int, int, int]], num_descs: int, - abs_tp: int, transfer_config: Any | None = None, tp_size: int = 1, is_mla: bool = False, @@ -521,13 +520,15 @@ def build_src_split_handles( self, src_blocks_data, num_descs, - abs_tp, transfer_config=None, tp_size: int = 1, is_mla: bool = False, total_num_kv_heads: int = 1, ): - _ = (num_descs, transfer_config, tp_size, is_mla, total_num_kv_heads) + _ = (num_descs, is_mla, total_num_kv_heads) + assert isinstance(transfer_config, EngineTransferInfo) + assert transfer_config.remote_tp_size > tp_size + abs_tp = transfer_config.remote_tp_size // tp_size result: list[list[tuple[int, int, int]]] = [] for i in range(abs_tp): blocks_data: list[tuple[int, int, int]] = [] @@ -963,7 +964,6 @@ def build_src_split_handles( self, src_blocks_data, num_descs, - abs_tp, transfer_config=None, tp_size: int = 1, is_mla: bool = False, @@ -971,6 +971,8 @@ def build_src_split_handles( ): assert isinstance(transfer_config, MambaEngineTransferInfo) info = transfer_config + assert info.remote_tp_size > tp_size + abs_tp = info.remote_tp_size // tp_size if self.needs_split_handles( info, tp_size=tp_size, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 8ce3c2f871fe..4efb5ef0a211 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1017,13 +1017,11 @@ def add_remote_agent( # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). - abs_tp = -tp_ratio self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] for handle_data in self.transfer_policy.build_src_split_handles( self.src_blocks_data, self.num_descs, - abs_tp, transfer_config=transfer_topo.get_engine_info(engine_id), tp_size=transfer_topo.tp_size, is_mla=transfer_topo.is_mla, From fea54dea185d011615c364a242873723eb835ec4 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:36:31 +0000 Subject: [PATCH 11/49] reorder build_local_descs args: memory, block geometry, layout Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 27 ++++++++++--------- .../kv_connector/v1/nixl/worker.py | 7 +++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 0a24500ad28e..29638cc48e42 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -143,12 +143,15 @@ def get_block_descs_ids( @abstractmethod def build_local_descs( self, + # Memory base_addresses: list[int], - block_len_per_layer: list[int], + device_id: int, + # Block geometry num_blocks: int, logical_num_blocks: int, block_size_ratio: int, - device_id: int, + block_len_per_layer: list[int], + # Layout is_blocks_first: bool, ) -> list[tuple[int, int, int]]: """Build local (src) descriptor tuples for NIXL registration.""" @@ -157,10 +160,10 @@ def build_local_descs( def _build_fa_local_descs( self, base_addresses: list[int], - block_len_per_layer: list[int], + device_id: int, num_blocks: int, block_size_ratio: int, - device_id: int, + block_len_per_layer: list[int], is_blocks_first: bool, ) -> list[tuple[int, int, int]]: """Build FA local descriptors (shared by Dense and Mamba).""" @@ -423,19 +426,19 @@ def __init__( def build_local_descs( self, base_addresses, - block_len_per_layer, + device_id, num_blocks, logical_num_blocks, block_size_ratio, - device_id, + block_len_per_layer, is_blocks_first, ): return self._build_fa_local_descs( base_addresses, - block_len_per_layer, + device_id, num_blocks, block_size_ratio, - device_id, + block_len_per_layer, is_blocks_first, ) @@ -697,19 +700,19 @@ def get_block_descs_ids( def build_local_descs( self, base_addresses, - block_len_per_layer, + device_id, num_blocks, logical_num_blocks, block_size_ratio, - device_id, + block_len_per_layer, is_blocks_first, ): fa_descs = self._build_fa_local_descs( base_addresses, - block_len_per_layer, + device_id, num_blocks, block_size_ratio, - device_id, + block_len_per_layer, is_blocks_first, ) num_regions = len(base_addresses) * (2 if is_blocks_first else 1) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 4efb5ef0a211..f1e3fb9d3c9b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -862,12 +862,15 @@ def register_local_xfer_handler( local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] blocks_data = self.transfer_policy.build_local_descs( + # Memory base_addresses=local_base_addresses, - block_len_per_layer=self.block_len_per_layer, + device_id=self.device_id, + # Block geometry num_blocks=self.num_blocks, logical_num_blocks=self._logical_num_blocks, block_size_ratio=block_size_ratio, - device_id=self.device_id, + block_len_per_layer=self.block_len_per_layer, + # Layout is_blocks_first=transfer_topo.is_kv_layout_blocks_first, ) logger.debug( From 0dd4f9247db48a1808eea4b66b59c67f2b8fdfd6 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:39:22 +0000 Subject: [PATCH 12/49] clean up build_src_split_handles: rename transfer_config to remote_info, remove defaults, regroup args Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 37 ++++++++++--------- .../kv_connector/v1/nixl/worker.py | 5 ++- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 29638cc48e42..61c28044ef7c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -239,12 +239,15 @@ def build_remote_descs( @abstractmethod def build_src_split_handles( self, + # Local src data src_blocks_data: list[tuple[int, int, int]], num_descs: int, - transfer_config: Any | None = None, - tp_size: int = 1, - is_mla: bool = False, - total_num_kv_heads: int = 1, + # TP topology + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + # Remote engine info + remote_info: EngineTransferInfo, ) -> list[list[tuple[int, int, int]]]: """Build split handle data for P_TP > D_TP scenario.""" ... @@ -523,15 +526,15 @@ def build_src_split_handles( self, src_blocks_data, num_descs, - transfer_config=None, - tp_size: int = 1, - is_mla: bool = False, - total_num_kv_heads: int = 1, + tp_size, + is_mla, + total_num_kv_heads, + remote_info, ): _ = (num_descs, is_mla, total_num_kv_heads) - assert isinstance(transfer_config, EngineTransferInfo) - assert transfer_config.remote_tp_size > tp_size - abs_tp = transfer_config.remote_tp_size // tp_size + assert isinstance(remote_info, EngineTransferInfo) + assert remote_info.remote_tp_size > tp_size + abs_tp = remote_info.remote_tp_size // tp_size result: list[list[tuple[int, int, int]]] = [] for i in range(abs_tp): blocks_data: list[tuple[int, int, int]] = [] @@ -967,13 +970,13 @@ def build_src_split_handles( self, src_blocks_data, num_descs, - transfer_config=None, - tp_size: int = 1, - is_mla: bool = False, - total_num_kv_heads: int = 1, + tp_size, + is_mla, + total_num_kv_heads, + remote_info, ): - assert isinstance(transfer_config, MambaEngineTransferInfo) - info = transfer_config + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info assert info.remote_tp_size > tp_size abs_tp = info.remote_tp_size // tp_size if self.needs_split_handles( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index f1e3fb9d3c9b..069aa709970e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1023,12 +1023,15 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] for handle_data in self.transfer_policy.build_src_split_handles( + # Local src data self.src_blocks_data, self.num_descs, - transfer_config=transfer_topo.get_engine_info(engine_id), + # TP topology tp_size=transfer_topo.tp_size, is_mla=transfer_topo.is_mla, total_num_kv_heads=transfer_topo.total_num_kv_heads, + # Remote engine info + remote_info=transfer_topo.get_engine_info(engine_id), ): descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type From 93f21c45276aac6a629b2083f2168cc3ed87634a Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 20:49:26 +0000 Subject: [PATCH 13/49] clean up build_remote_descs: rename args, internalize indexes_into_remote and physical_blocks_per_logical, regroup Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 67 +++++++++---------- .../kv_connector/v1/nixl/worker.py | 23 +++---- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 61c28044ef7c..76769c5e7a7e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -220,18 +220,20 @@ def _build_fa_local_descs( @abstractmethod def build_remote_descs( self, + # TP topology + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + tp_ratio: int, + # Remote engine info nixl_agent_meta: NixlAgentMetadata, + remote_info: EngineTransferInfo, + # Block geometry block_size_ratio: int, - tp_ratio: int, - tp_rank: int, - use_mla: bool, block_len_per_layer: list[int], + # Layout is_blocks_first: bool, - indexes_into_remote: bool, - transfer_config: Any | None = None, - physical_blocks_per_logical: int = 1, - tp_size: int = 1, - total_num_kv_heads: int = 1, ) -> list[tuple[int, int, int]]: """Build remote (dst) descriptor tuples.""" ... @@ -463,25 +465,27 @@ def get_block_descs_ids( def build_remote_descs( self, + tp_rank, + tp_size, + is_mla, + total_num_kv_heads, + tp_ratio, nixl_agent_meta, + remote_info, block_size_ratio, - tp_ratio, - tp_rank, - use_mla, block_len_per_layer, is_blocks_first, - indexes_into_remote, - transfer_config=None, - physical_blocks_per_logical=1, - tp_size: int = 1, - total_num_kv_heads: int = 1, ): # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Register all remote blocks, but only the corresponding kv heads. - _ = (tp_size, total_num_kv_heads, physical_blocks_per_logical, transfer_config) + assert isinstance(remote_info, EngineTransferInfo) + indexes_into_remote = ( + not (is_mla or remote_info.remote_tp_size > total_num_kv_heads) + and tp_ratio > 0 + ) result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, @@ -496,7 +500,7 @@ def build_remote_descs( remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: local_block_len = remote_kv_block_len - if tp_ratio < 0 and not use_mla: + if tp_ratio < 0 and not is_mla: # Remote tp is bigger: read a chunk of local region from remote local_block_len = local_block_len // (-tp_ratio) rank_offset = ( @@ -514,7 +518,7 @@ def build_remote_descs( block_len_per_layer, is_blocks_first, ) - if tp_ratio < 0 and not use_mla: + if tp_ratio < 0 and not is_mla: second_split = second_split // (-tp_ratio) for blk in range(num_blocks): addr = base_addr + blk * page_size + rank_offset @@ -786,24 +790,19 @@ def _build_mamba_local_descs( def build_remote_descs( self, + tp_rank, + tp_size, + is_mla, + total_num_kv_heads, + tp_ratio, nixl_agent_meta, + remote_info, block_size_ratio, - tp_ratio, - tp_rank, - use_mla, block_len_per_layer, is_blocks_first, - indexes_into_remote, - transfer_config=None, - physical_blocks_per_logical=1, - tp_size: int = 1, - total_num_kv_heads: int = 1, ): - # indexes_into_remote is not used for Mamba: FA offset is computed - # via fa_rank_offset which accounts for GQA/HMA head mapping. - _ = indexes_into_remote - assert isinstance(transfer_config, MambaEngineTransferInfo) - info = transfer_config + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info result: list[tuple[int, int, int]] = [] result.extend( self._build_fa_remote_descs( @@ -815,7 +814,7 @@ def build_remote_descs( total_num_kv_heads, block_size_ratio, is_blocks_first, - use_mla, + is_mla, block_len_per_layer, ) ) @@ -824,7 +823,7 @@ def build_remote_descs( nixl_agent_meta, tp_ratio, tp_rank, - physical_blocks_per_logical, + info.remote_physical_blocks_per_logical, ) ) return result diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 069aa709970e..df999eb093cb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -999,11 +999,6 @@ def add_remote_agent( # this is the ratio between the two sizes. tp_ratio = transfer_topo.tp_ratio(remote_tp_size) - # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = ( - not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 - ) - logger.debug( "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", engine_id, @@ -1041,18 +1036,20 @@ def add_remote_agent( ### Register remote agent memory regions blocks_data = self.transfer_policy.build_remote_descs( + # TP topology + tp_rank=transfer_topo.tp_rank, + tp_size=transfer_topo.tp_size, + is_mla=transfer_topo.is_mla, + total_num_kv_heads=transfer_topo.total_num_kv_heads, + tp_ratio=tp_ratio, + # Remote engine info nixl_agent_meta=nixl_agent_meta, + remote_info=transfer_topo.get_engine_info(engine_id), + # Block geometry block_size_ratio=block_size_ratio, - tp_ratio=tp_ratio, - tp_rank=self.tp_rank, - use_mla=self.use_mla, block_len_per_layer=self.block_len_per_layer, + # Layout is_blocks_first=transfer_topo.is_kv_layout_blocks_first, - indexes_into_remote=indexes_into_remote, - transfer_config=transfer_topo.get_engine_info(engine_id), - physical_blocks_per_logical=transfer_info.remote_physical_blocks_per_logical, - tp_size=transfer_topo.tp_size, - total_num_kv_heads=transfer_topo.total_num_kv_heads, ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", From 4bdfe0ec7d9dc76bc4d3e675568a33b0cd899b5a Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 21:15:47 +0000 Subject: [PATCH 14/49] rename transfer_config to remote_info, inline _get_block_descs_ids, unpack read spec loop Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 15 +++-- .../kv_connector/v1/nixl/worker.py | 56 ++++++++----------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 76769c5e7a7e..8c95023a9be8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import numpy as np import torch @@ -126,7 +126,9 @@ def get_kv_block_len( @abstractmethod def get_block_descs_ids( self, + # Input block_ids: BlockIds, + # Block geometry num_regions: int, dst_num_blocks: int, block_len_per_layer: list[int], @@ -264,7 +266,7 @@ def compute_read_specs( local_block_ids: BlockIds, remote_block_ids: BlockIds, remote_ranks: list[int], - transfer_config: Any | None = None, + remote_info: EngineTransferInfo, ) -> list[ReadSpec]: """Compute the full set of read operations needed for a request. @@ -559,8 +561,9 @@ def compute_read_specs( local_block_ids, remote_block_ids, remote_ranks, - transfer_config=None, + remote_info, ): + assert isinstance(remote_info, EngineTransferInfo) return [ ReadSpec( remote_rank=rank, @@ -1009,10 +1012,10 @@ def compute_read_specs( local_block_ids, remote_block_ids, remote_ranks, - transfer_config=None, + remote_info, ): - assert isinstance(transfer_config, MambaEngineTransferInfo) - info = transfer_config + assert isinstance(remote_info, MambaEngineTransferInfo) + info = remote_info specs: list[ReadSpec] = [] for rank in remote_ranks: filtered_local, filtered_remote = self.filter_block_ids_for_rank( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index df999eb093cb..518f4e139d9c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1589,19 +1589,21 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) + # TODO (ZhanqiuHu): Unify logical_to_kernel_block_ids + # and logical_to_remote_kernel_block_ids. if self._has_mamba: # Expand remote logical → kernel block IDs. - expanded_remote = self._logical_to_remote_kernel_block_ids( + remote_block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, remote_info.remote_physical_blocks_per_logical, ) else: - expanded_remote = self._logical_to_kernel_block_ids(meta.remote.block_ids) + remote_block_ids = self._logical_to_kernel_block_ids(meta.remote.block_ids) read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, - remote_block_ids=expanded_remote, + remote_block_ids=remote_block_ids, remote_ranks=remote_ranks, - transfer_config=remote_info, + remote_info=remote_info, ) # D may have to perform multiple reads from different remote ranks. @@ -1611,12 +1613,15 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): read_specs = read_specs[:1] for i, spec in enumerate(read_specs): + remote_rank = spec.remote_rank + local_block_ids = spec.local_block_ids + remote_block_ids = spec.remote_block_ids remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s with remote block size %s for req %s", meta.remote.engine_id, - spec.remote_rank, + remote_rank, remote_block_size, req_id, ) @@ -1635,16 +1640,16 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ - spec.remote_rank + remote_rank ] self._read_blocks( request_id=req_id, dst_engine_id=meta.remote.engine_id, remote_request_id=meta.remote.request_id, - local_block_ids=spec.local_block_ids, - remote_block_ids=spec.remote_block_ids, - remote_rank=spec.remote_rank, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + remote_rank=remote_rank, local_xfer_side_handle=local_xfer_side_handle, remote_xfer_side_handle=remote_xfer_side_handle, ) @@ -1760,14 +1765,18 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - remote_block_descs_ids = self._get_block_descs_ids( - remote_block_ids, - self.dst_num_blocks[dst_engine_id], + remote_block_descs_ids = self.transfer_policy.get_block_descs_ids( + block_ids=remote_block_ids, + num_regions=self.num_regions, + dst_num_blocks=self.dst_num_blocks[dst_engine_id], + block_len_per_layer=self.block_len_per_layer, physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) - local_block_descs_ids = self._get_block_descs_ids( - local_block_ids, - self.dst_num_blocks[self.engine_id], + local_block_descs_ids = self.transfer_policy.get_block_descs_ids( + block_ids=local_block_ids, + num_regions=self.num_regions, + dst_num_blocks=self.dst_num_blocks[self.engine_id], + block_len_per_layer=self.block_len_per_layer, block_size_ratio=block_size_ratio, physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) @@ -1887,23 +1896,6 @@ def _logical_to_remote_kernel_block_ids( result.append(group) return result - def _get_block_descs_ids( - self, - block_ids: BlockIds, - dst_num_blocks: int, - block_size_ratio: float | None = None, - physical_blocks_per_logical: int = 1, - ) -> np.ndarray: - """Thin wrapper delegating to the block transfer policy.""" - return self.transfer_policy.get_block_descs_ids( - block_ids=block_ids, - num_regions=self.num_regions, - dst_num_blocks=dst_num_blocks, - block_len_per_layer=self.block_len_per_layer, - block_size_ratio=block_size_ratio, - physical_blocks_per_logical=physical_blocks_per_logical, - ) - def get_kv_connector_stats(self) -> KVConnectorStats | None: """ Get the KV transfer stats for the connector. From 7456cfd06dbcbd65abaa1782e3b05a3f53a9866e Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 21:23:33 +0000 Subject: [PATCH 15/49] extract _fa_descs_ids static helper to deduplicate FA descriptor ID computation Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 8c95023a9be8..c614fd9dd21b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -307,6 +307,16 @@ def _range_overlap(a: range, b: range) -> range: stop = min(a.stop, b.stop) return range(start, max(start, stop)) + @staticmethod + def _fa_descs_ids( + block_ids_arr: np.ndarray, + num_regions: int, + stride: int, + ) -> np.ndarray: + """FA descriptor IDs: region_id * stride + block_id.""" + region_ids = np.arange(num_regions)[:, None] + return (region_ids * stride + block_ids_arr[None, :]).flatten() + @staticmethod def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: """Whether to skip FA groups for this remote rank. @@ -461,9 +471,7 @@ def get_block_descs_ids( num_blocks = dst_num_blocks if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) - region_ids = np.arange(num_regions)[:, None] - block_ids_arr = np.concatenate(block_ids)[None, :] - return (region_ids * num_blocks + block_ids_arr).flatten() + return self._fa_descs_ids(np.concatenate(block_ids), num_regions, num_blocks) def build_remote_descs( self, @@ -678,7 +686,6 @@ def get_block_descs_ids( num_blocks = dst_num_blocks if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) - region_ids = np.arange(num_regions)[:, None] ratio = physical_blocks_per_logical logical_blocks = num_blocks // ratio num_fa_descs = num_regions * num_blocks @@ -695,16 +702,18 @@ def get_block_descs_ids( )[:, None] all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): - group_arr = np.asarray(group)[None, :] + group_arr = np.asarray(group) if self._is_mamba_group[i]: # Mamba blocks are 1:1 logical-to-physical (no expansion). all_descs.append( ( - mamba_region_ids * logical_blocks + group_arr + num_fa_descs + mamba_region_ids * logical_blocks + + group_arr[None, :] + + num_fa_descs ).flatten() ) else: - all_descs.append((region_ids * num_blocks + group_arr).flatten()) + all_descs.append(self._fa_descs_ids(group_arr, num_regions, num_blocks)) return np.concatenate(all_descs) def build_local_descs( From b1765d0bded0bbcee6df5693bbf7b71bf6b8e8f3 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 21:32:47 +0000 Subject: [PATCH 16/49] add FA replication notes on ABC helpers and _build_fa_remote_descs Signed-off-by: Zhanqiu Hu --- .../kv_connector/v1/nixl/block_transfer_policy.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index c614fd9dd21b..31e61482acbc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -317,6 +317,13 @@ def _fa_descs_ids( region_ids = np.arange(num_regions)[:, None] return (region_ids * stride + block_ids_arr[None, :]).flatten() + # NOTE (ZhanqiuHu): The helpers below (_should_skip_fa, _fa_head_slot, + # _fa_rank_offset) handle FA replication, where num_kv_heads < tp_size + # and multiple ranks share the same physical KV head data. As of now + # they are only used for Mamba hybrid models because pure FA models + # typically have enough KV heads to avoid replication. We may modify + # and reuse these for Gemma4 hybrid cases. + @staticmethod def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: """Whether to skip FA groups for this remote rank. @@ -840,6 +847,8 @@ def build_remote_descs( ) return result + # NOTE (ZhanqiuHu): See ABC comment on _should_skip_fa for context. + # This method also handles FA replication (see ABC helpers above). def _build_fa_remote_descs( self, nixl_agent_meta, From 655a9ebb6257398b0daec3e9a146843ff8899572 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 21:38:20 +0000 Subject: [PATCH 17/49] fix unit tests: adapt to _get_block_descs_ids inlining and _physical_blocks_per_logical removal Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 3f5a9b9cc031..a87280bed577 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -153,14 +153,15 @@ def test_read_blocks_for_req_expands_remote_ids( ) remote_engine_id = "remote-engine" - if has_mamba: - worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio} # Mock transfer_topo: empty remote ranks skips the transfer machinery # entirely, isolating the block-ID expansion logic. worker.transfer_topo = MagicMock() worker.transfer_topo.target_remote_ranks.return_value = [] - worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1) + worker.transfer_topo.get_engine_info.return_value = MagicMock( + remote_tp_size=1, + remote_physical_blocks_per_logical=remote_ratio if has_mamba else 1, + ) worker.transfer_topo.tp_ratio.return_value = 1 metadata = NixlConnectorMetadata() @@ -308,29 +309,28 @@ def test_nixl_metadata_hma_block_ids_structure(): @pytest.mark.cpu_test def test_get_block_descs_ids_hybrid_ssm(): - """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM - when ratio=1 (no kernel block size mismatch).""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( - NixlConnectorWorker, + """Test get_block_descs_ids uses per-group strides for hybrid + FA+SSM when ratio=1 (no kernel block size mismatch).""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 + MambaModelBlockTransferPolicy, ) - worker = object.__new__(NixlConnectorWorker) + policy = object.__new__(MambaModelBlockTransferPolicy) + policy._is_mamba_group = [False, True] num_blocks = 100 - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = 1 - worker._physical_blocks_per_logical = {engine_id: 1} - worker.block_len_per_layer = [100] - # num_descs = num_regions * num_blocks (no blocks_first doubling) - worker.num_descs = 2 * num_blocks + num_regions = 2 + block_len_per_layer = [100] fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + result = policy.get_block_descs_ids( + block_ids=(fa_blocks, ssm_blocks), + num_regions=num_regions, + dst_num_blocks=num_blocks, + block_len_per_layer=block_len_per_layer, + physical_blocks_per_logical=1, + ) # FA group: stride=num_blocks=100, offset=0 # region0: [3, 5], region1: [103, 105] @@ -344,30 +344,30 @@ def test_get_block_descs_ids_hybrid_ssm(): @pytest.mark.cpu_test def test_get_block_descs_ids_kernel_block_mismatch(): - """Test _get_block_descs_ids uses different strides for FA (kernel blocks) - vs SSM (logical blocks) when ratio > 1.""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( - NixlConnectorWorker, + """Test get_block_descs_ids uses different strides for FA + (kernel blocks) vs SSM (logical blocks) when ratio > 1.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 + MambaModelBlockTransferPolicy, ) - worker = object.__new__(NixlConnectorWorker) + policy = object.__new__(MambaModelBlockTransferPolicy) + policy._is_mamba_group = [False, True] ratio = 4 logical_blocks = 100 num_blocks = logical_blocks * ratio # 400 kernel blocks - engine_id = "test-engine" - worker.num_regions = 2 - worker.dst_num_blocks = {engine_id: num_blocks} - worker._has_mamba = True - worker._is_mamba_group = [False, True] - worker._physical_blocks_per_logical_kv_block = ratio - worker._physical_blocks_per_logical = {engine_id: ratio} - worker.block_len_per_layer = [100] - worker.num_descs = 2 * num_blocks # 800 + num_regions = 2 + block_len_per_layer = [100] fa_blocks = [3, 7] # kernel-level block IDs ssm_blocks = [1, 2] # logical block IDs - result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks)) + result = policy.get_block_descs_ids( + block_ids=(fa_blocks, ssm_blocks), + num_regions=num_regions, + dst_num_blocks=num_blocks, + block_len_per_layer=block_len_per_layer, + physical_blocks_per_logical=ratio, + ) # FA group: stride=num_blocks=400, offset=0 # region0: [3, 7], region1: [403, 407] From 6a6087db5746e1892f14019ca892a4fe8fc56cba Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 21 Apr 2026 21:53:08 +0000 Subject: [PATCH 18/49] fix unit tests: restore meta.remote.block_ids mutation, mock transfer_policy and use_mla, set num_descs Signed-off-by: Zhanqiu Hu --- tests/v1/kv_connector/unit/test_nixl_connector.py | 1 + tests/v1/kv_connector/unit/test_nixl_connector_hma.py | 3 +++ .../distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 7 +++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fb4b641e1376..d20f026da241 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -726,6 +726,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + worker.num_descs = len(worker.src_blocks_data) def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index a87280bed577..267505546cda 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -163,6 +163,9 @@ def test_read_blocks_for_req_expands_remote_ids( remote_physical_blocks_per_logical=remote_ratio if has_mamba else 1, ) worker.transfer_topo.tp_ratio.return_value = 1 + worker.transfer_policy = MagicMock() + worker.transfer_policy.compute_read_specs.return_value = [] + worker.use_mla = False metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 518f4e139d9c..e0997e49ca77 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1593,12 +1593,15 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # and logical_to_remote_kernel_block_ids. if self._has_mamba: # Expand remote logical → kernel block IDs. - remote_block_ids = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, remote_info.remote_physical_blocks_per_logical, ) else: - remote_block_ids = self._logical_to_kernel_block_ids(meta.remote.block_ids) + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids + ) + remote_block_ids = meta.remote.block_ids read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, remote_block_ids=remote_block_ids, From 6deae9f01be809ff998a9f7767fc9677196398ce Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 22 Apr 2026 18:49:43 +0000 Subject: [PATCH 19/49] consolidate local topology into TransferTopology for build_engine_transfer_info Add physical_blocks_per_logical to TransferTopology and pass transfer_topo directly to build_engine_transfer_info, reducing the method's parameter count from 10 to 6. Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_nixl_connector.py | 2 + .../kv_transfer/kv_connector/utils.py | 1 + .../v1/nixl/block_transfer_policy.py | 38 +++++++++---------- .../kv_connector/v1/nixl/worker.py | 12 +++--- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index d20f026da241..46f6ba706708 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -472,6 +472,7 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) @@ -2436,6 +2437,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], + physical_blocks_per_logical=decode_worker._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index e2d6cab01185..1662824aee67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -441,6 +441,7 @@ class TransferTopology: is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] + physical_blocks_per_logical: int tensor_shape: torch.Size | None = None def __post_init__(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 31e61482acbc..de2f2a667385 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -25,6 +25,7 @@ BlockIds, EngineTransferInfo, MambaEngineTransferInfo, + TransferTopology, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, @@ -77,17 +78,15 @@ def __init__( # Per-engine transfer info (data operations) # ------------------------------------------------------------------ - # TODO (ZhanqiuHu): Revisit data packing for local facts and remote facts. @abstractmethod def build_engine_transfer_info( self, *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_kv_layout_blocks_first: bool, + # Local topology + transfer_topo: TransferTopology, + # Block geometry local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) remote_tp_size: int, remote_block_size: int, remote_block_len: int, @@ -591,17 +590,17 @@ def compute_read_specs( def build_engine_transfer_info( self, *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_kv_layout_blocks_first: bool, + # Local topology + transfer_topo: TransferTopology, + # Block geometry local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) remote_tp_size: int, remote_block_size: int, remote_block_len: int, remote_physical_blocks_per_logical: int, ) -> EngineTransferInfo: + _ = (transfer_topo, local_block_len) return EngineTransferInfo( remote_tp_size=remote_tp_size, remote_block_len=remote_block_len, @@ -1054,20 +1053,21 @@ def compute_read_specs( def build_engine_transfer_info( self, *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_kv_layout_blocks_first: bool, + # Local topology + transfer_topo: TransferTopology, + # Block geometry local_block_len: int, + # Remote facts (from NixlAgentMetadata handshake) remote_tp_size: int, remote_block_size: int, remote_block_len: int, remote_physical_blocks_per_logical: int, ) -> MambaEngineTransferInfo: - K = total_num_kv_heads - local_tp = tp_size - local_rank = tp_rank + K = transfer_topo.total_num_kv_heads + local_tp = transfer_topo.tp_size + local_rank = transfer_topo.tp_rank + is_mla = transfer_topo.is_mla + is_kv_layout_blocks_first = transfer_topo.is_kv_layout_blocks_first is_remote_replicated = remote_tp_size > K remote_physical_heads = max(1, K // remote_tp_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index e0997e49ca77..375cc1709c0e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -639,6 +639,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, # SSM States come in tuples (ssm, conv) tensor_shape=next(iter(kv_caches.values())).shape if not self._has_mamba @@ -957,14 +958,11 @@ def add_remote_agent( else 1 ) transfer_info = self.transfer_policy.build_engine_transfer_info( - # Local facts (from TransferTopology). - tp_rank=transfer_topo.tp_rank, - tp_size=transfer_topo.tp_size, - is_mla=transfer_topo.is_mla, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - is_kv_layout_blocks_first=transfer_topo.is_kv_layout_blocks_first, + # Local topology + transfer_topo=transfer_topo, + # Block geometry local_block_len=self.block_len_per_layer[0], - # Remote facts (from NixlAgentMetadata handshake). + # Remote facts (from NixlAgentMetadata handshake) remote_tp_size=remote_tp_size, remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], From 4b72036a563b1b556a19e4f95809b77305a5e341 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 22 Apr 2026 18:58:02 +0000 Subject: [PATCH 20/49] =?UTF-8?q?consolidate=20build=5Fremote=5Fdescs=20pa?= =?UTF-8?q?rams=20via=20transfer=5Ftopo=20+=20engine=5Fid=20(11=E2=86=924)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 80 +++++++------------ .../kv_connector/v1/nixl/worker.py | 18 +---- 2 files changed, 35 insertions(+), 63 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index de2f2a667385..296ee8ecff46 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -23,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, + EngineId, EngineTransferInfo, MambaEngineTransferInfo, TransferTopology, @@ -221,20 +222,10 @@ def _build_fa_local_descs( @abstractmethod def build_remote_descs( self, - # TP topology - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - tp_ratio: int, - # Remote engine info + transfer_topo: TransferTopology, + engine_id: EngineId, nixl_agent_meta: NixlAgentMetadata, - remote_info: EngineTransferInfo, - # Block geometry - block_size_ratio: int, block_len_per_layer: list[int], - # Layout - is_blocks_first: bool, ) -> list[tuple[int, int, int]]: """Build remote (dst) descriptor tuples.""" ... @@ -481,43 +472,42 @@ def get_block_descs_ids( def build_remote_descs( self, - tp_rank, - tp_size, - is_mla, - total_num_kv_heads, - tp_ratio, + transfer_topo, + engine_id, nixl_agent_meta, - remote_info, - block_size_ratio, block_len_per_layer, - is_blocks_first, ): # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Register all remote blocks, but only the corresponding kv heads. + remote_info = transfer_topo.get_engine_info(engine_id) assert isinstance(remote_info, EngineTransferInfo) + tp_rank = transfer_topo.tp_rank + is_mla = transfer_topo.is_mla + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + tp_ratio = transfer_topo.tp_ratio(remote_info.remote_tp_size) + block_size_ratio = transfer_topo.block_size_ratio(remote_info.remote_block_size) indexes_into_remote = ( - not (is_mla or remote_info.remote_tp_size > total_num_kv_heads) + not ( + is_mla or remote_info.remote_tp_size > transfer_topo.total_num_kv_heads + ) and tp_ratio > 0 ) result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): - # Read our whole local region size from remote. local_block_len = self.get_kv_block_len( i, block_len_per_layer, is_blocks_first, ) - # using remote kv_block_len as transfer unit remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: local_block_len = remote_kv_block_len if tp_ratio < 0 and not is_mla: - # Remote tp is bigger: read a chunk of local region from remote local_block_len = local_block_len // (-tp_ratio) rank_offset = ( tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 @@ -808,31 +798,22 @@ def _build_mamba_local_descs( def build_remote_descs( self, - tp_rank, - tp_size, - is_mla, - total_num_kv_heads, - tp_ratio, + transfer_topo, + engine_id, nixl_agent_meta, - remote_info, - block_size_ratio, block_len_per_layer, - is_blocks_first, ): + remote_info = transfer_topo.get_engine_info(engine_id) assert isinstance(remote_info, MambaEngineTransferInfo) info = remote_info + tp_ratio = transfer_topo.tp_ratio(info.remote_tp_size) result: list[tuple[int, int, int]] = [] result.extend( self._build_fa_remote_descs( + transfer_topo, nixl_agent_meta, info, tp_ratio, - tp_rank, - tp_size, - total_num_kv_heads, - block_size_ratio, - is_blocks_first, - is_mla, block_len_per_layer, ) ) @@ -840,7 +821,7 @@ def build_remote_descs( self._build_mamba_remote_descs( nixl_agent_meta, tp_ratio, - tp_rank, + transfer_topo.tp_rank, info.remote_physical_blocks_per_logical, ) ) @@ -850,19 +831,20 @@ def build_remote_descs( # This method also handles FA replication (see ABC helpers above). def _build_fa_remote_descs( self, - nixl_agent_meta, + transfer_topo: TransferTopology, + nixl_agent_meta: NixlAgentMetadata, info: MambaEngineTransferInfo, tp_ratio: int, - tp_rank: int, - tp_size: int, - total_num_kv_heads: int, - block_size_ratio: int, - is_blocks_first: bool, - use_mla: bool, block_len_per_layer: list[int], ): """Build remote FA descriptors for mamba models using transfer_cfg for GQA-aware sizing.""" + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size + is_mla = transfer_topo.is_mla + total_num_kv_heads = transfer_topo.total_num_kv_heads + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + block_size_ratio = transfer_topo.block_size_ratio(info.remote_block_size) assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got {block_size_ratio=}." @@ -879,14 +861,14 @@ def _build_fa_remote_descs( remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: local_block_len = remote_kv_block_len - if tp_ratio < 0 and not use_mla: + if tp_ratio < 0 and not is_mla: local_block_len = local_block_len // info.remote_num_fa_reads rank_offset = self._fa_rank_offset( info, remote_kv_block_len, tp_rank=tp_rank, tp_size=tp_size, - is_mla=use_mla, + is_mla=is_mla, total_num_kv_heads=total_num_kv_heads, ) num_blocks = nixl_agent_meta.num_blocks @@ -901,7 +883,7 @@ def _build_fa_remote_descs( block_len_per_layer, is_blocks_first, ) - if tp_ratio < 0 and not use_mla: + if tp_ratio < 0 and not is_mla: second_split = second_split // info.remote_num_fa_reads for blk in range(num_blocks): addr = base_addr + blk * page_size + rank_offset diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 375cc1709c0e..fba15b0762c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1034,20 +1034,10 @@ def add_remote_agent( ### Register remote agent memory regions blocks_data = self.transfer_policy.build_remote_descs( - # TP topology - tp_rank=transfer_topo.tp_rank, - tp_size=transfer_topo.tp_size, - is_mla=transfer_topo.is_mla, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - tp_ratio=tp_ratio, - # Remote engine info - nixl_agent_meta=nixl_agent_meta, - remote_info=transfer_topo.get_engine_info(engine_id), - # Block geometry - block_size_ratio=block_size_ratio, - block_len_per_layer=self.block_len_per_layer, - # Layout - is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + transfer_topo, + engine_id, + nixl_agent_meta, + self.block_len_per_layer, ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", From 5b28f28e0dcd0e28945dbc440a6984d50be12e48 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 22 Apr 2026 19:24:44 +0000 Subject: [PATCH 21/49] =?UTF-8?q?consolidate=20build=5Fsrc=5Fsplit=5Fhandl?= =?UTF-8?q?es=20params=20via=20transfer=5Ftopo=20+=20engine=5Fid=20(7?= =?UTF-8?q?=E2=86=924)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 34 ++++++++----------- .../kv_connector/v1/nixl/worker.py | 9 ++--- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 296ee8ecff46..ba7a9b82bcf2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -233,15 +233,10 @@ def build_remote_descs( @abstractmethod def build_src_split_handles( self, - # Local src data + transfer_topo: TransferTopology, + engine_id: EngineId, src_blocks_data: list[tuple[int, int, int]], num_descs: int, - # TP topology - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - # Remote engine info - remote_info: EngineTransferInfo, ) -> list[list[tuple[int, int, int]]]: """Build split handle data for P_TP > D_TP scenario.""" ... @@ -534,17 +529,16 @@ def build_remote_descs( def build_src_split_handles( self, + transfer_topo, + engine_id, src_blocks_data, num_descs, - tp_size, - is_mla, - total_num_kv_heads, - remote_info, ): - _ = (num_descs, is_mla, total_num_kv_heads) + remote_info = transfer_topo.get_engine_info(engine_id) assert isinstance(remote_info, EngineTransferInfo) - assert remote_info.remote_tp_size > tp_size - abs_tp = remote_info.remote_tp_size // tp_size + _ = num_descs + assert remote_info.remote_tp_size > transfer_topo.tp_size + abs_tp = remote_info.remote_tp_size // transfer_topo.tp_size result: list[list[tuple[int, int, int]]] = [] for i in range(abs_tp): blocks_data: list[tuple[int, int, int]] = [] @@ -969,21 +963,21 @@ def _build_mamba_remote_descs( def build_src_split_handles( self, + transfer_topo, + engine_id, src_blocks_data, num_descs, - tp_size, - is_mla, - total_num_kv_heads, - remote_info, ): + remote_info = transfer_topo.get_engine_info(engine_id) assert isinstance(remote_info, MambaEngineTransferInfo) info = remote_info + tp_size = transfer_topo.tp_size assert info.remote_tp_size > tp_size abs_tp = info.remote_tp_size // tp_size if self.needs_split_handles( info, tp_size=tp_size, - is_mla=is_mla, + is_mla=transfer_topo.is_mla, ): result = list( self.compute_split_handle_data( @@ -991,7 +985,7 @@ def build_src_split_handles( src_blocks_data, num_descs, abs_tp, - total_num_kv_heads=total_num_kv_heads, + total_num_kv_heads=transfer_topo.total_num_kv_heads, ) ) logger.info( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index fba15b0762c6..1d92185be544 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1016,15 +1016,10 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] for handle_data in self.transfer_policy.build_src_split_handles( - # Local src data + transfer_topo, + engine_id, self.src_blocks_data, self.num_descs, - # TP topology - tp_size=transfer_topo.tp_size, - is_mla=transfer_topo.is_mla, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - # Remote engine info - remote_info=transfer_topo.get_engine_info(engine_id), ): descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type From 81e82fbea805a99896e9ceb3d8958cc8d85310ca Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 22 Apr 2026 19:35:12 +0000 Subject: [PATCH 22/49] fix mooncake: compute physical_blocks_per_logical for TransferTopology Signed-off-by: Zhanqiu Hu --- .../kv_connector/v1/mooncake/mooncake_connector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 715fcbde16c9..b1b1cd27a5bc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -821,6 +821,7 @@ def __init__( self.cache_config = vllm_config.cache_config self.kv_cache_config = kv_cache_config self.use_mla = self.model_config.use_mla + self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() # Get the attention backend from the first layer @@ -841,6 +842,7 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=[backend], + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) self.async_zmq_ctx = zmq.asyncio.Context() @@ -863,6 +865,9 @@ def _sync_block_size_with_kernel(self) -> None: kernel_block_size, ) assert self.block_size > kernel_block_size + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) self.block_size = kernel_block_size def __del__(self): From d08ab9b72708e224b44440c8f6e12c3bf44c6f27 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 22 Apr 2026 19:45:17 +0000 Subject: [PATCH 23/49] inline _fa_descs_ids, move static methods to module-level private utils Signed-off-by: Zhanqiu Hu --- .../v1/nixl/block_transfer_policy.py | 301 ++++++++---------- 1 file changed, 141 insertions(+), 160 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index ba7a9b82bcf2..17d91c86a96f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -58,6 +58,125 @@ class ReadSpec: remote_block_ids: BlockIds +# ------------------------------------------------------------------ +# Private module-level helpers +# ------------------------------------------------------------------ + + +def _get_kv_block_len( + layer_idx: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> int: + """Byte length of one K or V descriptor for a layer. + + For FA/KV layers only — Mamba SSM descriptors use ``_ssm_sizes`` + directly. When ``is_blocks_first``, K and V are interleaved so + the descriptor covers half the block. + """ + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + +def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + """Physical KV head range stored in a rank's KV cache tensor. + + When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. + When ``tp_size > num_heads``: 1 physical head per rank. Heads are + distributed **contiguously** (matching vLLM's GQA weight + partitioning): consecutive ranks share a head before moving to + the next one. + """ + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + + +def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) + + +def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: + """Whether to skip FA groups for this remote rank. + + Returns False unless ``info`` carries FA source tracking + (``MambaEngineTransferInfo``) and the rank is a replicated + duplicate. + """ + if not isinstance(info, MambaEngineTransferInfo): + return False + return remote_rank not in info.fa_source_set + + +def _fa_head_slot( + info: EngineTransferInfo, + remote_rank: int, + total_num_kv_heads: int, +) -> int: + """Index into local FA block for this remote rank's head data. + + For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. + For ranks NOT in ``fa_source_ranks`` (replicated duplicates), + returns the slot of the matching source rank with the same head. + Returns 0 when ``info`` has no FA source tracking. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + fa_index = info.fa_source_indices + if remote_rank in fa_index: + return fa_index[remote_rank] + K = total_num_kv_heads + remote_tp = info.remote_tp_size + r_head = _physical_head_range(remote_tp, K, remote_rank) + for target in info.remote_fa_source_ranks: + t_head = _physical_head_range(remote_tp, K, target) + if _range_overlap(r_head, t_head): + return fa_index[target] + return 0 + + +def _fa_rank_offset( + info: EngineTransferInfo, + remote_kv_block_len: int, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, +) -> int: + """Byte offset into remote FA block for this local rank. + + When local TP is replicated (local_tp > K), multiple local ranks + share a head. Computes offset *relative to the target remote + rank's first head* so it works regardless of how many heads the + remote has. Returns 0 when ``info`` has no FA source tracking + or when local does not index into remote. + """ + if not isinstance(info, MambaEngineTransferInfo): + return 0 + tp_ratio = ( + tp_size // info.remote_tp_size + if tp_size >= info.remote_tp_size + else -(info.remote_tp_size // tp_size) + ) + if is_mla or tp_ratio <= 0: + return 0 + K = total_num_kv_heads + is_local_replicated = tp_size > K + if is_local_replicated: + local_head = tp_rank * K // tp_size + p_rank = info.remote_fa_source_ranks[0] + p_start = p_rank * K // info.remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return tp_rank % tp_ratio * remote_kv_block_len + + class ModelBlockTransferPolicy(ABC): """Abstract base for model-specific block transfer logic. @@ -99,26 +218,6 @@ def build_engine_transfer_info( Mamba models return ``MambaEngineTransferInfo``. """ - # ------------------------------------------------------------------ - # KV block length helper (used by FA descriptor building) - # ------------------------------------------------------------------ - - @staticmethod - def get_kv_block_len( - layer_idx: int, - block_len_per_layer: list[int], - is_blocks_first: bool, - ) -> int: - """Byte length of one K or V descriptor for a layer. - - For FA/KV layers only — Mamba SSM descriptors use ``_ssm_sizes`` - directly. When ``is_blocks_first``, K and V are interleaved so - the descriptor covers half the block. - """ - if is_blocks_first: - return block_len_per_layer[layer_idx] // 2 - return block_len_per_layer[layer_idx] - # ------------------------------------------------------------------ # Descriptor ID computation (abstract — genuinely different per model) # ------------------------------------------------------------------ @@ -175,7 +274,7 @@ def _build_fa_local_descs( # The new block_len is using prefill block_len; # and num_blocks is multiple with N kv_block_len = ( - self.get_kv_block_len( + _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -195,7 +294,7 @@ def _build_fa_local_descs( ) ) if is_blocks_first: - second_split = self.get_kv_block_len( + second_split = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -264,127 +363,6 @@ def compute_read_specs( """ ... - # ------------------------------------------------------------------ - # FA head replication helpers (hetero-TP: tp_size > num_kv_heads) - # ------------------------------------------------------------------ - - @staticmethod - def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight - partitioning): consecutive ranks share a head before moving to - the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - @staticmethod - def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - - @staticmethod - def _fa_descs_ids( - block_ids_arr: np.ndarray, - num_regions: int, - stride: int, - ) -> np.ndarray: - """FA descriptor IDs: region_id * stride + block_id.""" - region_ids = np.arange(num_regions)[:, None] - return (region_ids * stride + block_ids_arr[None, :]).flatten() - - # NOTE (ZhanqiuHu): The helpers below (_should_skip_fa, _fa_head_slot, - # _fa_rank_offset) handle FA replication, where num_kv_heads < tp_size - # and multiple ranks share the same physical KV head data. As of now - # they are only used for Mamba hybrid models because pure FA models - # typically have enough KV heads to avoid replication. We may modify - # and reuse these for Gemma4 hybrid cases. - - @staticmethod - def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank. - - Returns False unless ``info`` carries FA source tracking - (``MambaEngineTransferInfo``) and the rank is a replicated - duplicate. - """ - if not isinstance(info, MambaEngineTransferInfo): - return False - return remote_rank not in info.fa_source_set - - @staticmethod - def _fa_head_slot( - info: EngineTransferInfo, - remote_rank: int, - total_num_kv_heads: int, - ) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - Returns 0 when ``info`` has no FA source tracking. - """ - if not isinstance(info, MambaEngineTransferInfo): - return 0 - fa_index = info.fa_source_indices - if remote_rank in fa_index: - return fa_index[remote_rank] - K = total_num_kv_heads - remote_tp = info.remote_tp_size - phr = ModelBlockTransferPolicy._physical_head_range - rov = ModelBlockTransferPolicy._range_overlap - r_head = phr(remote_tp, K, remote_rank) - for target in info.remote_fa_source_ranks: - t_head = phr(remote_tp, K, target) - if rov(r_head, t_head): - return fa_index[target] - return 0 - - @staticmethod - def _fa_rank_offset( - info: EngineTransferInfo, - remote_kv_block_len: int, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - ) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when ``info`` has no FA source tracking - or when local does not index into remote. - """ - if not isinstance(info, MambaEngineTransferInfo): - return 0 - tp_ratio = ( - tp_size // info.remote_tp_size - if tp_size >= info.remote_tp_size - else -(info.remote_tp_size // tp_size) - ) # noqa: E501 - if is_mla or tp_ratio <= 0: - return 0 - K = total_num_kv_heads - is_local_replicated = tp_size > K - if is_local_replicated: - local_head = tp_rank * K // tp_size - p_rank = info.remote_fa_source_ranks[0] - p_start = p_rank * K // info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return tp_rank % tp_ratio * remote_kv_block_len - # ------------------------------------------------------------------ # Factory # ------------------------------------------------------------------ @@ -463,7 +441,9 @@ def get_block_descs_ids( num_blocks = dst_num_blocks if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) - return self._fa_descs_ids(np.concatenate(block_ids), num_regions, num_blocks) + block_ids_arr = np.concatenate(block_ids) + region_ids = np.arange(num_regions)[:, None] + return (region_ids * num_blocks + block_ids_arr[None, :]).flatten() def build_remote_descs( self, @@ -494,7 +474,7 @@ def build_remote_descs( for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): - local_block_len = self.get_kv_block_len( + local_block_len = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -514,7 +494,7 @@ def build_remote_descs( addr = base_addr + blk * page_size + rank_offset result.append((addr, local_block_len, dev_id)) if is_blocks_first: - second_split = self.get_kv_block_len( + second_split = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -703,7 +683,10 @@ def get_block_descs_ids( ).flatten() ) else: - all_descs.append(self._fa_descs_ids(group_arr, num_regions, num_blocks)) + region_ids = np.arange(num_regions)[:, None] + all_descs.append( + (region_ids * num_blocks + group_arr[None, :]).flatten() + ) return np.concatenate(all_descs) def build_local_descs( @@ -847,7 +830,7 @@ def _build_fa_remote_descs( for i, base_addr in enumerate( nixl_agent_meta.kv_caches_base_addr, ): - local_block_len = self.get_kv_block_len( + local_block_len = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -857,7 +840,7 @@ def _build_fa_remote_descs( local_block_len = remote_kv_block_len if tp_ratio < 0 and not is_mla: local_block_len = local_block_len // info.remote_num_fa_reads - rank_offset = self._fa_rank_offset( + rank_offset = _fa_rank_offset( info, remote_kv_block_len, tp_rank=tp_rank, @@ -872,7 +855,7 @@ def _build_fa_remote_descs( addr = base_addr + blk * page_size + rank_offset result.append((addr, local_block_len, dev_id)) if is_blocks_first: - second_split = self.get_kv_block_len( + second_split = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, @@ -1070,15 +1053,15 @@ def build_engine_transfer_info( else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] ) else: - local_needs = self._physical_head_range(local_tp, K, local_rank) + local_needs = _physical_head_range(local_tp, K, local_rank) search_range = ( mamba_range if mamba_range is not None else range(remote_tp_size) ) seen: set[tuple[int, int]] = set() fa_source_ranks = [] for p in search_range: - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) if len(ov) > 0: key = (ov.start, ov.stop) if key not in seen: @@ -1086,8 +1069,8 @@ def build_engine_transfer_info( fa_source_ranks.append(p) if not fa_source_ranks: for p in range(remote_tp_size): - p_has = self._physical_head_range(remote_tp_size, K, p) - ov = self._range_overlap(local_needs, p_has) + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) if len(ov) > 0: key = (ov.start, ov.stop) if key not in seen: @@ -1190,11 +1173,9 @@ def compute_split_handle_data( all_handle_data: list[list[tuple[int, int, int]]] = [] for p_idx, p_rank in enumerate(info.remote_all_source_ranks): handle_data: list[tuple[int, int, int]] = [] - skip_fa = self._should_skip_fa(info, p_rank) + skip_fa = _should_skip_fa(info, p_rank) fa_slot = ( - self._fa_head_slot(info, p_rank, total_num_kv_heads) - if not skip_fa - else 0 + _fa_head_slot(info, p_rank, total_num_kv_heads) if not skip_fa else 0 ) for j, (addr, local_len, dev) in enumerate(src_blocks_data): if j < num_fa_descs: @@ -1220,7 +1201,7 @@ def filter_block_ids_for_rank( remote rank carries FA data for this local rank, returns the inputs unchanged. """ - if not self._should_skip_fa(info, remote_rank): + if not _should_skip_fa(info, remote_rank): return local_ids, remote_ids num_groups = len(local_ids) filtered_local: list[list[int]] = [ From 925a9bd4990e83e594c8923491fcaee05b673dde Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 02:15:41 +0000 Subject: [PATCH 24/49] [3/N] extract model-specific block expansion into plan-based logical_to_kernel_block_ids Move model-specific block ID expansion and trimming logic out of worker.py hot paths into a unified logical_to_kernel_block_ids() in transfer_plan.py. Per-request functions now consume only the pre-computed EngineTransferPlan (model-agnostic); model awareness is confined to init and plan generation (if/else dense vs mamba). - Add transfer_plan.py with plan generators, executors, and logical_to_kernel_block_ids (per-group physical_per_logical). - Generate EngineTransferPlan during handshake (generate_dense_plan or generate_mamba_plan), stored in _transfer_plans dict. - Replace _logical_to_remote_kernel_block_ids in _read_blocks_for_req with plan-based logical_to_kernel_block_ids (no model branching). - Make block trimming in _read_blocks unconditional (SSM groups are no-op due to shared block table). - Thin-wrap _logical_to_kernel_block_ids for local expansion, delegating to the same unified function. - Add _conv_decomp to worker init for mamba plan generation. Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 70 +- .../kv_connector/unit/test_transfer_plan.py | 783 +++++++++++++++ .../kv_connector/v1/nixl/transfer_plan.py | 920 ++++++++++++++++++ .../kv_connector/v1/nixl/worker.py | 148 +-- 4 files changed, 1808 insertions(+), 113 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_transfer_plan.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 267505546cda..0a5b491ba586 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -62,28 +62,24 @@ def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): @pytest.mark.cpu_test def test_logical_to_kernel_block_ids_with_hma(): - """Test _logical_to_kernel_block_ids expands blocks when HMA is enabled. + """Test logical_to_kernel_block_ids expands blocks when HMA is enabled. When HMA is enabled, the logical block size may differ from the kernel block size. Each logical block maps to multiple kernel blocks. """ - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( - NixlConnectorWorker, + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + logical_to_kernel_block_ids, ) - # Create a mock worker with just the required attributes - # (use __new__ to skip __init__) - worker = object.__new__(NixlConnectorWorker) - # Simulate HMA scenario: logical block size = 32, kernel block size = 16 # So each logical block maps to 2 kernel blocks eg [0]->[0,1] - worker._physical_blocks_per_logical_kv_block = 2 # FA + SW groups (neither is MambaSpec, so both get expanded) - worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True) - - # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] - kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids) + physical_per_logical = (2, 2) + + kernel_block_ids = logical_to_kernel_block_ids( + logical_block_ids, physical_per_logical + ) expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]] assert kernel_block_ids == expected_kernel_block_ids, ( @@ -93,64 +89,48 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "has_mamba,swa_enabled,mamba_enabled,remote_ratio," - "remote_block_ids,expected_remote_block_ids", + "physical_per_logical,remote_block_ids,expected_remote_block_ids", [ - # Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids. + # Non-mamba (FA+SWA): both groups expanded by ratio 2. # Regression for https://github.com/vllm-project/vllm/pull/39724 ( - False, - True, - False, - 1, + (2, 2), ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], ), - # Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids, - # Mamba passed through unchanged. - # remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using - # the wrong conversion method produces different FA results. + # Mamba (FA+Mamba): FA expanded by ratio 261, Mamba (ratio=1) passthrough. ( - True, - False, - True, - 261, + (261, 1), ([0, 1, 2], [10, 11]), - [[0, 1, 261, 262, 522, 523], [10, 11]], + [ + list(range(0, 261)) + list(range(261, 522)) + list(range(522, 783)), + [10, 11], + ], ), ], ids=["non_mamba_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - has_mamba, - swa_enabled, - mamba_enabled, - remote_ratio, + physical_per_logical, remote_block_ids, expected_remote_block_ids, ): """_read_blocks_for_req must expand remote logical block IDs to kernel - block IDs when kernel block size != logical block size. - - Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded). - Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba - passed through). + block IDs via plan.physical_per_logical (model-agnostic). """ from unittest.mock import MagicMock from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( NixlConnectorMetadata, ) + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) worker = object.__new__(NixlConnectorWorker) - worker._has_mamba = has_mamba - worker._physical_blocks_per_logical_kv_block = 2 - worker.kv_cache_config = make_kv_cache_config( - block_size=16, swa_enabled=swa_enabled, mamba_enabled=mamba_enabled - ) remote_engine_id = "remote-engine" @@ -160,13 +140,17 @@ def test_read_blocks_for_req_expands_remote_ids( worker.transfer_topo.target_remote_ranks.return_value = [] worker.transfer_topo.get_engine_info.return_value = MagicMock( remote_tp_size=1, - remote_physical_blocks_per_logical=remote_ratio if has_mamba else 1, ) worker.transfer_topo.tp_ratio.return_value = 1 worker.transfer_policy = MagicMock() worker.transfer_policy.compute_read_specs.return_value = [] worker.use_mla = False + # Mock the plan with the physical_per_logical tuple + mock_plan = MagicMock(spec=EngineTransferPlan) + mock_plan.physical_per_logical = physical_per_logical + worker._transfer_plans = {remote_engine_id: mock_plan} + metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id="test-req", diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py new file mode 100644 index 000000000000..04c3906434b9 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -0,0 +1,783 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Equivalence tests: plan-based executors vs current ABC policy. + +These tests verify that the new plan-based design produces identical +outputs (descriptor tuples, descriptor IDs, read specs) to the current +ModelBlockTransferPolicy ABC hierarchy. No GPU or NIXL required. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import pytest + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + TransferTopology, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( + DenseModelBlockTransferPolicy, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + RegionKind, + RegionPlan, + build_local_splits_from_plan, + build_remote_descs_from_plan, + compute_desc_ids_from_plan, + compute_read_specs_from_plan, + generate_dense_plan, + visualize_plan, +) + +# ====================================================================== +# Test fixtures / helpers +# ====================================================================== + +ENGINE_ID = "remote_engine" +LOCAL_ENGINE_ID = "local_engine" + + +@dataclass +class FakeNixlAgentMeta: + """Minimal mock of NixlAgentMetadata for testing.""" + + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + device_id: int + num_blocks: int + block_lens: list[int] + kv_cache_layout: str + block_size: int + ssm_sizes: tuple[int, int] + attn_backend_name: str + + +def _make_kv_cache_config( + block_size: int = 16, + num_blocks: int = 256, + num_layers: int = 2, + head_size: int = 128, + num_kv_heads: int = 8, +): + """Create a minimal KVCacheConfig for Dense models.""" + import torch + + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + ) + + spec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=torch.float16, + ) + layers = [f"layer_{i}" for i in range(num_layers)] + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[KVCacheGroupSpec(layers, spec)], + ) + + +def _make_transfer_topo( + tp_rank: int = 0, + tp_size: int = 1, + block_size: int = 16, + is_mla: bool = False, + num_kv_heads: int = 8, +): + """Create a TransferTopology for testing without real attention backend.""" + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + return TransferTopology( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=block_size, + engine_id=LOCAL_ENGINE_ID, + is_mla=is_mla, + is_mamba=False, + total_num_kv_heads=num_kv_heads, + attn_backends=[FlashAttentionBackend], + physical_blocks_per_logical=1, + ) + + +def _common_plan_params( + tp_rank: int = 0, + tp_size: int = 1, + is_mla: bool = False, + num_kv_heads: int = 8, + block_size: int = 16, + is_blocks_first: bool = False, + block_len_per_layer: list[int] | None = None, + remote_tp_size: int = 1, + remote_block_size: int = 16, + remote_num_blocks: int = 256, + remote_block_lens: list[int] | None = None, + remote_physical_blocks_per_logical: int = 1, +) -> dict: + """Build common kwargs for plan generators.""" + if block_len_per_layer is None: + slot_size = num_kv_heads * 128 * 2 # num_heads * head_size * dtype_bytes + block_len_per_layer = [slot_size * block_size] * 2 + if remote_block_lens is None: + remote_block_lens = list(block_len_per_layer) + return dict( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=is_mla, + total_num_kv_heads=num_kv_heads, + is_blocks_first=is_blocks_first, + block_len_per_layer=block_len_per_layer, + block_size=block_size, + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_num_blocks=remote_num_blocks, + remote_block_lens=remote_block_lens, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) + + +def _make_nixl_meta( + base_addrs: list[int], + num_blocks: int, + block_lens: list[int], + device_id: int = 0, + block_size: int = 16, +) -> FakeNixlAgentMeta: + return FakeNixlAgentMeta( + engine_id=ENGINE_ID, + agent_metadata=b"", + kv_caches_base_addr=base_addrs, + device_id=device_id, + num_blocks=num_blocks, + block_lens=block_lens, + kv_cache_layout="HND", + block_size=block_size, + ssm_sizes=(0, 0), + attn_backend_name="FlashAttentionBackend", + ) + + +# ====================================================================== +# Dense equivalence tests +# ====================================================================== + + +class TestDensePlanEquivalence: + """Verify plan-based outputs match current DenseModelBlockTransferPolicy.""" + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [ + (1, 1), # homogeneous + (2, 1), # D_TP > P_TP + (4, 2), # D_TP > P_TP (larger) + (1, 2), # P_TP > D_TP + (2, 4), # P_TP > D_TP (larger) + ], + ) + @pytest.mark.parametrize("tp_rank_frac", [0.0, 0.5]) + def test_build_remote_descs( + self, + tp_size, + remote_tp_size, + tp_rank_frac, + ): + tp_rank = int(tp_rank_frac * (tp_size - 1)) if tp_size > 1 else 0 + num_kv_heads = 8 + head_size = 128 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * head_size * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + # Adjust remote block_lens for hetero TP + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + base_addrs = [0x1000 * (i + 1) for i in range(num_layers)] + + # ---- Old path ---- + kv_config = _make_kv_cache_config( + block_size=block_size, + num_blocks=num_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_size=head_size, + ) + policy = DenseModelBlockTransferPolicy(kv_config, 1) + topo = _make_transfer_topo( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=block_size, + num_kv_heads=num_kv_heads, + ) + is_blocks_first = topo.is_kv_layout_blocks_first + transfer_info = policy.build_engine_transfer_info( + transfer_topo=topo, + local_block_len=block_len_per_layer[0], + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=1, + ) + topo.register_remote_engine(ENGINE_ID, transfer_info) + meta = _make_nixl_meta( + base_addrs, + num_blocks, + remote_block_lens, + block_size=block_size, + ) + old_descs = policy.build_remote_descs( + topo, + ENGINE_ID, + meta, + block_len_per_layer, + ) + + # ---- New path ---- + plan = generate_dense_plan( + **_common_plan_params( + tp_rank=tp_rank, + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + new_descs = build_remote_descs_from_plan(plan, meta) + + assert old_descs == new_descs, ( + f"Descriptor mismatch for tp={tp_size}/{remote_tp_size}, " + f"rank={tp_rank}.\nOld: {old_descs[:5]}...\nNew: {new_descs[:5]}..." + ) + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [ + (1, 1), + (2, 1), + (1, 2), + ], + ) + def test_compute_desc_ids(self, tp_size, remote_tp_size): + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + topo = _make_transfer_topo( + tp_size=tp_size, + block_size=block_size, + num_kv_heads=num_kv_heads, + ) + is_blocks_first = topo.is_kv_layout_blocks_first + + kv_config = _make_kv_cache_config( + block_size=block_size, + num_blocks=num_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + ) + policy = DenseModelBlockTransferPolicy(kv_config, 1) + plan = generate_dense_plan( + **_common_plan_params( + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + num_regions = len(plan.fa_regions) + block_ids = ([1, 5, 10, 20],) + + old_ids = policy.get_block_descs_ids( + block_ids=block_ids, + num_regions=num_regions, + dst_num_blocks=num_blocks, + block_len_per_layer=block_len_per_layer, + ) + new_ids = compute_desc_ids_from_plan( + plan, + block_ids, + dst_num_blocks=num_blocks, + ) + + np.testing.assert_array_equal(old_ids, new_ids) + + @pytest.mark.parametrize( + "tp_size,remote_tp_size", + [ + (1, 1), + (2, 1), + (1, 2), + ], + ) + def test_compute_read_specs(self, tp_size, remote_tp_size): + tp_rank = 0 + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + if tp_size >= remote_tp_size: + tp_ratio = tp_size // remote_tp_size + remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] + else: + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + kv_config = _make_kv_cache_config( + block_size=block_size, + num_blocks=num_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + ) + policy = DenseModelBlockTransferPolicy(kv_config, 1) + topo = _make_transfer_topo( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=block_size, + num_kv_heads=num_kv_heads, + ) + is_blocks_first = topo.is_kv_layout_blocks_first + transfer_info = policy.build_engine_transfer_info( + transfer_topo=topo, + local_block_len=block_len_per_layer[0], + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=1, + ) + topo.register_remote_engine(ENGINE_ID, transfer_info) + remote_ranks = topo.target_remote_ranks(ENGINE_ID) + + plan = generate_dense_plan( + **_common_plan_params( + tp_rank=tp_rank, + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + local_ids = ([1, 2, 3],) + remote_ids = ([4, 5, 6],) + + old_specs = policy.compute_read_specs( + local_ids, + remote_ids, + remote_ranks, + transfer_info, + ) + new_specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + + assert len(old_specs) == len(new_specs) + for old, new in zip(old_specs, new_specs): + assert old.remote_rank == new.remote_rank + assert list(old.local_block_ids[0]) == list(new.local_block_ids[0]) + assert list(old.remote_block_ids[0]) == list(new.remote_block_ids[0]) + + @pytest.mark.parametrize("remote_tp_size", [2, 4]) + def test_build_src_split_handles(self, remote_tp_size): + tp_rank = 0 + tp_size = 1 + num_kv_heads = 8 + block_size = 16 + num_blocks = 64 + num_layers = 2 + slot_size = num_kv_heads * 128 * 2 + block_len = slot_size * block_size + block_len_per_layer = [block_len] * num_layers + + tp_ratio_neg = remote_tp_size // tp_size + remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] + + kv_config = _make_kv_cache_config( + block_size=block_size, + num_blocks=num_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + ) + policy = DenseModelBlockTransferPolicy(kv_config, 1) + topo = _make_transfer_topo( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=block_size, + num_kv_heads=num_kv_heads, + ) + is_blocks_first = topo.is_kv_layout_blocks_first + transfer_info = policy.build_engine_transfer_info( + transfer_topo=topo, + local_block_len=block_len_per_layer[0], + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=1, + ) + topo.register_remote_engine(ENGINE_ID, transfer_info) + + plan = generate_dense_plan( + **_common_plan_params( + tp_rank=tp_rank, + tp_size=tp_size, + num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + block_len_per_layer=block_len_per_layer, + remote_tp_size=remote_tp_size, + remote_block_size=block_size, + remote_num_blocks=num_blocks, + remote_block_lens=remote_block_lens, + ), + ) + + src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)] + num_descs = len(src_blocks_data) + + old_splits = policy.build_src_split_handles( + topo, + ENGINE_ID, + src_blocks_data, + num_descs, + ) + new_splits = build_local_splits_from_plan( + plan, + src_blocks_data, + num_descs, + ) + + assert len(old_splits) == len(new_splits), ( + f"Split count mismatch: {len(old_splits)} vs {len(new_splits)}" + ) + for i, (old, new) in enumerate(zip(old_splits, new_splits)): + assert old == new, f"Split {i} mismatch" + + +class TestDensePlanVisualization: + def test_visualize_produces_output(self): + plan = generate_dense_plan( + **_common_plan_params(), + ) + output = visualize_plan(plan) + assert "FA regions" in output + assert "fa_k" in output + + +class TestDensePlanStructure: + def test_source_ranks_homogeneous(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=2, tp_rank=1, remote_tp_size=2), + ) + assert plan.all_source_ranks == (1,) + + def test_source_ranks_d_gt_p(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=4, tp_rank=2, remote_tp_size=2), + ) + assert plan.all_source_ranks == (1,) + + def test_source_ranks_p_gt_d(self): + plan = generate_dense_plan( + **_common_plan_params(tp_size=1, tp_rank=0, remote_tp_size=2), + ) + assert plan.all_source_ranks == (0, 1) + + def test_no_ssm_regions(self): + plan = generate_dense_plan(**_common_plan_params()) + assert plan.ssm_regions == () + assert plan.is_mamba_group == (False,) + + def test_blocks_first_has_k_and_v(self): + plan = generate_dense_plan( + **_common_plan_params(is_blocks_first=True), + ) + kinds = [r.kind.value for r in plan.fa_regions] + assert "fa_k" in kinds + assert "fa_v" in kinds + + def test_not_blocks_first_has_only_k(self): + plan = generate_dense_plan( + **_common_plan_params(is_blocks_first=False), + ) + kinds = [r.kind.value for r in plan.fa_regions] + assert "fa_k" in kinds + assert "fa_v" not in kinds + + +# ====================================================================== +# Mamba equivalence tests +# ====================================================================== + + +def _make_mamba_plan_for_desc_ids( + num_fa_regions: int, + num_ssm_regions: int, + is_mamba_group: list[bool], + fa_num_blocks: int = 100, + ssm_num_blocks: int = 100, +) -> EngineTransferPlan: + """Build a minimal plan with enough structure for compute_desc_ids.""" + fa_regions = tuple( + RegionPlan( + kind=RegionKind.FA_K, + layer_idx=i, + descriptor_bytes=100, + offset_in_page=0, + page_stride=100, + num_blocks=fa_num_blocks, + physical_per_logical=1, + ) + for i in range(num_fa_regions) + ) + ssm_regions = tuple( + RegionPlan( + kind=RegionKind.SSM_CONV_X, + layer_idx=i % (num_ssm_regions // 4) if num_ssm_regions >= 4 else 0, + descriptor_bytes=50, + offset_in_page=0, + page_stride=200, + num_blocks=ssm_num_blocks, + physical_per_logical=1, + ) + for i in range(num_ssm_regions) + ) + physical_per_logical = tuple(1 if m else 1 for m in is_mamba_group) + return EngineTransferPlan( + fa_regions=fa_regions, + ssm_regions=ssm_regions, + physical_per_logical=physical_per_logical, + is_mamba_group=tuple(is_mamba_group), + all_source_ranks=(0,), + fa_source_ranks=(0,), + fa_source_set=frozenset({0}), + num_fa_reads=1, + num_mamba_reads=1, + fa_head_slots={0: 0}, + remote_tp_size=1, + remote_block_size=16, + remote_block_len=0, + remote_physical_blocks_per_logical=1, + ) + + +class TestMambaPlanDescIds: + """Verify plan-based desc IDs match MambaModelBlockTransferPolicy.""" + + def test_hybrid_ssm_ratio_1(self): + """Equivalent to test_get_block_descs_ids_hybrid_ssm.""" + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, # 4 regions per layer, 1 layer + is_mamba_group=[False, True], + fa_num_blocks=100, + ssm_num_blocks=100, + ) + + fa_blocks = [3, 5] + ssm_blocks = [1, 2] + + result = compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=100, + physical_blocks_per_logical=1, + ) + + expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + def test_kernel_block_mismatch(self): + """Equivalent to test_get_block_descs_ids_kernel_block_mismatch.""" + ratio = 4 + logical_blocks = 100 + num_blocks = logical_blocks * ratio # 400 + + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + is_mamba_group=[False, True], + fa_num_blocks=num_blocks, + ssm_num_blocks=logical_blocks, + ) + + fa_blocks = [3, 7] + ssm_blocks = [1, 2] + + result = compute_desc_ids_from_plan( + plan, + block_ids=(fa_blocks, ssm_blocks), + dst_num_blocks=num_blocks, + physical_blocks_per_logical=ratio, + ) + + expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102] + assert list(result) == expected, f"Expected {expected}, got {list(result)}" + + +class TestMambaPlanReadSpecs: + """Verify plan-based read specs handle FA group filtering correctly.""" + + def test_all_source_ranks_serve_fa(self): + """When all ranks are FA sources, no filtering happens.""" + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=(), + physical_per_logical=(1, 1), + is_mamba_group=(False, True), + all_source_ranks=(0, 1), + fa_source_ranks=(0, 1), + fa_source_set=frozenset({0, 1}), + num_fa_reads=2, + num_mamba_reads=2, + fa_head_slots={0: 0, 1: 1}, + remote_tp_size=2, + remote_block_size=16, + remote_block_len=0, + remote_physical_blocks_per_logical=1, + ) + + local_ids = ([1, 2], [3, 4]) + remote_ids = ([5, 6], [7, 8]) + + specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + assert len(specs) == 2 + for spec in specs: + assert list(spec.local_block_ids[0]) == [1, 2] + assert list(spec.local_block_ids[1]) == [3, 4] + + def test_non_fa_rank_skips_fa_groups(self): + """Ranks not in fa_source_set get FA groups zeroed out.""" + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=(), + physical_per_logical=(1, 1), + is_mamba_group=(False, True), + all_source_ranks=(0, 1, 2), + fa_source_ranks=(0,), + fa_source_set=frozenset({0}), + num_fa_reads=1, + num_mamba_reads=3, + fa_head_slots={0: 0}, + remote_tp_size=3, + remote_block_size=16, + remote_block_len=0, + remote_physical_blocks_per_logical=1, + ) + + local_ids = ([1, 2], [3, 4]) + remote_ids = ([5, 6], [7, 8]) + + specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + assert len(specs) == 3 + + # Rank 0 (FA source): gets all groups + assert list(specs[0].local_block_ids[0]) == [1, 2] + assert list(specs[0].local_block_ids[1]) == [3, 4] + + # Rank 1 (not FA): FA group zeroed, Mamba group preserved + assert specs[1].local_block_ids[0] == [] + assert list(specs[1].local_block_ids[1]) == [3, 4] + + # Rank 2 (not FA): same + assert specs[2].local_block_ids[0] == [] + assert list(specs[2].local_block_ids[1]) == [3, 4] + + +class TestMambaPlanSplitHandles: + """Verify plan-based split handles for Mamba with FA/SSM distinction.""" + + def test_fa_and_ssm_different_split_factors(self): + """FA descs split by num_fa_reads, SSM descs split by abs_tp.""" + plan = EngineTransferPlan( + fa_regions=(), + ssm_regions=( + RegionPlan( + kind=RegionKind.SSM_STATE, + layer_idx=0, + descriptor_bytes=100, + offset_in_page=0, + page_stride=100, + num_blocks=10, + physical_per_logical=1, + ), + ), + physical_per_logical=(1, 1), + is_mamba_group=(False, True), + all_source_ranks=(0, 1), + fa_source_ranks=(0,), + fa_source_set=frozenset({0}), + num_fa_reads=1, + num_mamba_reads=2, + fa_head_slots={0: 0}, + remote_tp_size=2, + remote_block_size=16, + remote_block_len=0, + remote_physical_blocks_per_logical=1, + ) + + # 2 FA descs + 1 SSM desc + src_blocks_data = [ + (1000, 200, 0), # FA desc 0 + (2000, 200, 0), # FA desc 1 + (3000, 400, 0), # SSM desc 0 + ] + num_fa_descs = 2 + + splits = build_local_splits_from_plan(plan, src_blocks_data, num_fa_descs) + + assert len(splits) == 2 # 2 source ranks + + # Rank 0 (FA source, p_idx=0): + # FA: chunk=200//1=200, slot=0 → (1000, 200, 0), (2000, 200, 0) + # SSM: chunk=400//2=200, idx=0 → (3000, 200, 0) + assert splits[0] == [(1000, 200, 0), (2000, 200, 0), (3000, 200, 0)] + + # Rank 1 (not FA source, p_idx=1): + # FA: chunk=200//1=200, slot=0 (skip_fa) → (1000, 200, 0), (2000, 200, 0) + # SSM: chunk=400//2=200, idx=1 → (3200, 200, 0) + assert splits[1] == [(1000, 200, 0), (2000, 200, 0), (3200, 200, 0)] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py new file mode 100644 index 000000000000..c9ccd2a467a7 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -0,0 +1,920 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Plan-based transfer design for NIXL connector. + +Instead of an ABC hierarchy with duplicated Dense/Mamba implementations, +we pre-generate a flat transfer plan per remote engine during handshake. +All downstream operations become generic plan executors with zero model +branching. + +Architecture: + 1. Plan generators (generate_dense_plan, generate_mamba_plan) + — the ONLY model-specific code. + 2. Generic executors (build_remote_descs_from_plan, etc.) + — consume plans without model branching. + 3. Visualization (visualize_plan). +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +from vllm.distributed.kv_transfer.kv_connector.utils import BlockIds +from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, +) + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, + ) + + +# ====================================================================== +# 1. Data structures +# ====================================================================== + + +@dataclass(frozen=True) +class ReadSpec: + """Specification for a single remote block read operation.""" + + remote_rank: int + local_block_ids: BlockIds + remote_block_ids: BlockIds + + +class RegionKind(enum.Enum): + """Descriptor region type. Used for visualization/debugging only; + executors never branch on this value.""" + + FA_K = "fa_k" + FA_V = "fa_v" + SSM_CONV_X = "ssm_conv_x" + SSM_CONV_B = "ssm_conv_b" + SSM_CONV_C = "ssm_conv_c" + SSM_STATE = "ssm_state" + + +@dataclass(frozen=True) +class RegionPlan: + """Pre-computed plan for one descriptor region. + + Everything needed to build NIXL descriptors and compute descriptor + IDs is baked in — no runtime model branching. The executor plugs + in per-rank ``base_addr`` and ``device_id`` from NixlAgentMetadata. + """ + + kind: RegionKind + layer_idx: int + + # Descriptor geometry + descriptor_bytes: int + offset_in_page: int + page_stride: int + num_blocks: int + + # Block ID expansion (HMA / kernel block mismatch) + physical_per_logical: int + + +@dataclass(frozen=True) +class EngineTransferPlan: + """Complete transfer plan for one remote engine. + + Generated once during handshake. Stored alongside (or replacing) + ``EngineTransferInfo`` on ``TransferTopology``. + + Regions are split into ``fa_regions`` and ``ssm_regions`` matching + the descriptor handle layout: [FA descriptors | SSM descriptors]. + ``is_mamba_group`` maps kv_cache_groups to the correct section. + """ + + # Regions in descriptor handle order + fa_regions: tuple[RegionPlan, ...] + ssm_regions: tuple[RegionPlan, ...] + + # Per-group geometric properties (worker-facing, model-agnostic) + physical_per_logical: tuple[int, ...] + + # kv_cache_group mapping (internal to transfer_plan, worker should not use) + is_mamba_group: tuple[bool, ...] + + # Source rank routing + all_source_ranks: tuple[int, ...] + fa_source_ranks: tuple[int, ...] + fa_source_set: frozenset[int] + + # Split handle parameters + num_fa_reads: int + num_mamba_reads: int + fa_head_slots: dict[int, int] + + # Remote engine facts (needed by worker at read time) + remote_tp_size: int + remote_block_size: int + remote_block_len: int + remote_physical_blocks_per_logical: int + + @property + def all_regions(self) -> tuple[RegionPlan, ...]: + return self.fa_regions + self.ssm_regions + + +# ====================================================================== +# 2. Internal helpers +# ====================================================================== + + +def _get_kv_block_len( + layer_idx: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> int: + if is_blocks_first: + return block_len_per_layer[layer_idx] // 2 + return block_len_per_layer[layer_idx] + + +def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + + +def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) + + +def _compute_tp_ratio(tp_size: int, remote_tp_size: int) -> int: + if tp_size >= remote_tp_size: + assert tp_size % remote_tp_size == 0 + return tp_size // remote_tp_size + assert remote_tp_size % tp_size == 0 + return -(remote_tp_size // tp_size) + + +def _compute_fa_source_ranks( + tp_rank: int, + tp_size: int, + remote_tp_size: int, + is_mla: bool, + total_num_kv_heads: int, +) -> tuple[list[int], list[int], int, int]: + """Compute FA and all source ranks for Mamba models. + + Returns (fa_source_ranks, all_source_ranks, num_fa_reads, num_mamba_reads). + Mirrors the logic in MambaModelBlockTransferPolicy.build_engine_transfer_info. + """ + K = total_num_kv_heads + tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) + abs_tp = -tp_ratio if tp_ratio < 0 else 1 + mamba_range: range | None = None + if tp_ratio < 0: + mamba_range = range(tp_rank * abs_tp, (tp_rank + 1) * abs_tp) + + fa_source_ranks: list[int] + if is_mla or tp_ratio >= 0: + num_fa_reads = 1 + if is_mla: + fa_source_ranks = [0] + elif tp_ratio > 0: + fa_source_ranks = [tp_rank // tp_ratio] + else: + fa_source_ranks = [tp_rank] + else: + local_needs = _physical_head_range(tp_size, K, tp_rank) + search_range = mamba_range if mamba_range is not None else range(remote_tp_size) + seen: set[tuple[int, int]] = set() + fa_source_ranks = [] + for p in search_range: + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + if not fa_source_ranks: + for p in range(remote_tp_size): + p_has = _physical_head_range(remote_tp_size, K, p) + ov = _range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + num_fa_reads = len(fa_source_ranks) + + if mamba_range is not None and abs_tp > num_fa_reads: + num_mamba_reads = abs_tp + all_source_ranks = list(mamba_range) + else: + num_mamba_reads = num_fa_reads + all_source_ranks = list(fa_source_ranks) + + return fa_source_ranks, all_source_ranks, num_fa_reads, num_mamba_reads + + +def _compute_fa_head_slots( + fa_source_ranks: list[int], + all_source_ranks: list[int], + remote_tp_size: int, + total_num_kv_heads: int, +) -> dict[int, int]: + """Pre-compute the FA head slot for each source rank. + + Mirrors _fa_head_slot from block_transfer_policy.py but pre-computes + all values at plan generation time. + """ + fa_index = {r: i for i, r in enumerate(fa_source_ranks)} + K = total_num_kv_heads + result: dict[int, int] = {} + for rank in all_source_ranks: + if rank in fa_index: + result[rank] = fa_index[rank] + else: + r_head = _physical_head_range(remote_tp_size, K, rank) + for target in fa_source_ranks: + t_head = _physical_head_range(remote_tp_size, K, target) + if _range_overlap(r_head, t_head): + result[rank] = fa_index[target] + break + else: + result[rank] = 0 + return result + + +def _compute_fa_rank_offset( + tp_rank: int, + tp_size: int, + tp_ratio: int, + is_mla: bool, + total_num_kv_heads: int, + remote_tp_size: int, + fa_source_ranks: list[int], + remote_kv_block_len: int, +) -> int: + """Byte offset into remote FA block for this local rank. + + Mirrors _fa_rank_offset from block_transfer_policy.py, but takes + raw parameters instead of MambaEngineTransferInfo. + """ + if is_mla or tp_ratio <= 0: + return 0 + K = total_num_kv_heads + is_local_replicated = tp_size > K + if is_local_replicated: + local_head = tp_rank * K // tp_size + p_rank = fa_source_ranks[0] + p_start = p_rank * K // remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return tp_rank % tp_ratio * remote_kv_block_len + + +# ====================================================================== +# 3. Plan generators — the ONLY model-specific code +# ====================================================================== + + +def generate_dense_plan( + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_blocks_first: bool, + block_len_per_layer: list[int], + block_size: int, + remote_tp_size: int, + remote_block_size: int, + remote_num_blocks: int, + remote_block_lens: list[int], + remote_physical_blocks_per_logical: int, +) -> EngineTransferPlan: + """Generate transfer plan for dense (FA-only) models. + + Mirrors the combined logic of: + - DenseModelBlockTransferPolicy.build_engine_transfer_info() + - DenseModelBlockTransferPolicy.build_remote_descs() + """ + tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) + block_size_ratio = block_size // remote_block_size + indexes_into_remote = ( + not (is_mla or remote_tp_size > total_num_kv_heads) and tp_ratio > 0 + ) + + # Source ranks — mirrors TransferTopology.target_remote_ranks for dense + if tp_ratio > 0: + all_source_ranks: tuple[int, ...] = (tp_rank // tp_ratio,) + else: + abs_ratio = -tp_ratio + all_source_ranks = tuple(tp_rank * abs_ratio + i for i in range(abs_ratio)) + + # Build FA regions — one (K, optionally V) per layer + fa_regions: list[RegionPlan] = [] + for i in range(len(remote_block_lens)): + local_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + remote_kv_block_len = local_block_len // block_size_ratio + + k_desc_bytes = local_block_len + if block_size_ratio > 1: + k_desc_bytes = remote_kv_block_len + if tp_ratio < 0 and not is_mla: + k_desc_bytes = k_desc_bytes // (-tp_ratio) + + rank_offset = ( + tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 + ) + page_stride = remote_block_lens[i] + + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_K, + layer_idx=i, + descriptor_bytes=k_desc_bytes, + offset_in_page=rank_offset, + page_stride=page_stride, + num_blocks=remote_num_blocks, + physical_per_logical=remote_physical_blocks_per_logical, + ) + ) + + if is_blocks_first: + v_desc_bytes = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + if tp_ratio < 0 and not is_mla: + v_desc_bytes = v_desc_bytes // (-tp_ratio) + + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_V, + layer_idx=i, + descriptor_bytes=v_desc_bytes, + offset_in_page=rank_offset + page_stride // 2, + page_stride=page_stride, + num_blocks=remote_num_blocks, + physical_per_logical=remote_physical_blocks_per_logical, + ) + ) + + # For dense split handles: fa_head_slots maps rank → index, + # so the executor uniformly splits all descs by abs_tp. + fa_head_slots = {r: i for i, r in enumerate(all_source_ranks)} + + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=(), + physical_per_logical=(remote_physical_blocks_per_logical,), + is_mamba_group=(False,), + all_source_ranks=all_source_ranks, + fa_source_ranks=all_source_ranks, + fa_source_set=frozenset(all_source_ranks), + num_fa_reads=len(all_source_ranks), + num_mamba_reads=0, + fa_head_slots=fa_head_slots, + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) + + +def generate_mamba_plan( + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_blocks_first: bool, + block_len_per_layer: list[int], + block_size: int, + remote_tp_size: int, + remote_block_size: int, + remote_num_blocks: int, + remote_block_lens: list[int], + remote_physical_blocks_per_logical: int, + is_mamba_group: list[bool], + conv_decomp: MambaConvSplitInfo, + ssm_sizes: tuple[int, int], + remote_ssm_sizes: tuple[int, int], +) -> EngineTransferPlan: + """Generate transfer plan for hybrid Mamba (SSM + FA) models. + + Mirrors the combined logic of: + - MambaModelBlockTransferPolicy.build_engine_transfer_info() + - MambaModelBlockTransferPolicy._build_fa_remote_descs() + - MambaModelBlockTransferPolicy._build_mamba_remote_descs() + """ + tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) + block_size_ratio = block_size // remote_block_size + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + + # ---- Source rank computation ---- + ( + fa_source_ranks, + all_source_ranks, + num_fa_reads, + num_mamba_reads, + ) = _compute_fa_source_ranks( + tp_rank, + tp_size, + remote_tp_size, + is_mla, + total_num_kv_heads, + ) + + # ---- FA head slots (for split handles) ---- + fa_head_slots = _compute_fa_head_slots( + fa_source_ranks, + all_source_ranks, + remote_tp_size, + total_num_kv_heads, + ) + + # ---- FA regions ---- + fa_regions: list[RegionPlan] = [] + for i in range(len(remote_block_lens)): + local_block_len = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + remote_kv_block_len = local_block_len // block_size_ratio + + k_desc_bytes = local_block_len + if block_size_ratio > 1: + k_desc_bytes = remote_kv_block_len + if tp_ratio < 0 and not is_mla: + k_desc_bytes = k_desc_bytes // num_fa_reads + + rank_offset = _compute_fa_rank_offset( + tp_rank, + tp_size, + tp_ratio, + is_mla, + total_num_kv_heads, + remote_tp_size, + fa_source_ranks, + remote_kv_block_len, + ) + + page_stride = remote_block_lens[i] + + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_K, + layer_idx=i, + descriptor_bytes=k_desc_bytes, + offset_in_page=rank_offset, + page_stride=page_stride, + num_blocks=remote_num_blocks, + physical_per_logical=remote_physical_blocks_per_logical, + ) + ) + + if is_blocks_first: + v_desc_bytes = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + if tp_ratio < 0 and not is_mla: + v_desc_bytes = v_desc_bytes // num_fa_reads + + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_V, + layer_idx=i, + descriptor_bytes=v_desc_bytes, + offset_in_page=rank_offset + page_stride // 2, + page_stride=page_stride, + num_blocks=remote_num_blocks, + physical_per_logical=remote_physical_blocks_per_logical, + ) + ) + + # ---- SSM regions ---- + effective_ratio = max(tp_ratio, 1) + local_offset = tp_rank % effective_ratio + conv_size_remote = remote_ssm_sizes[0] + remote_ratio = remote_physical_blocks_per_logical + ssm_num_blocks = remote_num_blocks // remote_ratio + + if tp_ratio >= 1: + conv_offsets = conv_decomp.remote_conv_offsets( + local_offset, + effective_ratio, + ) + ssm_read_size = ssm_sizes[1] + else: + abs_ratio = -tp_ratio + xb_p = conv_decomp.x_bytes // abs_ratio + bb_p = conv_decomp.b_bytes // abs_ratio + conv_offsets = [ + (0, xb_p), + (xb_p, bb_p), + (xb_p + bb_p, bb_p), + ] + ssm_read_size = remote_ssm_sizes[1] + + conv_kinds = [RegionKind.SSM_CONV_X, RegionKind.SSM_CONV_B, RegionKind.SSM_CONV_C] + ssm_regions: list[RegionPlan] = [] + for i in range(len(remote_block_lens)): + page_stride = remote_block_lens[i] * remote_ratio + + for kind, (off, sz) in zip(conv_kinds, conv_offsets): + ssm_regions.append( + RegionPlan( + kind=kind, + layer_idx=i, + descriptor_bytes=sz, + offset_in_page=off, + page_stride=page_stride, + num_blocks=ssm_num_blocks, + physical_per_logical=1, + ) + ) + + ssm_regions.append( + RegionPlan( + kind=RegionKind.SSM_STATE, + layer_idx=i, + descriptor_bytes=ssm_read_size, + offset_in_page=conv_size_remote + local_offset * ssm_read_size, + page_stride=page_stride, + num_blocks=ssm_num_blocks, + physical_per_logical=1, + ) + ) + + physical_per_logical_per_group = tuple( + 1 if m else remote_physical_blocks_per_logical for m in is_mamba_group + ) + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=tuple(ssm_regions), + physical_per_logical=physical_per_logical_per_group, + is_mamba_group=tuple(is_mamba_group), + all_source_ranks=tuple(all_source_ranks), + fa_source_ranks=tuple(fa_source_ranks), + fa_source_set=frozenset(fa_source_ranks), + num_fa_reads=num_fa_reads, + num_mamba_reads=num_mamba_reads, + fa_head_slots=fa_head_slots, + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) + + +# ====================================================================== +# 4. Generic executors — identical for ALL models +# ====================================================================== + + +def logical_to_kernel_block_ids( + block_ids: BlockIds, + physical_per_logical: tuple[int, ...], +) -> BlockIds: + """Convert logical block IDs to kernel-level physical block IDs. + + Each group has its own ratio in ``physical_per_logical``. + Groups with ratio == 1 are passed through unchanged. + """ + if all(r == 1 for r in physical_per_logical): + return block_ids + result: list[list[int]] = [] + for i, group in enumerate(block_ids): + ratio = physical_per_logical[i] + if ratio == 1: + result.append(group) + else: + arr = np.array(group).reshape(-1, 1) + arange = np.arange(ratio).reshape(1, -1) + result.append((arr * ratio + arange).flatten().tolist()) + return result + + +def build_remote_descs_from_plan( + plan: EngineTransferPlan, + nixl_agent_meta: NixlAgentMetadata, +) -> list[tuple[int, int, int]]: + """Build (addr, len, dev_id) descriptor tuples from plan. + + Replaces DenseModelBlockTransferPolicy.build_remote_descs() and + MambaModelBlockTransferPolicy.build_remote_descs(). + """ + result: list[tuple[int, int, int]] = [] + dev_id = nixl_agent_meta.device_id + + for region in plan.all_regions: + base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] + for blk in range(region.num_blocks): + addr = base_addr + blk * region.page_stride + region.offset_in_page + result.append((addr, region.descriptor_bytes, dev_id)) + + return result + + +def compute_desc_ids_from_plan( + plan: EngineTransferPlan, + block_ids: BlockIds, + dst_num_blocks: int, + block_size_ratio: float | None = None, + physical_blocks_per_logical: int = 1, +) -> np.ndarray: + """Compute NIXL descriptor IDs for given block IDs. + + Replaces DenseModelBlockTransferPolicy.get_block_descs_ids() and + MambaModelBlockTransferPolicy.get_block_descs_ids(). + """ + num_fa_regions = len(plan.fa_regions) + num_ssm_regions = len(plan.ssm_regions) + + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + ratio = physical_blocks_per_logical + logical_blocks = num_blocks // ratio + + num_fa_descs = num_fa_regions * num_blocks + + all_descs: list[np.ndarray] = [] + for i, group in enumerate(block_ids): + group_arr = np.asarray(group) + if plan.is_mamba_group[i]: + ssm_region_ids = np.arange(num_ssm_regions)[:, None] + all_descs.append( + ( + ssm_region_ids * logical_blocks + group_arr[None, :] + num_fa_descs + ).flatten() + ) + else: + fa_region_ids = np.arange(num_fa_regions)[:, None] + all_descs.append( + (fa_region_ids * num_blocks + group_arr[None, :]).flatten() + ) + + return np.concatenate(all_descs) + + +def compute_read_specs_from_plan( + plan: EngineTransferPlan, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, +) -> list[ReadSpec]: + """Compute read specs from plan. + + Replaces compute_read_specs() + filter_block_ids_for_rank(). + No _should_skip_fa — the plan structurally encodes which ranks + serve which groups via fa_source_set. + """ + specs: list[ReadSpec] = [] + for rank in plan.all_source_ranks: + skip_fa = rank not in plan.fa_source_set + if not skip_fa: + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + ) + ) + else: + num_groups = len(local_block_ids) + filtered_local: list[list[int]] = [ + [] if not plan.is_mamba_group[g] else list(local_block_ids[g]) + for g in range(num_groups) + ] + filtered_remote: list[list[int]] = [ + [] if not plan.is_mamba_group[g] else list(remote_block_ids[g]) + for g in range(num_groups) + ] + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=filtered_local, + remote_block_ids=filtered_remote, + ) + ) + return specs + + +def build_local_splits_from_plan( + plan: EngineTransferPlan, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, +) -> list[list[tuple[int, int, int]]]: + """Build split handle data for P_TP > D_TP scenario. + + Replaces DenseModelBlockTransferPolicy.build_src_split_handles() and + MambaModelBlockTransferPolicy.build_src_split_handles() + + compute_split_handle_data(). + + When num_ssm_regions == 0 (dense), all descs are FA and the split + is uniform. When SSM regions exist, FA and SSM descs get different + split factors. + """ + abs_tp = len(plan.all_source_ranks) + result: list[list[tuple[int, int, int]]] = [] + + for p_idx, p_rank in enumerate(plan.all_source_ranks): + skip_fa = p_rank not in plan.fa_source_set + fa_slot = plan.fa_head_slots.get(p_rank, 0) if not skip_fa else 0 + + handle: list[tuple[int, int, int]] = [] + for j, (addr, local_len, dev) in enumerate(src_blocks_data): + if j < num_fa_descs: + assert plan.num_fa_reads >= 1 + fa_chunk = local_len // plan.num_fa_reads + handle.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) + else: + mamba_chunk = local_len // abs_tp + handle.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) + result.append(handle) + + return result + + +# ====================================================================== +# 5. Local descriptor building (no plan needed — purely local geometry) +# ====================================================================== + + +def build_fa_local_descs( + base_addresses: list[int], + device_id: int, + num_blocks: int, + block_size_ratio: int, + block_len_per_layer: list[int], + is_blocks_first: bool, +) -> list[tuple[int, int, int]]: + """Build FA local descriptors for NIXL registration.""" + result: list[tuple[int, int, int]] = [] + n_blocks = num_blocks * block_size_ratio + for i, base_addr in enumerate(base_addresses): + kv_block_len = ( + _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + // block_size_ratio + ) + page_stride = block_len_per_layer[i] // block_size_ratio + for block_id in range(n_blocks): + result.append( + ( + base_addr + block_id * page_stride, + kv_block_len, + device_id, + ) + ) + if is_blocks_first: + second_split = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + for block_id in range(n_blocks): + v_addr = base_addr + block_id * page_stride + kv_block_len + result.append((v_addr, second_split, device_id)) + return result + + +def build_mamba_local_descs( + base_addresses: list[int], + block_len_per_layer: list[int], + logical_num_blocks: int, + block_size_ratio: int, + device_id: int, + conv_decomp: MambaConvSplitInfo, + ssm_sizes: tuple[int, int], + physical_blocks_per_logical: int, +) -> list[tuple[int, int, int]]: + """Build 4 SSM descriptor regions (x, B, C, ssm) per layer.""" + assert block_size_ratio == 1, ( + "Mamba 3-read transfer with block_size_ratio != 1 " + f"is not tested. Got {block_size_ratio=}." + ) + conv_offsets = conv_decomp.local_conv_offsets + conv_size, ssm_size = ssm_sizes + n_blocks = logical_num_blocks * block_size_ratio + phys_ratio = physical_blocks_per_logical + + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio + for off, sz in conv_offsets: + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + off, + sz, + device_id, + ) + ) + for blk in range(n_blocks): + result.append( + ( + base_addr + blk * page_stride + conv_size, + ssm_size, + device_id, + ) + ) + return result + + +def build_local_descs( + *, + has_mamba: bool, + conv_decomp: MambaConvSplitInfo | None, + ssm_sizes: tuple[int, int], + base_addresses: list[int], + device_id: int, + num_blocks: int, + logical_num_blocks: int, + block_size_ratio: int, + block_len_per_layer: list[int], + is_blocks_first: bool, + physical_blocks_per_logical: int = 1, +) -> list[tuple[int, int, int]]: + """Build local (src) descriptor tuples for NIXL registration.""" + fa_descs = build_fa_local_descs( + base_addresses, + device_id, + num_blocks, + block_size_ratio, + block_len_per_layer, + is_blocks_first, + ) + if not has_mamba: + return fa_descs + assert conv_decomp is not None + mamba_descs = build_mamba_local_descs( + base_addresses, + block_len_per_layer, + logical_num_blocks, + block_size_ratio, + device_id, + conv_decomp, + ssm_sizes, + physical_blocks_per_logical, + ) + return fa_descs + mamba_descs + + +# ====================================================================== +# 6. Visualization +# ====================================================================== + + +def visualize_plan(plan: EngineTransferPlan) -> str: + """Human-readable transfer plan for logging and debugging.""" + lines = [ + f"EngineTransferPlan(remote_tp={plan.remote_tp_size}, " + f"remote_bs={plan.remote_block_size}):", + f" Source ranks: all={list(plan.all_source_ranks)}, " + f"fa={list(plan.fa_source_ranks)}", + ] + total_descs = 0 + + if plan.fa_regions: + lines.append(f" FA regions ({len(plan.fa_regions)}):") + for idx, r in enumerate(plan.fa_regions): + ratio_str = ( + f", p/l={r.physical_per_logical}" if r.physical_per_logical > 1 else "" + ) + lines.append( + f" [{idx}] {r.kind.value:12s} L{r.layer_idx} " + f"{r.descriptor_bytes:6d}B x {r.num_blocks:4d} blks " + f"stride={r.page_stride:6d} " + f"off={r.offset_in_page:6d}" + f"{ratio_str}" + ) + total_descs += r.num_blocks + + if plan.ssm_regions: + lines.append(f" SSM regions ({len(plan.ssm_regions)}):") + for idx, r in enumerate(plan.ssm_regions): + lines.append( + f" [{idx}] {r.kind.value:12s} L{r.layer_idx} " + f"{r.descriptor_bytes:6d}B x {r.num_blocks:4d} blks " + f"stride={r.page_stride:6d} " + f"off={r.offset_in_page:6d}" + ) + total_descs += r.num_blocks + + lines.append(f" Groups: {['SSM' if m else 'FA' for m in plan.is_mamba_group]}") + lines.append(f" Total descriptors: {total_descs}") + return "\n".join(lines) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 1d92185be544..a760ee96134f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -45,12 +45,20 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import ( NixlKVConnectorStats, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + EngineTransferPlan, + generate_dense_plan, + generate_mamba_plan, + logical_to_kernel_block_ids, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( _NIXL_SUPPORTED_DEVICE, zmq_ctx, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( + MambaConvSplitInfo, compute_physical_blocks_per_logical, + derive_mamba_conv_split, ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.distributed.parallel_state import ( @@ -66,7 +74,6 @@ MambaSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: @@ -124,9 +131,18 @@ def __init__( for group in kv_cache_config.kv_cache_groups ] mamba_ssm_size = (0, 0) + self._conv_decomp: MambaConvSplitInfo | None = None self._has_mamba = any(self._is_mamba_group) if self._has_mamba: assert self._is_hma_required + from vllm.model_executor.layers.mamba.mamba_utils import ( + is_conv_state_dim_first, + ) + + assert is_conv_state_dim_first(), ( + "3-read Mamba conv transfer requires DS conv state layout. " + "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" + ) mamba_spec = next( spec for spec in self._layer_specs.values() @@ -144,6 +160,10 @@ def __init__( conv_shape.numel() * conv_nbytes, ssm_shape.numel() * ssm_nbytes, ) + self._conv_decomp = derive_mamba_conv_split( + mamba_spec, + vllm_config.parallel_config.tensor_parallel_size, + ) self._mamba_ssm_size = mamba_ssm_size # Agent. @@ -317,6 +337,10 @@ def __init__( tp_size=vllm_config.parallel_config.tensor_parallel_size, ) + # Per-engine transfer plans. Generated during handshake, used by + # per-request hot path (model-agnostic). + self._transfer_plans: dict[EngineId, EngineTransferPlan] = {} + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) @@ -971,6 +995,37 @@ def add_remote_agent( transfer_topo.register_remote_engine(engine_id, transfer_info) logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) + # Generate the pre-computed transfer plan for this remote engine. + # Plan generation is model-aware (if/else), but the per-request + # hot path only consumes the plan (model-agnostic). + plan_common = dict( + tp_rank=self.tp_rank, + tp_size=self.world_size, + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + block_len_per_layer=self.block_len_per_layer, + block_size=self.block_size, + remote_tp_size=remote_tp_size, + remote_block_size=nixl_agent_meta.block_size, + remote_num_blocks=nixl_agent_meta.num_blocks, + remote_block_lens=nixl_agent_meta.block_lens, + remote_physical_blocks_per_logical=physical_blocks_per_logical, + ) + if self._has_mamba: + assert self._conv_decomp is not None + self._transfer_plans[engine_id] = generate_mamba_plan( + **plan_common, + is_mamba_group=self._is_mamba_group, + conv_decomp=self._conv_decomp, + ssm_sizes=self._mamba_ssm_size, + remote_ssm_sizes=nixl_agent_meta.ssm_sizes, + ) + else: + self._transfer_plans[engine_id] = generate_dense_plan( + **plan_common, + ) + remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) @@ -1572,18 +1627,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) - # TODO (ZhanqiuHu): Unify logical_to_kernel_block_ids - # and logical_to_remote_kernel_block_ids. - if self._has_mamba: - # Expand remote logical → kernel block IDs. - meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( - meta.remote.block_ids, - remote_info.remote_physical_blocks_per_logical, - ) - else: - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids - ) + plan = self._transfer_plans[engine_id] + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids, plan.physical_per_logical + ) remote_block_ids = meta.remote.block_ids read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, @@ -1734,16 +1781,13 @@ def _read_blocks( == len(local_block_ids) == len(self.kv_cache_config.kv_cache_groups) ) + # Partial prefix cache hit: trim remote blocks to match local count. + # SSM groups share the block table so counts always match (no-op trim). remote_block_ids = list(remote_block_ids) for i, remote_group in enumerate(remote_block_ids): - num_remote_blocks = len(remote_group) num_local_blocks = len(local_block_ids[i]) - if not self._is_mamba_group[i]: - assert num_local_blocks <= num_remote_blocks - # Partial prefix cache hit: just read uncomputed blocks. - # Skip mamba groups — their blocks represent full state (conv+ssm), - # not per-token data, so trimming would corrupt the transfer. - if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: + assert num_local_blocks <= len(remote_group) + if num_local_blocks < len(remote_group): remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -1826,61 +1870,25 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) - def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: - """ - Convert logical block ids to kernel physical block ids. - This is required when the logical block size (the one set by the user) - does not match the one required by the attn backend. - """ - if self._physical_blocks_per_logical_kv_block == 1: - # Noop when physical and logical block sizes are the same - return block_ids - block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( - 1, -1 - ) - # Mamba blocks have no logical<>physical discrepancy - group_specs = self.kv_cache_config.kv_cache_groups - return [ - BlockTable.map_to_kernel_blocks( - np.array(group), - self._physical_blocks_per_logical_kv_block, - block_arange, - ).tolist() - if not isinstance(group_specs[i].kv_cache_spec, MambaSpec) - else group - for i, group in enumerate(block_ids) - ] - - def _logical_to_remote_kernel_block_ids( - self, block_ids: BlockIds, remote_physical_per_logical: int + def _logical_to_kernel_block_ids( + self, + block_ids: BlockIds, + physical_per_logical: tuple[int, ...] | None = None, ) -> BlockIds: - """Map logical block IDs to physical kernel block IDs on the remote. + """Convert logical block IDs to kernel physical block IDs. Args: block_ids: per-group lists of logical block IDs. - remote_physical_per_logical: remote engine's physical blocks - per logical block. - - Returns: - Same structure with FA groups expanded (each logical block L - becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]). - Mamba groups are passed through unchanged. + physical_per_logical: per-group expansion ratios. When *None* + (local expansion), FA groups use the local kernel ratio + and Mamba groups are 1:1. """ - local_ratio = self._physical_blocks_per_logical_kv_block - if remote_physical_per_logical == 1: - return block_ids - local_arange = np.arange(local_ratio).reshape(1, -1) - group_specs = self.kv_cache_config.kv_cache_groups - result: list[list[int]] = [] - for i, group in enumerate(block_ids): - if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): - arr = np.array(group).reshape(-1, 1) - expanded = (arr * remote_physical_per_logical + local_arange).flatten() - result.append(expanded.tolist()) - else: - # Mamba blocks are 1:1 logical-to-physical (no expansion). - result.append(group) - return result + if physical_per_logical is None: + blk_ratio = self._physical_blocks_per_logical_kv_block + physical_per_logical = tuple( + 1 if m else blk_ratio for m in self._is_mamba_group + ) + return logical_to_kernel_block_ids(block_ids, physical_per_logical) def get_kv_connector_stats(self) -> KVConnectorStats | None: """ From 7bca1cab74347d7e79560cdc237b55638e00fbfa Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 02:35:57 +0000 Subject: [PATCH 25/49] restore _logical_to_remote_kernel_block_ids as separate method The unified logical_to_kernel_block_ids had two bugs: 1. Dense remote: used ratio=1 instead of local kernel ratio 2. Mamba FA remote: used same value for stride and count, but old code used remote_ratio as stride, local_ratio as count Restore the original two-method design: - _logical_to_kernel_block_ids: local expansion (same as main) - _logical_to_remote_kernel_block_ids: remote mamba expansion (same as main) Only difference from main: remote_ratio comes from plan.remote_physical_blocks_per_logical instead of self._mamba_phys_ratio[engine_id]. Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 56 ++++++++++------ .../kv_connector/v1/nixl/worker.py | 66 ++++++++++++++----- 2 files changed, 83 insertions(+), 39 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 0a5b491ba586..555cf7ee8f71 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -62,24 +62,25 @@ def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): @pytest.mark.cpu_test def test_logical_to_kernel_block_ids_with_hma(): - """Test logical_to_kernel_block_ids expands blocks when HMA is enabled. + """Test _logical_to_kernel_block_ids expands blocks when HMA is enabled. When HMA is enabled, the logical block size may differ from the kernel block size. Each logical block maps to multiple kernel blocks. """ - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( - logical_to_kernel_block_ids, + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, ) + worker = object.__new__(NixlConnectorWorker) + # Simulate HMA scenario: logical block size = 32, kernel block size = 16 # So each logical block maps to 2 kernel blocks eg [0]->[0,1] + worker._physical_blocks_per_logical_kv_block = 2 # FA + SW groups (neither is MambaSpec, so both get expanded) - logical_block_ids = [[0, 1, 2], [3, 4]] - physical_per_logical = (2, 2) + worker._is_mamba_group = [False, False] - kernel_block_ids = logical_to_kernel_block_ids( - logical_block_ids, physical_per_logical - ) + logical_block_ids = [[0, 1, 2], [3, 4]] + kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids) expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]] assert kernel_block_ids == expected_kernel_block_ids, ( @@ -89,34 +90,44 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "physical_per_logical,remote_block_ids,expected_remote_block_ids", + "has_mamba,is_mamba_group,remote_ratio,remote_block_ids,expected_remote_block_ids", [ - # Non-mamba (FA+SWA): both groups expanded by ratio 2. + # Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids. # Regression for https://github.com/vllm-project/vllm/pull/39724 ( - (2, 2), + False, + [False, False], + 1, ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], ), - # Mamba (FA+Mamba): FA expanded by ratio 261, Mamba (ratio=1) passthrough. + # Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids, + # Mamba passed through unchanged. + # remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using + # the wrong conversion method produces different FA results. ( - (261, 1), + True, + [False, True], + 261, ([0, 1, 2], [10, 11]), - [ - list(range(0, 261)) + list(range(261, 522)) + list(range(522, 783)), - [10, 11], - ], + [[0, 1, 261, 262, 522, 523], [10, 11]], ), ], ids=["non_mamba_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - physical_per_logical, + has_mamba, + is_mamba_group, + remote_ratio, remote_block_ids, expected_remote_block_ids, ): """_read_blocks_for_req must expand remote logical block IDs to kernel - block IDs via plan.physical_per_logical (model-agnostic). + block IDs when kernel block size != logical block size. + + Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded). + Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba + passed through). """ from unittest.mock import MagicMock @@ -131,6 +142,9 @@ def test_read_blocks_for_req_expands_remote_ids( ) worker = object.__new__(NixlConnectorWorker) + worker._has_mamba = has_mamba + worker._physical_blocks_per_logical_kv_block = 2 + worker._is_mamba_group = is_mamba_group remote_engine_id = "remote-engine" @@ -146,9 +160,9 @@ def test_read_blocks_for_req_expands_remote_ids( worker.transfer_policy.compute_read_specs.return_value = [] worker.use_mla = False - # Mock the plan with the physical_per_logical tuple + # Mock the plan with the remote physical blocks ratio mock_plan = MagicMock(spec=EngineTransferPlan) - mock_plan.physical_per_logical = physical_per_logical + mock_plan.remote_physical_blocks_per_logical = remote_ratio worker._transfer_plans = {remote_engine_id: mock_plan} metadata = NixlConnectorMetadata() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a760ee96134f..12ed7d1c779a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -49,7 +49,6 @@ EngineTransferPlan, generate_dense_plan, generate_mamba_plan, - logical_to_kernel_block_ids, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( _NIXL_SUPPORTED_DEVICE, @@ -1628,9 +1627,15 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) plan = self._transfer_plans[engine_id] - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids, plan.physical_per_logical - ) + if self._has_mamba: + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids, + plan.remote_physical_blocks_per_logical, + ) + else: + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids + ) remote_block_ids = meta.remote.block_ids read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, @@ -1870,25 +1875,50 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) - def _logical_to_kernel_block_ids( - self, - block_ids: BlockIds, - physical_per_logical: tuple[int, ...] | None = None, - ) -> BlockIds: + def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: """Convert logical block IDs to kernel physical block IDs. + Required when the logical block size (set by the user) does not match + the one required by the attention backend. + """ + if self._physical_blocks_per_logical_kv_block == 1: + return block_ids + ratio = self._physical_blocks_per_logical_kv_block + arange = np.arange(ratio).reshape(1, -1) + return [ + (np.array(group).reshape(-1, 1) * ratio + arange).flatten().tolist() + if not self._is_mamba_group[i] + else group + for i, group in enumerate(block_ids) + ] + + def _logical_to_remote_kernel_block_ids( + self, block_ids: BlockIds, remote_ratio: int + ) -> BlockIds: + """Map logical block IDs to physical kernel block IDs on the remote. + Args: block_ids: per-group lists of logical block IDs. - physical_per_logical: per-group expansion ratios. When *None* - (local expansion), FA groups use the local kernel ratio - and Mamba groups are 1:1. + remote_ratio: remote engine's physical blocks per logical block. + + Returns: + Same structure with FA groups expanded (each logical block L + becomes kernel blocks [L*remote_ratio .. L*remote_ratio + + local_ratio - 1]). Mamba groups are passed through unchanged. """ - if physical_per_logical is None: - blk_ratio = self._physical_blocks_per_logical_kv_block - physical_per_logical = tuple( - 1 if m else blk_ratio for m in self._is_mamba_group - ) - return logical_to_kernel_block_ids(block_ids, physical_per_logical) + local_ratio = self._physical_blocks_per_logical_kv_block + if remote_ratio == 1: + return block_ids + local_arange = np.arange(local_ratio).reshape(1, -1) + result: list[list[int]] = [] + for i, group in enumerate(block_ids): + if not self._is_mamba_group[i]: + arr = np.array(group).reshape(-1, 1) + expanded = (arr * remote_ratio + local_arange).flatten() + result.append(expanded.tolist()) + else: + result.append(group) + return result def get_kv_connector_stats(self) -> KVConnectorStats | None: """ From e7c59c808d679dda897964f3e0552307c6058bd4 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 02:51:24 +0000 Subject: [PATCH 26/49] eliminate if _has_mamba branch from hot path via remote_expansion_stride Add remote_expansion_stride to EngineTransferPlan so the per-request hot path always calls _logical_to_remote_kernel_block_ids with the plan's stride, removing the last model-specific branch from _read_blocks_for_req. Dense plan: stride = local_physical_blocks_per_logical (stride == count). Mamba plan: stride = remote_physical_blocks_per_logical (stride != count). Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 30 ++++++++----------- .../kv_connector/unit/test_transfer_plan.py | 6 ++++ .../kv_connector/v1/nixl/transfer_plan.py | 8 +++++ .../kv_connector/v1/nixl/worker.py | 16 +++++----- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 555cf7ee8f71..b40de929ea2b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -90,44 +90,40 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "has_mamba,is_mamba_group,remote_ratio,remote_block_ids,expected_remote_block_ids", + "is_mamba_group,expansion_stride,remote_block_ids,expected_remote_block_ids", [ - # Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids. + # Dense (FA+SWA): stride == local_ratio, all groups expanded. # Regression for https://github.com/vllm-project/vllm/pull/39724 ( - False, [False, False], - 1, + 2, ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], ), - # Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids, - # Mamba passed through unchanged. - # remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using - # the wrong conversion method produces different FA results. + # Mamba (FA+Mamba): stride == remote_physical_blocks_per_logical, + # FA expanded, Mamba passed through unchanged. + # stride=261 (Nemotron 30B TP=1) != local_ratio=2 so that using + # the wrong stride produces different FA results. ( - True, [False, True], 261, ([0, 1, 2], [10, 11]), [[0, 1, 261, 262, 522, 523], [10, 11]], ), ], - ids=["non_mamba_fa_swa", "mamba_fa_ssm"], + ids=["dense_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - has_mamba, is_mamba_group, - remote_ratio, + expansion_stride, remote_block_ids, expected_remote_block_ids, ): """_read_blocks_for_req must expand remote logical block IDs to kernel block IDs when kernel block size != logical block size. - Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded). - Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba - passed through). + The hot path always calls _logical_to_remote_kernel_block_ids with + plan.remote_expansion_stride (model-agnostic). """ from unittest.mock import MagicMock @@ -142,7 +138,6 @@ def test_read_blocks_for_req_expands_remote_ids( ) worker = object.__new__(NixlConnectorWorker) - worker._has_mamba = has_mamba worker._physical_blocks_per_logical_kv_block = 2 worker._is_mamba_group = is_mamba_group @@ -160,9 +155,8 @@ def test_read_blocks_for_req_expands_remote_ids( worker.transfer_policy.compute_read_specs.return_value = [] worker.use_mla = False - # Mock the plan with the remote physical blocks ratio mock_plan = MagicMock(spec=EngineTransferPlan) - mock_plan.remote_physical_blocks_per_logical = remote_ratio + mock_plan.remote_expansion_stride = expansion_stride worker._transfer_plans = {remote_engine_id: mock_plan} metadata = NixlConnectorMetadata() diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 04c3906434b9..713ee9ea7f0f 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -122,6 +122,7 @@ def _common_plan_params( remote_num_blocks: int = 256, remote_block_lens: list[int] | None = None, remote_physical_blocks_per_logical: int = 1, + local_physical_blocks_per_logical: int = 1, ) -> dict: """Build common kwargs for plan generators.""" if block_len_per_layer is None: @@ -142,6 +143,7 @@ def _common_plan_params( remote_num_blocks=remote_num_blocks, remote_block_lens=remote_block_lens, remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + local_physical_blocks_per_logical=local_physical_blocks_per_logical, ) @@ -602,6 +604,7 @@ def _make_mamba_plan_for_desc_ids( remote_block_size=16, remote_block_len=0, remote_physical_blocks_per_logical=1, + remote_expansion_stride=1, ) @@ -679,6 +682,7 @@ def test_all_source_ranks_serve_fa(self): remote_block_size=16, remote_block_len=0, remote_physical_blocks_per_logical=1, + remote_expansion_stride=1, ) local_ids = ([1, 2], [3, 4]) @@ -707,6 +711,7 @@ def test_non_fa_rank_skips_fa_groups(self): remote_block_size=16, remote_block_len=0, remote_physical_blocks_per_logical=1, + remote_expansion_stride=1, ) local_ids = ([1, 2], [3, 4]) @@ -758,6 +763,7 @@ def test_fa_and_ssm_different_split_factors(self): remote_block_size=16, remote_block_len=0, remote_physical_blocks_per_logical=1, + remote_expansion_stride=1, ) # 2 FA descs + 1 SSM desc diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index c9ccd2a467a7..b34dfeef4402 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -120,6 +120,11 @@ class EngineTransferPlan: remote_block_len: int remote_physical_blocks_per_logical: int + # Stride for expanding remote logical block IDs to kernel block IDs. + # Dense: equals local physical_blocks_per_logical (stride == count). + # Mamba: equals remote_physical_blocks_per_logical (stride != count). + remote_expansion_stride: int + @property def all_regions(self) -> tuple[RegionPlan, ...]: return self.fa_regions + self.ssm_regions @@ -301,6 +306,7 @@ def generate_dense_plan( remote_num_blocks: int, remote_block_lens: list[int], remote_physical_blocks_per_logical: int, + local_physical_blocks_per_logical: int, ) -> EngineTransferPlan: """Generate transfer plan for dense (FA-only) models. @@ -386,6 +392,7 @@ def generate_dense_plan( remote_block_size=remote_block_size, remote_block_len=remote_block_lens[0], remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + remote_expansion_stride=local_physical_blocks_per_logical, ) @@ -578,6 +585,7 @@ def generate_mamba_plan( remote_block_size=remote_block_size, remote_block_len=remote_block_lens[0], remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + remote_expansion_stride=remote_physical_blocks_per_logical, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 12ed7d1c779a..2cf6772d0e63 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1023,6 +1023,9 @@ def add_remote_agent( else: self._transfer_plans[engine_id] = generate_dense_plan( **plan_common, + local_physical_blocks_per_logical=( + self._physical_blocks_per_logical_kv_block + ), ) remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -1627,15 +1630,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) plan = self._transfer_plans[engine_id] - if self._has_mamba: - meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( - meta.remote.block_ids, - plan.remote_physical_blocks_per_logical, - ) - else: - meta.remote.block_ids = self._logical_to_kernel_block_ids( - meta.remote.block_ids - ) + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( + meta.remote.block_ids, + plan.remote_expansion_stride, + ) remote_block_ids = meta.remote.block_ids read_specs = self.transfer_policy.compute_read_specs( local_block_ids=meta.local_physical_block_ids, From a129ae886a317be8e9a9d541c515f7a865657861 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 13:41:38 +0000 Subject: [PATCH 27/49] introduce GroupKind enum to replace is_mamba_group boolean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the bool-based is_mamba_group with a GroupKind enum (FA, SWA, MAMBA, GDN) so the transfer layer can dispatch on group type without model-specific branching. Shared behavior is captured by properties (is_attention, is_ssm) — no code duplication when adding new group types. Unsupported KVCacheSpec types raise NotImplementedError. Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 28 ++++++++---- .../kv_connector/unit/test_transfer_plan.py | 19 ++++---- .../v1/nixl/block_transfer_policy.py | 17 ++++--- .../kv_connector/v1/nixl/transfer_plan.py | 45 ++++++++++++++----- .../kv_connector/v1/nixl/worker.py | 32 +++++++++---- 5 files changed, 97 insertions(+), 44 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index b40de929ea2b..46aba4838af3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -75,9 +75,12 @@ def test_logical_to_kernel_block_ids_with_hma(): # Simulate HMA scenario: logical block size = 32, kernel block size = 16 # So each logical block maps to 2 kernel blocks eg [0]->[0,1] + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + GroupKind, + ) + worker._physical_blocks_per_logical_kv_block = 2 - # FA + SW groups (neither is MambaSpec, so both get expanded) - worker._is_mamba_group = [False, False] + worker._group_kinds = (GroupKind.FA, GroupKind.SWA) logical_block_ids = [[0, 1, 2], [3, 4]] kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids) @@ -90,12 +93,12 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "is_mamba_group,expansion_stride,remote_block_ids,expected_remote_block_ids", + "group_kinds,expansion_stride,remote_block_ids,expected_remote_block_ids", [ # Dense (FA+SWA): stride == local_ratio, all groups expanded. # Regression for https://github.com/vllm-project/vllm/pull/39724 ( - [False, False], + ("FA", "SWA"), 2, ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], @@ -105,7 +108,7 @@ def test_logical_to_kernel_block_ids_with_hma(): # stride=261 (Nemotron 30B TP=1) != local_ratio=2 so that using # the wrong stride produces different FA results. ( - [False, True], + ("FA", "MAMBA"), 261, ([0, 1, 2], [10, 11]), [[0, 1, 261, 262, 522, 523], [10, 11]], @@ -114,7 +117,7 @@ def test_logical_to_kernel_block_ids_with_hma(): ids=["dense_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - is_mamba_group, + group_kinds, expansion_stride, remote_block_ids, expected_remote_block_ids, @@ -132,6 +135,7 @@ def test_read_blocks_for_req_expands_remote_ids( ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, + GroupKind, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, @@ -139,7 +143,7 @@ def test_read_blocks_for_req_expands_remote_ids( worker = object.__new__(NixlConnectorWorker) worker._physical_blocks_per_logical_kv_block = 2 - worker._is_mamba_group = is_mamba_group + worker._group_kinds = tuple(GroupKind[k] for k in group_kinds) remote_engine_id = "remote-engine" @@ -309,9 +313,12 @@ def test_get_block_descs_ids_hybrid_ssm(): from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 MambaModelBlockTransferPolicy, ) + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + GroupKind, + ) policy = object.__new__(MambaModelBlockTransferPolicy) - policy._is_mamba_group = [False, True] + policy._group_kinds = (GroupKind.FA, GroupKind.MAMBA) num_blocks = 100 num_regions = 2 @@ -344,9 +351,12 @@ def test_get_block_descs_ids_kernel_block_mismatch(): from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 MambaModelBlockTransferPolicy, ) + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + GroupKind, + ) policy = object.__new__(MambaModelBlockTransferPolicy) - policy._is_mamba_group = [False, True] + policy._group_kinds = (GroupKind.FA, GroupKind.MAMBA) ratio = 4 logical_blocks = 100 diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 713ee9ea7f0f..4d3e2aeaad67 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -22,6 +22,7 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, + GroupKind, RegionKind, RegionPlan, build_local_splits_from_plan, @@ -532,7 +533,7 @@ def test_source_ranks_p_gt_d(self): def test_no_ssm_regions(self): plan = generate_dense_plan(**_common_plan_params()) assert plan.ssm_regions == () - assert plan.is_mamba_group == (False,) + assert plan.group_kinds == (GroupKind.FA,) def test_blocks_first_has_k_and_v(self): plan = generate_dense_plan( @@ -559,7 +560,7 @@ def test_not_blocks_first_has_only_k(self): def _make_mamba_plan_for_desc_ids( num_fa_regions: int, num_ssm_regions: int, - is_mamba_group: list[bool], + group_kinds: tuple[GroupKind, ...], fa_num_blocks: int = 100, ssm_num_blocks: int = 100, ) -> EngineTransferPlan: @@ -588,12 +589,12 @@ def _make_mamba_plan_for_desc_ids( ) for i in range(num_ssm_regions) ) - physical_per_logical = tuple(1 if m else 1 for m in is_mamba_group) + physical_per_logical = tuple(1 for _ in group_kinds) return EngineTransferPlan( fa_regions=fa_regions, ssm_regions=ssm_regions, physical_per_logical=physical_per_logical, - is_mamba_group=tuple(is_mamba_group), + group_kinds=group_kinds, all_source_ranks=(0,), fa_source_ranks=(0,), fa_source_set=frozenset({0}), @@ -616,7 +617,7 @@ def test_hybrid_ssm_ratio_1(self): plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, # 4 regions per layer, 1 layer - is_mamba_group=[False, True], + group_kinds=(GroupKind.FA, GroupKind.MAMBA), fa_num_blocks=100, ssm_num_blocks=100, ) @@ -643,7 +644,7 @@ def test_kernel_block_mismatch(self): plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, - is_mamba_group=[False, True], + group_kinds=(GroupKind.FA, GroupKind.MAMBA), fa_num_blocks=num_blocks, ssm_num_blocks=logical_blocks, ) @@ -671,7 +672,7 @@ def test_all_source_ranks_serve_fa(self): fa_regions=(), ssm_regions=(), physical_per_logical=(1, 1), - is_mamba_group=(False, True), + group_kinds=(GroupKind.FA, GroupKind.MAMBA), all_source_ranks=(0, 1), fa_source_ranks=(0, 1), fa_source_set=frozenset({0, 1}), @@ -700,7 +701,7 @@ def test_non_fa_rank_skips_fa_groups(self): fa_regions=(), ssm_regions=(), physical_per_logical=(1, 1), - is_mamba_group=(False, True), + group_kinds=(GroupKind.FA, GroupKind.MAMBA), all_source_ranks=(0, 1, 2), fa_source_ranks=(0,), fa_source_set=frozenset({0}), @@ -752,7 +753,7 @@ def test_fa_and_ssm_different_split_factors(self): ), ), physical_per_logical=(1, 1), - is_mamba_group=(False, True), + group_kinds=(GroupKind.FA, GroupKind.MAMBA), all_source_ranks=(0, 1), fa_source_ranks=(0,), fa_source_set=frozenset({0}), diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py index 17d91c86a96f..1d4eed7b0dfe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py @@ -28,6 +28,9 @@ MambaEngineTransferInfo, TransferTopology, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( + GroupKind, +) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, derive_mamba_conv_split, @@ -595,10 +598,12 @@ def __init__( physical_blocks_per_logical: int, ): super().__init__(kv_cache_config, physical_blocks_per_logical) - self._is_mamba_group = [ - isinstance(group.kv_cache_spec, MambaSpec) + self._group_kinds = tuple( + GroupKind.MAMBA + if isinstance(group.kv_cache_spec, MambaSpec) + else GroupKind.FA for group in kv_cache_config.kv_cache_groups - ] + ) mamba_spec = next( spec for spec in layer_specs.values() if isinstance(spec, MambaSpec) @@ -673,7 +678,7 @@ def get_block_descs_ids( all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): group_arr = np.asarray(group) - if self._is_mamba_group[i]: + if self._group_kinds[i].is_ssm: # Mamba blocks are 1:1 logical-to-physical (no expansion). all_descs.append( ( @@ -1205,11 +1210,11 @@ def filter_block_ids_for_rank( return local_ids, remote_ids num_groups = len(local_ids) filtered_local: list[list[int]] = [ - [] if not self._is_mamba_group[g] else local_ids[g] + local_ids[g] if self._group_kinds[g].is_ssm else [] for g in range(num_groups) ] filtered_remote: list[list[int]] = [ - [] if not self._is_mamba_group[g] else remote_ids[g] + remote_ids[g] if self._group_kinds[g].is_ssm else [] for g in range(num_groups) ] return filtered_local, filtered_remote diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index b34dfeef4402..548e6c9bff0d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -48,6 +48,29 @@ class ReadSpec: remote_block_ids: BlockIds +class GroupKind(enum.Enum): + """KV cache group type for transfer purposes. + + Used by ``EngineTransferPlan`` and block expansion functions to + dispatch per-group behavior without model-specific branching. + """ + + FA = "fa" + SWA = "swa" + MAMBA = "mamba" + GDN = "gdn" + + @property + def is_attention(self) -> bool: + """FA and SWA both need block expansion and standard descriptors.""" + return self in (GroupKind.FA, GroupKind.SWA) + + @property + def is_ssm(self) -> bool: + """MAMBA and GDN have state descriptors instead of KV pages.""" + return self in (GroupKind.MAMBA, GroupKind.GDN) + + class RegionKind(enum.Enum): """Descriptor region type. Used for visualization/debugging only; executors never branch on this value.""" @@ -91,7 +114,7 @@ class EngineTransferPlan: Regions are split into ``fa_regions`` and ``ssm_regions`` matching the descriptor handle layout: [FA descriptors | SSM descriptors]. - ``is_mamba_group`` maps kv_cache_groups to the correct section. + ``group_kinds`` maps each kv_cache_group to its type. """ # Regions in descriptor handle order @@ -101,8 +124,8 @@ class EngineTransferPlan: # Per-group geometric properties (worker-facing, model-agnostic) physical_per_logical: tuple[int, ...] - # kv_cache_group mapping (internal to transfer_plan, worker should not use) - is_mamba_group: tuple[bool, ...] + # Per-group type (FA, SWA, MAMBA, GDN). + group_kinds: tuple[GroupKind, ...] # Source rank routing all_source_ranks: tuple[int, ...] @@ -381,7 +404,7 @@ def generate_dense_plan( fa_regions=tuple(fa_regions), ssm_regions=(), physical_per_logical=(remote_physical_blocks_per_logical,), - is_mamba_group=(False,), + group_kinds=(GroupKind.FA,), all_source_ranks=all_source_ranks, fa_source_ranks=all_source_ranks, fa_source_set=frozenset(all_source_ranks), @@ -410,7 +433,7 @@ def generate_mamba_plan( remote_num_blocks: int, remote_block_lens: list[int], remote_physical_blocks_per_logical: int, - is_mamba_group: list[bool], + group_kinds: tuple[GroupKind, ...], conv_decomp: MambaConvSplitInfo, ssm_sizes: tuple[int, int], remote_ssm_sizes: tuple[int, int], @@ -568,13 +591,13 @@ def generate_mamba_plan( ) physical_per_logical_per_group = tuple( - 1 if m else remote_physical_blocks_per_logical for m in is_mamba_group + 1 if k.is_ssm else remote_physical_blocks_per_logical for k in group_kinds ) return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=tuple(ssm_regions), physical_per_logical=physical_per_logical_per_group, - is_mamba_group=tuple(is_mamba_group), + group_kinds=group_kinds, all_source_ranks=tuple(all_source_ranks), fa_source_ranks=tuple(fa_source_ranks), fa_source_set=frozenset(fa_source_ranks), @@ -664,7 +687,7 @@ def compute_desc_ids_from_plan( all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): group_arr = np.asarray(group) - if plan.is_mamba_group[i]: + if plan.group_kinds[i].is_ssm: ssm_region_ids = np.arange(num_ssm_regions)[:, None] all_descs.append( ( @@ -705,11 +728,11 @@ def compute_read_specs_from_plan( else: num_groups = len(local_block_ids) filtered_local: list[list[int]] = [ - [] if not plan.is_mamba_group[g] else list(local_block_ids[g]) + list(local_block_ids[g]) if plan.group_kinds[g].is_ssm else [] for g in range(num_groups) ] filtered_remote: list[list[int]] = [ - [] if not plan.is_mamba_group[g] else list(remote_block_ids[g]) + list(remote_block_ids[g]) if plan.group_kinds[g].is_ssm else [] for g in range(num_groups) ] specs.append( @@ -923,6 +946,6 @@ def visualize_plan(plan: EngineTransferPlan) -> str: ) total_descs += r.num_blocks - lines.append(f" Groups: {['SSM' if m else 'FA' for m in plan.is_mamba_group]}") + lines.append(f" Groups: {[k.value for k in plan.group_kinds]}") lines.append(f" Total descriptors: {total_descs}") return "\n".join(lines) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 2cf6772d0e63..411b81f2336d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -47,6 +47,7 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, + GroupKind, generate_dense_plan, generate_mamba_plan, ) @@ -71,13 +72,14 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, + SlidingWindowSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec logger = init_logger(__name__) @@ -85,6 +87,18 @@ class NixlConnectorWorker: """Implementation of Worker side methods""" + @staticmethod + def _spec_to_group_kind(spec: "KVCacheSpec") -> GroupKind: + if isinstance(spec, MambaSpec): + return GroupKind.MAMBA + if isinstance(spec, SlidingWindowSpec): + return GroupKind.SWA + if isinstance(spec, FullAttentionSpec): + return GroupKind.FA + raise NotImplementedError( + f"Unsupported KVCacheSpec type for NIXL transfer: {type(spec)}" + ) + def __init__( self, vllm_config: "VllmConfig", @@ -124,14 +138,14 @@ def __init__( } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) - # ---- Mamba model state (derived from model config) ---- - self._is_mamba_group = [ - isinstance(group.kv_cache_spec, MambaSpec) + # ---- Group kinds and model state (derived from model config) ---- + self._group_kinds = tuple( + self._spec_to_group_kind(group.kv_cache_spec) for group in kv_cache_config.kv_cache_groups - ] + ) mamba_ssm_size = (0, 0) self._conv_decomp: MambaConvSplitInfo | None = None - self._has_mamba = any(self._is_mamba_group) + self._has_mamba = any(k.is_ssm for k in self._group_kinds) if self._has_mamba: assert self._is_hma_required from vllm.model_executor.layers.mamba.mamba_utils import ( @@ -1015,7 +1029,7 @@ def add_remote_agent( assert self._conv_decomp is not None self._transfer_plans[engine_id] = generate_mamba_plan( **plan_common, - is_mamba_group=self._is_mamba_group, + group_kinds=self._group_kinds, conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, remote_ssm_sizes=nixl_agent_meta.ssm_sizes, @@ -1885,7 +1899,7 @@ def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: arange = np.arange(ratio).reshape(1, -1) return [ (np.array(group).reshape(-1, 1) * ratio + arange).flatten().tolist() - if not self._is_mamba_group[i] + if self._group_kinds[i].is_attention else group for i, group in enumerate(block_ids) ] @@ -1910,7 +1924,7 @@ def _logical_to_remote_kernel_block_ids( local_arange = np.arange(local_ratio).reshape(1, -1) result: list[list[int]] = [] for i, group in enumerate(block_ids): - if not self._is_mamba_group[i]: + if self._group_kinds[i].is_attention: arr = np.array(group).reshape(-1, 1) expanded = (arr * remote_ratio + local_arange).flatten() result.append(expanded.tolist()) From f03a17f82a9d919ad4a53400221c77af50629501 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 19:52:48 +0000 Subject: [PATCH 28/49] refactor Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 77 +- .../kv_connector/unit/test_transfer_plan.py | 328 +---- .../v1/nixl/block_transfer_policy.py | 1220 ----------------- .../kv_connector/v1/nixl/transfer_plan.py | 554 +++----- .../kv_connector/v1/nixl/worker.py | 89 +- 5 files changed, 338 insertions(+), 1930 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 46aba4838af3..57bde517125a 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -147,20 +147,15 @@ def test_read_blocks_for_req_expands_remote_ids( remote_engine_id = "remote-engine" - # Mock transfer_topo: empty remote ranks skips the transfer machinery - # entirely, isolating the block-ID expansion logic. worker.transfer_topo = MagicMock() - worker.transfer_topo.target_remote_ranks.return_value = [] - worker.transfer_topo.get_engine_info.return_value = MagicMock( - remote_tp_size=1, - ) worker.transfer_topo.tp_ratio.return_value = 1 - worker.transfer_policy = MagicMock() - worker.transfer_policy.compute_read_specs.return_value = [] worker.use_mla = False mock_plan = MagicMock(spec=EngineTransferPlan) mock_plan.remote_expansion_stride = expansion_stride + mock_plan.remote_tp_size = 1 + mock_plan.all_source_ranks = () + mock_plan.source_ranks_per_group = () worker._transfer_plans = {remote_engine_id: mock_plan} metadata = NixlConnectorMetadata() @@ -308,78 +303,68 @@ def test_nixl_metadata_hma_block_ids_structure(): @pytest.mark.cpu_test def test_get_block_descs_ids_hybrid_ssm(): - """Test get_block_descs_ids uses per-group strides for hybrid + """Test compute_desc_ids_from_plan uses per-group strides for hybrid FA+SSM when ratio=1 (no kernel block size mismatch).""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 - MambaModelBlockTransferPolicy, - ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( GroupKind, + compute_desc_ids_from_plan, ) - policy = object.__new__(MambaModelBlockTransferPolicy) - policy._group_kinds = (GroupKind.FA, GroupKind.MAMBA) + from .test_transfer_plan import _make_mamba_plan_for_desc_ids - num_blocks = 100 - num_regions = 2 - block_len_per_layer = [100] + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + group_kinds=(GroupKind.FA, GroupKind.MAMBA), + fa_num_blocks=100, + ssm_num_blocks=100, + ) fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = policy.get_block_descs_ids( + result = compute_desc_ids_from_plan( + plan, block_ids=(fa_blocks, ssm_blocks), - num_regions=num_regions, - dst_num_blocks=num_blocks, - block_len_per_layer=block_len_per_layer, + dst_num_blocks=100, physical_blocks_per_logical=1, ) - # FA group: stride=num_blocks=100, offset=0 - # region0: [3, 5], region1: [103, 105] - # SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1), - # offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm) - # region0: [201, 202], region1: [301, 302], - # region2: [401, 402], region3: [501, 502] expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502] assert list(result) == expected, f"Expected {expected}, got {list(result)}" @pytest.mark.cpu_test def test_get_block_descs_ids_kernel_block_mismatch(): - """Test get_block_descs_ids uses different strides for FA + """Test compute_desc_ids_from_plan uses different strides for FA (kernel blocks) vs SSM (logical blocks) when ratio > 1.""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( # noqa: E501 - MambaModelBlockTransferPolicy, - ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( GroupKind, + compute_desc_ids_from_plan, ) - policy = object.__new__(MambaModelBlockTransferPolicy) - policy._group_kinds = (GroupKind.FA, GroupKind.MAMBA) + from .test_transfer_plan import _make_mamba_plan_for_desc_ids ratio = 4 logical_blocks = 100 num_blocks = logical_blocks * ratio # 400 kernel blocks - num_regions = 2 - block_len_per_layer = [100] - fa_blocks = [3, 7] # kernel-level block IDs - ssm_blocks = [1, 2] # logical block IDs - result = policy.get_block_descs_ids( + plan = _make_mamba_plan_for_desc_ids( + num_fa_regions=2, + num_ssm_regions=4, + group_kinds=(GroupKind.FA, GroupKind.MAMBA), + fa_num_blocks=num_blocks, + ssm_num_blocks=logical_blocks, + ) + + fa_blocks = [3, 7] + ssm_blocks = [1, 2] + result = compute_desc_ids_from_plan( + plan, block_ids=(fa_blocks, ssm_blocks), - num_regions=num_regions, dst_num_blocks=num_blocks, - block_len_per_layer=block_len_per_layer, physical_blocks_per_logical=ratio, ) - # FA group: stride=num_blocks=400, offset=0 - # region0: [3, 7], region1: [403, 407] - # SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800, - # 4 regions per Mamba layer (x, B, C, ssm) - # region0: [801, 802], region1: [901, 902], - # region2: [1001, 1002], region3: [1101, 1102] expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102] assert list(result) == expected, f"Expected {expected}, got {list(result)}" diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 4d3e2aeaad67..295501579673 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -1,25 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Equivalence tests: plan-based executors vs current ABC policy. +"""Tests for plan-based transfer executors. -These tests verify that the new plan-based design produces identical -outputs (descriptor tuples, descriptor IDs, read specs) to the current -ModelBlockTransferPolicy ABC hierarchy. No GPU or NIXL required. +These tests verify that the plan-based design produces correct +outputs (descriptor tuples, descriptor IDs, read specs, split handles). +No GPU or NIXL required. """ from __future__ import annotations from dataclasses import dataclass -import numpy as np import pytest -from vllm.distributed.kv_transfer.kv_connector.utils import ( - TransferTopology, -) -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( - DenseModelBlockTransferPolicy, -) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, GroupKind, @@ -38,7 +31,6 @@ # ====================================================================== ENGINE_ID = "remote_engine" -LOCAL_ENGINE_ID = "local_engine" @dataclass @@ -57,59 +49,6 @@ class FakeNixlAgentMeta: attn_backend_name: str -def _make_kv_cache_config( - block_size: int = 16, - num_blocks: int = 256, - num_layers: int = 2, - head_size: int = 128, - num_kv_heads: int = 8, -): - """Create a minimal KVCacheConfig for Dense models.""" - import torch - - from vllm.v1.kv_cache_interface import ( - FullAttentionSpec, - KVCacheConfig, - KVCacheGroupSpec, - ) - - spec = FullAttentionSpec( - block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=torch.float16, - ) - layers = [f"layer_{i}" for i in range(num_layers)] - return KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[], - kv_cache_groups=[KVCacheGroupSpec(layers, spec)], - ) - - -def _make_transfer_topo( - tp_rank: int = 0, - tp_size: int = 1, - block_size: int = 16, - is_mla: bool = False, - num_kv_heads: int = 8, -): - """Create a TransferTopology for testing without real attention backend.""" - from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - - return TransferTopology( - tp_rank=tp_rank, - tp_size=tp_size, - block_size=block_size, - engine_id=LOCAL_ENGINE_ID, - is_mla=is_mla, - is_mamba=False, - total_num_kv_heads=num_kv_heads, - attn_backends=[FlashAttentionBackend], - physical_blocks_per_logical=1, - ) - - def _common_plan_params( tp_rank: int = 0, tp_size: int = 1, @@ -174,37 +113,30 @@ def _make_nixl_meta( # ====================================================================== -class TestDensePlanEquivalence: - """Verify plan-based outputs match current DenseModelBlockTransferPolicy.""" +class TestDensePlanExecutors: + """Verify plan-based executors produce correct outputs for dense models.""" @pytest.mark.parametrize( "tp_size,remote_tp_size", [ - (1, 1), # homogeneous - (2, 1), # D_TP > P_TP - (4, 2), # D_TP > P_TP (larger) - (1, 2), # P_TP > D_TP - (2, 4), # P_TP > D_TP (larger) + (1, 1), + (2, 1), + (4, 2), + (1, 2), + (2, 4), ], ) @pytest.mark.parametrize("tp_rank_frac", [0.0, 0.5]) - def test_build_remote_descs( - self, - tp_size, - remote_tp_size, - tp_rank_frac, - ): + def test_build_remote_descs(self, tp_size, remote_tp_size, tp_rank_frac): tp_rank = int(tp_rank_frac * (tp_size - 1)) if tp_size > 1 else 0 num_kv_heads = 8 - head_size = 128 block_size = 16 num_blocks = 64 num_layers = 2 - slot_size = num_kv_heads * head_size * 2 + slot_size = num_kv_heads * 128 * 2 block_len = slot_size * block_size block_len_per_layer = [block_len] * num_layers - # Adjust remote block_lens for hetero TP if tp_size >= remote_tp_size: tp_ratio = tp_size // remote_tp_size remote_block_lens = [bl * tp_ratio for bl in block_len_per_layer] @@ -213,53 +145,12 @@ def test_build_remote_descs( remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] base_addrs = [0x1000 * (i + 1) for i in range(num_layers)] - - # ---- Old path ---- - kv_config = _make_kv_cache_config( - block_size=block_size, - num_blocks=num_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - ) - policy = DenseModelBlockTransferPolicy(kv_config, 1) - topo = _make_transfer_topo( - tp_rank=tp_rank, - tp_size=tp_size, - block_size=block_size, - num_kv_heads=num_kv_heads, - ) - is_blocks_first = topo.is_kv_layout_blocks_first - transfer_info = policy.build_engine_transfer_info( - transfer_topo=topo, - local_block_len=block_len_per_layer[0], - remote_tp_size=remote_tp_size, - remote_block_size=block_size, - remote_block_len=remote_block_lens[0], - remote_physical_blocks_per_logical=1, - ) - topo.register_remote_engine(ENGINE_ID, transfer_info) - meta = _make_nixl_meta( - base_addrs, - num_blocks, - remote_block_lens, - block_size=block_size, - ) - old_descs = policy.build_remote_descs( - topo, - ENGINE_ID, - meta, - block_len_per_layer, - ) - - # ---- New path ---- plan = generate_dense_plan( **_common_plan_params( tp_rank=tp_rank, tp_size=tp_size, num_kv_heads=num_kv_heads, block_size=block_size, - is_blocks_first=is_blocks_first, block_len_per_layer=block_len_per_layer, remote_tp_size=remote_tp_size, remote_block_size=block_size, @@ -267,20 +158,20 @@ def test_build_remote_descs( remote_block_lens=remote_block_lens, ), ) - new_descs = build_remote_descs_from_plan(plan, meta) - - assert old_descs == new_descs, ( - f"Descriptor mismatch for tp={tp_size}/{remote_tp_size}, " - f"rank={tp_rank}.\nOld: {old_descs[:5]}...\nNew: {new_descs[:5]}..." + meta = _make_nixl_meta( + base_addrs, num_blocks, remote_block_lens, block_size=block_size ) + descs = build_remote_descs_from_plan(plan, meta) + + expected_count = len(plan.fa_regions) * num_blocks + assert len(descs) == expected_count + for addr, length, dev in descs: + assert length > 0 + assert dev == 0 @pytest.mark.parametrize( "tp_size,remote_tp_size", - [ - (1, 1), - (2, 1), - (1, 2), - ], + [(1, 1), (2, 1), (1, 2)], ) def test_compute_desc_ids(self, tp_size, remote_tp_size): num_kv_heads = 8 @@ -298,26 +189,11 @@ def test_compute_desc_ids(self, tp_size, remote_tp_size): tp_ratio_neg = remote_tp_size // tp_size remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] - topo = _make_transfer_topo( - tp_size=tp_size, - block_size=block_size, - num_kv_heads=num_kv_heads, - ) - is_blocks_first = topo.is_kv_layout_blocks_first - - kv_config = _make_kv_cache_config( - block_size=block_size, - num_blocks=num_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - ) - policy = DenseModelBlockTransferPolicy(kv_config, 1) plan = generate_dense_plan( **_common_plan_params( tp_size=tp_size, num_kv_heads=num_kv_heads, block_size=block_size, - is_blocks_first=is_blocks_first, block_len_per_layer=block_len_per_layer, remote_tp_size=remote_tp_size, remote_block_size=block_size, @@ -326,33 +202,18 @@ def test_compute_desc_ids(self, tp_size, remote_tp_size): ), ) - num_regions = len(plan.fa_regions) block_ids = ([1, 5, 10, 20],) + ids = compute_desc_ids_from_plan(plan, block_ids, dst_num_blocks=num_blocks) - old_ids = policy.get_block_descs_ids( - block_ids=block_ids, - num_regions=num_regions, - dst_num_blocks=num_blocks, - block_len_per_layer=block_len_per_layer, - ) - new_ids = compute_desc_ids_from_plan( - plan, - block_ids, - dst_num_blocks=num_blocks, - ) - - np.testing.assert_array_equal(old_ids, new_ids) + num_regions = len(plan.fa_regions) + assert len(ids) == num_regions * len(block_ids[0]) + assert ids[0] == 1 @pytest.mark.parametrize( "tp_size,remote_tp_size", - [ - (1, 1), - (2, 1), - (1, 2), - ], + [(1, 1), (2, 1), (1, 2)], ) def test_compute_read_specs(self, tp_size, remote_tp_size): - tp_rank = 0 num_kv_heads = 8 block_size = 16 num_blocks = 64 @@ -368,38 +229,11 @@ def test_compute_read_specs(self, tp_size, remote_tp_size): tp_ratio_neg = remote_tp_size // tp_size remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] - kv_config = _make_kv_cache_config( - block_size=block_size, - num_blocks=num_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - ) - policy = DenseModelBlockTransferPolicy(kv_config, 1) - topo = _make_transfer_topo( - tp_rank=tp_rank, - tp_size=tp_size, - block_size=block_size, - num_kv_heads=num_kv_heads, - ) - is_blocks_first = topo.is_kv_layout_blocks_first - transfer_info = policy.build_engine_transfer_info( - transfer_topo=topo, - local_block_len=block_len_per_layer[0], - remote_tp_size=remote_tp_size, - remote_block_size=block_size, - remote_block_len=remote_block_lens[0], - remote_physical_blocks_per_logical=1, - ) - topo.register_remote_engine(ENGINE_ID, transfer_info) - remote_ranks = topo.target_remote_ranks(ENGINE_ID) - plan = generate_dense_plan( **_common_plan_params( - tp_rank=tp_rank, tp_size=tp_size, num_kv_heads=num_kv_heads, block_size=block_size, - is_blocks_first=is_blocks_first, block_len_per_layer=block_len_per_layer, remote_tp_size=remote_tp_size, remote_block_size=block_size, @@ -410,20 +244,12 @@ def test_compute_read_specs(self, tp_size, remote_tp_size): local_ids = ([1, 2, 3],) remote_ids = ([4, 5, 6],) + specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) - old_specs = policy.compute_read_specs( - local_ids, - remote_ids, - remote_ranks, - transfer_info, - ) - new_specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) - - assert len(old_specs) == len(new_specs) - for old, new in zip(old_specs, new_specs): - assert old.remote_rank == new.remote_rank - assert list(old.local_block_ids[0]) == list(new.local_block_ids[0]) - assert list(old.remote_block_ids[0]) == list(new.remote_block_ids[0]) + assert len(specs) == len(plan.all_source_ranks) + for spec in specs: + assert list(spec.local_block_ids[0]) == [1, 2, 3] + assert list(spec.remote_block_ids[0]) == [4, 5, 6] @pytest.mark.parametrize("remote_tp_size", [2, 4]) def test_build_src_split_handles(self, remote_tp_size): @@ -440,37 +266,12 @@ def test_build_src_split_handles(self, remote_tp_size): tp_ratio_neg = remote_tp_size // tp_size remote_block_lens = [bl // tp_ratio_neg for bl in block_len_per_layer] - kv_config = _make_kv_cache_config( - block_size=block_size, - num_blocks=num_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - ) - policy = DenseModelBlockTransferPolicy(kv_config, 1) - topo = _make_transfer_topo( - tp_rank=tp_rank, - tp_size=tp_size, - block_size=block_size, - num_kv_heads=num_kv_heads, - ) - is_blocks_first = topo.is_kv_layout_blocks_first - transfer_info = policy.build_engine_transfer_info( - transfer_topo=topo, - local_block_len=block_len_per_layer[0], - remote_tp_size=remote_tp_size, - remote_block_size=block_size, - remote_block_len=remote_block_lens[0], - remote_physical_blocks_per_logical=1, - ) - topo.register_remote_engine(ENGINE_ID, transfer_info) - plan = generate_dense_plan( **_common_plan_params( tp_rank=tp_rank, tp_size=tp_size, num_kv_heads=num_kv_heads, block_size=block_size, - is_blocks_first=is_blocks_first, block_len_per_layer=block_len_per_layer, remote_tp_size=remote_tp_size, remote_block_size=block_size, @@ -481,24 +282,17 @@ def test_build_src_split_handles(self, remote_tp_size): src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)] num_descs = len(src_blocks_data) - - old_splits = policy.build_src_split_handles( - topo, - ENGINE_ID, - src_blocks_data, - num_descs, - ) - new_splits = build_local_splits_from_plan( + splits = build_local_splits_from_plan( plan, src_blocks_data, num_descs, ) - assert len(old_splits) == len(new_splits), ( - f"Split count mismatch: {len(old_splits)} vs {len(new_splits)}" - ) - for i, (old, new) in enumerate(zip(old_splits, new_splits)): - assert old == new, f"Split {i} mismatch" + assert len(splits) == remote_tp_size + for handle in splits: + assert len(handle) == len(src_blocks_data) + for _, length, _ in handle: + assert length == 1024 // remote_tp_size class TestDensePlanVisualization: @@ -590,17 +384,16 @@ def _make_mamba_plan_for_desc_ids( for i in range(num_ssm_regions) ) physical_per_logical = tuple(1 for _ in group_kinds) + all_ranks = (0,) + source_ranks_per_group = tuple(all_ranks for _ in group_kinds) return EngineTransferPlan( fa_regions=fa_regions, ssm_regions=ssm_regions, physical_per_logical=physical_per_logical, group_kinds=group_kinds, + source_ranks_per_group=source_ranks_per_group, all_source_ranks=(0,), - fa_source_ranks=(0,), - fa_source_set=frozenset({0}), - num_fa_reads=1, - num_mamba_reads=1, - fa_head_slots={0: 0}, + rank_to_attention_slot={0: 0}, remote_tp_size=1, remote_block_size=16, remote_block_len=0, @@ -610,7 +403,7 @@ def _make_mamba_plan_for_desc_ids( class TestMambaPlanDescIds: - """Verify plan-based desc IDs match MambaModelBlockTransferPolicy.""" + """Verify plan-based desc IDs for hybrid FA+SSM models.""" def test_hybrid_ssm_ratio_1(self): """Equivalent to test_get_block_descs_ids_hybrid_ssm.""" @@ -668,17 +461,15 @@ class TestMambaPlanReadSpecs: def test_all_source_ranks_serve_fa(self): """When all ranks are FA sources, no filtering happens.""" + both = (0, 1) plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), + source_ranks_per_group=(both, both), all_source_ranks=(0, 1), - fa_source_ranks=(0, 1), - fa_source_set=frozenset({0, 1}), - num_fa_reads=2, - num_mamba_reads=2, - fa_head_slots={0: 0, 1: 1}, + rank_to_attention_slot={0: 0, 1: 1}, remote_tp_size=2, remote_block_size=16, remote_block_len=0, @@ -696,18 +487,17 @@ def test_all_source_ranks_serve_fa(self): assert list(spec.local_block_ids[1]) == [3, 4] def test_non_fa_rank_skips_fa_groups(self): - """Ranks not in fa_source_set get FA groups zeroed out.""" + """Ranks not in source_ranks_per_group get groups zeroed out.""" + fa_readers = (0,) + ssm_readers = (0, 1, 2) plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), + source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1, 2), - fa_source_ranks=(0,), - fa_source_set=frozenset({0}), - num_fa_reads=1, - num_mamba_reads=3, - fa_head_slots={0: 0}, + rank_to_attention_slot={0: 0}, remote_tp_size=3, remote_block_size=16, remote_block_len=0, @@ -738,7 +528,9 @@ class TestMambaPlanSplitHandles: """Verify plan-based split handles for Mamba with FA/SSM distinction.""" def test_fa_and_ssm_different_split_factors(self): - """FA descs split by num_fa_reads, SSM descs split by abs_tp.""" + """Section 0 split by num_attn_reads, section 1 by abs_tp.""" + fa_readers = (0,) + ssm_readers = (0, 1) plan = EngineTransferPlan( fa_regions=(), ssm_regions=( @@ -754,12 +546,9 @@ def test_fa_and_ssm_different_split_factors(self): ), physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), + source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1), - fa_source_ranks=(0,), - fa_source_set=frozenset({0}), - num_fa_reads=1, - num_mamba_reads=2, - fa_head_slots={0: 0}, + rank_to_attention_slot={0: 0, 1: 0}, remote_tp_size=2, remote_block_size=16, remote_block_len=0, @@ -773,9 +562,8 @@ def test_fa_and_ssm_different_split_factors(self): (2000, 200, 0), # FA desc 1 (3000, 400, 0), # SSM desc 0 ] - num_fa_descs = 2 - splits = build_local_splits_from_plan(plan, src_blocks_data, num_fa_descs) + splits = build_local_splits_from_plan(plan, src_blocks_data, 2) assert len(splits) == 2 # 2 source ranks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py deleted file mode 100644 index 1d4eed7b0dfe..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/block_transfer_policy.py +++ /dev/null @@ -1,1220 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""ModelBlockTransferPolicy: model-specific transfer intelligence. - -This module defines the ``ModelBlockTransferPolicy`` ABC and its concrete -implementations for Dense and Mamba models. The policy encapsulates all -model-specific logic that was previously scattered across ``worker.py`` -and ``TransferTopology``, making both of those model-agnostic. - -The policy is an *immutable config holder*: its state is set once at -``__init__`` and never mutated. It computes results but stores no -per-engine state (that lives on ``TransferTopology._engines``). -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import numpy as np -import torch - -from vllm.distributed.kv_transfer.kv_connector.utils import ( - BlockIds, - EngineId, - EngineTransferInfo, - MambaEngineTransferInfo, - TransferTopology, -) -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( - GroupKind, -) -from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( - MambaConvSplitInfo, - derive_mamba_conv_split, -) -from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first -from vllm.v1.kv_cache_interface import MambaSpec - -if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( - NixlAgentMetadata, - ) - from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class ReadSpec: - """Specification for a single remote block read operation. - - Computed upfront by ``compute_read_specs`` so that the worker's read - loop is a simple iteration with no model-specific branching. - """ - - remote_rank: int - local_block_ids: BlockIds - remote_block_ids: BlockIds - - -# ------------------------------------------------------------------ -# Private module-level helpers -# ------------------------------------------------------------------ - - -def _get_kv_block_len( - layer_idx: int, - block_len_per_layer: list[int], - is_blocks_first: bool, -) -> int: - """Byte length of one K or V descriptor for a layer. - - For FA/KV layers only — Mamba SSM descriptors use ``_ssm_sizes`` - directly. When ``is_blocks_first``, K and V are interleaved so - the descriptor covers half the block. - """ - if is_blocks_first: - return block_len_per_layer[layer_idx] // 2 - return block_len_per_layer[layer_idx] - - -def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight - partitioning): consecutive ranks share a head before moving to - the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - -def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - - -def _should_skip_fa(info: EngineTransferInfo, remote_rank: int) -> bool: - """Whether to skip FA groups for this remote rank. - - Returns False unless ``info`` carries FA source tracking - (``MambaEngineTransferInfo``) and the rank is a replicated - duplicate. - """ - if not isinstance(info, MambaEngineTransferInfo): - return False - return remote_rank not in info.fa_source_set - - -def _fa_head_slot( - info: EngineTransferInfo, - remote_rank: int, - total_num_kv_heads: int, -) -> int: - """Index into local FA block for this remote rank's head data. - - For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. - For ranks NOT in ``fa_source_ranks`` (replicated duplicates), - returns the slot of the matching source rank with the same head. - Returns 0 when ``info`` has no FA source tracking. - """ - if not isinstance(info, MambaEngineTransferInfo): - return 0 - fa_index = info.fa_source_indices - if remote_rank in fa_index: - return fa_index[remote_rank] - K = total_num_kv_heads - remote_tp = info.remote_tp_size - r_head = _physical_head_range(remote_tp, K, remote_rank) - for target in info.remote_fa_source_ranks: - t_head = _physical_head_range(remote_tp, K, target) - if _range_overlap(r_head, t_head): - return fa_index[target] - return 0 - - -def _fa_rank_offset( - info: EngineTransferInfo, - remote_kv_block_len: int, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, -) -> int: - """Byte offset into remote FA block for this local rank. - - When local TP is replicated (local_tp > K), multiple local ranks - share a head. Computes offset *relative to the target remote - rank's first head* so it works regardless of how many heads the - remote has. Returns 0 when ``info`` has no FA source tracking - or when local does not index into remote. - """ - if not isinstance(info, MambaEngineTransferInfo): - return 0 - tp_ratio = ( - tp_size // info.remote_tp_size - if tp_size >= info.remote_tp_size - else -(info.remote_tp_size // tp_size) - ) - if is_mla or tp_ratio <= 0: - return 0 - K = total_num_kv_heads - is_local_replicated = tp_size > K - if is_local_replicated: - local_head = tp_rank * K // tp_size - p_rank = info.remote_fa_source_ranks[0] - p_start = p_rank * K // info.remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return tp_rank % tp_ratio * remote_kv_block_len - - -class ModelBlockTransferPolicy(ABC): - """Abstract base for model-specific block transfer logic. - - Encapsulates genuinely model-specific algorithms: descriptor building, - transfer info computation, split handles, read spec filtering, and - orchestration. Simple per-layer branches (``isinstance(MambaSpec)``) - and block-ID mapping remain on ``worker.py``. - """ - - def __init__( - self, - kv_cache_config: KVCacheConfig, - physical_blocks_per_logical: int, - ): - self._kv_cache_config = kv_cache_config - self._physical_blocks_per_logical = physical_blocks_per_logical - - # ------------------------------------------------------------------ - # Per-engine transfer info (data operations) - # ------------------------------------------------------------------ - - @abstractmethod - def build_engine_transfer_info( - self, - *, - # Local topology - transfer_topo: TransferTopology, - # Block geometry - local_block_len: int, - # Remote facts (from NixlAgentMetadata handshake) - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - ) -> EngineTransferInfo: - """Compute transfer info for a remote engine. - - Dense models return ``EngineTransferInfo``. - Mamba models return ``MambaEngineTransferInfo``. - """ - - # ------------------------------------------------------------------ - # Descriptor ID computation (abstract — genuinely different per model) - # ------------------------------------------------------------------ - - @abstractmethod - def get_block_descs_ids( - self, - # Input - block_ids: BlockIds, - # Block geometry - num_regions: int, - dst_num_blocks: int, - block_len_per_layer: list[int], - block_size_ratio: float | None = None, - physical_blocks_per_logical: int = 1, - ) -> np.ndarray: - """Compute NIXL descriptor IDs for a set of block IDs.""" - ... - - # ------------------------------------------------------------------ - # Local descriptor building (concrete default = FA-only) - # ------------------------------------------------------------------ - - @abstractmethod - def build_local_descs( - self, - # Memory - base_addresses: list[int], - device_id: int, - # Block geometry - num_blocks: int, - logical_num_blocks: int, - block_size_ratio: int, - block_len_per_layer: list[int], - # Layout - is_blocks_first: bool, - ) -> list[tuple[int, int, int]]: - """Build local (src) descriptor tuples for NIXL registration.""" - ... - - def _build_fa_local_descs( - self, - base_addresses: list[int], - device_id: int, - num_blocks: int, - block_size_ratio: int, - block_len_per_layer: list[int], - is_blocks_first: bool, - ) -> list[tuple[int, int, int]]: - """Build FA local descriptors (shared by Dense and Mamba).""" - result: list[tuple[int, int, int]] = [] - n_blocks = num_blocks * block_size_ratio - for i, base_addr in enumerate(base_addresses): - # The new block_len is using prefill block_len; - # and num_blocks is multiple with N - kv_block_len = ( - _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - // block_size_ratio - ) - # Jump one page_size, but ssm page_size may be bigger when kernel - # locks block size to a specific value. - page_stride = block_len_per_layer[i] // block_size_ratio - for block_id in range(n_blocks): - # (addr, len, device id) - result.append( - ( - base_addr + block_id * page_stride, - kv_block_len, - device_id, - ) - ) - if is_blocks_first: - second_split = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(n_blocks): - # Register addresses for V cache (K registered first). - v_addr = base_addr + block_id * page_stride + kv_block_len - result.append( - ( - v_addr, - second_split, - device_id, - ) - ) - return result - - # ------------------------------------------------------------------ - # Remote descriptor building (abstract — genuinely different) - # ------------------------------------------------------------------ - - @abstractmethod - def build_remote_descs( - self, - transfer_topo: TransferTopology, - engine_id: EngineId, - nixl_agent_meta: NixlAgentMetadata, - block_len_per_layer: list[int], - ) -> list[tuple[int, int, int]]: - """Build remote (dst) descriptor tuples.""" - ... - - @abstractmethod - def build_src_split_handles( - self, - transfer_topo: TransferTopology, - engine_id: EngineId, - src_blocks_data: list[tuple[int, int, int]], - num_descs: int, - ) -> list[list[tuple[int, int, int]]]: - """Build split handle data for P_TP > D_TP scenario.""" - ... - - # ------------------------------------------------------------------ - # Read spec computation - # ------------------------------------------------------------------ - - @abstractmethod - def compute_read_specs( - self, - local_block_ids: BlockIds, - remote_block_ids: BlockIds, - remote_ranks: list[int], - remote_info: EngineTransferInfo, - ) -> list[ReadSpec]: - """Compute the full set of read operations needed for a request. - - Returns one ``ReadSpec`` per remote rank that requires a read. - The worker iterates the result without model-specific branching. - MLA trimming (keeping only the first spec) is handled by the worker. - - Block ID expansion (logical→kernel) is done by the worker before - calling this method. - """ - ... - - # ------------------------------------------------------------------ - # Factory - # ------------------------------------------------------------------ - - @staticmethod - def create( - kv_cache_config: KVCacheConfig, - layer_specs: dict[str, KVCacheSpec], - physical_blocks_per_logical: int, - tp_size: int, - ) -> ModelBlockTransferPolicy: - """Create the appropriate policy based on model architecture.""" - has_mamba = any( - isinstance(group.kv_cache_spec, MambaSpec) - for group in kv_cache_config.kv_cache_groups - ) - if has_mamba: - return MambaModelBlockTransferPolicy( - kv_cache_config=kv_cache_config, - tp_size=tp_size, - layer_specs=layer_specs, - physical_blocks_per_logical=physical_blocks_per_logical, - ) - return DenseModelBlockTransferPolicy( - kv_cache_config, - physical_blocks_per_logical, - ) - - -# ====================================================================== -# Dense (pure-attention) policy -# ====================================================================== - - -class DenseModelBlockTransferPolicy(ModelBlockTransferPolicy): - """Policy for pure-attention (dense) models. - - Inherits ``get_block_len`` from the ABC. - """ - - def __init__( - self, - kv_cache_config: KVCacheConfig, - physical_blocks_per_logical: int, - ): - super().__init__(kv_cache_config, physical_blocks_per_logical) - - def build_local_descs( - self, - base_addresses, - device_id, - num_blocks, - logical_num_blocks, - block_size_ratio, - block_len_per_layer, - is_blocks_first, - ): - return self._build_fa_local_descs( - base_addresses, - device_id, - num_blocks, - block_size_ratio, - block_len_per_layer, - is_blocks_first, - ) - - def get_block_descs_ids( - self, - block_ids, - num_regions, - dst_num_blocks, - block_len_per_layer, - block_size_ratio=None, - physical_blocks_per_logical=1, - ): - num_blocks = dst_num_blocks - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - block_ids_arr = np.concatenate(block_ids) - region_ids = np.arange(num_regions)[:, None] - return (region_ids * num_blocks + block_ids_arr[None, :]).flatten() - - def build_remote_descs( - self, - transfer_topo, - engine_id, - nixl_agent_meta, - block_len_per_layer, - ): - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogeneous TP, prepare the descriptors by splitting the - # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - # Register all remote blocks, but only the corresponding kv heads. - remote_info = transfer_topo.get_engine_info(engine_id) - assert isinstance(remote_info, EngineTransferInfo) - tp_rank = transfer_topo.tp_rank - is_mla = transfer_topo.is_mla - is_blocks_first = transfer_topo.is_kv_layout_blocks_first - tp_ratio = transfer_topo.tp_ratio(remote_info.remote_tp_size) - block_size_ratio = transfer_topo.block_size_ratio(remote_info.remote_block_size) - indexes_into_remote = ( - not ( - is_mla or remote_info.remote_tp_size > transfer_topo.total_num_kv_heads - ) - and tp_ratio > 0 - ) - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate( - nixl_agent_meta.kv_caches_base_addr, - ): - local_block_len = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len - if tp_ratio < 0 and not is_mla: - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 - ) - num_blocks = nixl_agent_meta.num_blocks - page_size = nixl_agent_meta.block_lens[i] - dev_id = nixl_agent_meta.device_id - for blk in range(num_blocks): - addr = base_addr + blk * page_size + rank_offset - result.append((addr, local_block_len, dev_id)) - if is_blocks_first: - second_split = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - if tp_ratio < 0 and not is_mla: - second_split = second_split // (-tp_ratio) - for blk in range(num_blocks): - addr = base_addr + blk * page_size + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append((v_addr, second_split, dev_id)) - return result - - def build_src_split_handles( - self, - transfer_topo, - engine_id, - src_blocks_data, - num_descs, - ): - remote_info = transfer_topo.get_engine_info(engine_id) - assert isinstance(remote_info, EngineTransferInfo) - _ = num_descs - assert remote_info.remote_tp_size > transfer_topo.tp_size - abs_tp = remote_info.remote_tp_size // transfer_topo.tp_size - result: list[list[tuple[int, int, int]]] = [] - for i in range(abs_tp): - blocks_data: list[tuple[int, int, int]] = [] - for addr, local_block_len, own_rank in src_blocks_data: - remote_block_len = local_block_len // abs_tp - blocks_data.append( - ( - addr + i * remote_block_len, - remote_block_len, - own_rank, - ) - ) - result.append(blocks_data) - return result - - def compute_read_specs( - self, - local_block_ids, - remote_block_ids, - remote_ranks, - remote_info, - ): - assert isinstance(remote_info, EngineTransferInfo) - return [ - ReadSpec( - remote_rank=rank, - local_block_ids=local_block_ids, - remote_block_ids=remote_block_ids, - ) - for rank in remote_ranks - ] - - def build_engine_transfer_info( - self, - *, - # Local topology - transfer_topo: TransferTopology, - # Block geometry - local_block_len: int, - # Remote facts (from NixlAgentMetadata handshake) - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - ) -> EngineTransferInfo: - _ = (transfer_topo, local_block_len) - return EngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, - ) - - -# ====================================================================== -# Mamba (hybrid SSM+Attention) policy -# ====================================================================== - - -class MambaModelBlockTransferPolicy(ModelBlockTransferPolicy): - """Policy for hybrid Mamba+Attention (SSM) models. - - Stores Mamba-specific state (SSM sizes, conv decomposition, - per-group flags) and overrides methods that differ from the - FA-only base: descriptor building, split handles, read specs, - registration helpers for MambaSpec layers, and orchestration. - """ - - def __init__( - self, - kv_cache_config: KVCacheConfig, - tp_size: int, - layer_specs: dict[str, KVCacheSpec], - physical_blocks_per_logical: int, - ): - super().__init__(kv_cache_config, physical_blocks_per_logical) - self._group_kinds = tuple( - GroupKind.MAMBA - if isinstance(group.kv_cache_spec, MambaSpec) - else GroupKind.FA - for group in kv_cache_config.kv_cache_groups - ) - - mamba_spec = next( - spec for spec in layer_specs.values() if isinstance(spec, MambaSpec) - ) - conv_nbytes = torch.tensor( - [], - dtype=mamba_spec.dtypes[0], # type: ignore[misc] - ).element_size() - ssm_nbytes = torch.tensor( - [], - dtype=mamba_spec.dtypes[1], # type: ignore[misc] - ).element_size() - conv_shape = torch.Size(mamba_spec.shapes[0]) - ssm_shape = torch.Size(mamba_spec.shapes[1]) - self._ssm_sizes = ( - conv_shape.numel() * conv_nbytes, - ssm_shape.numel() * ssm_nbytes, - ) - - assert is_conv_state_dim_first(), ( - "3-read Mamba conv transfer requires DS conv state layout. " - "Set VLLM_SSM_CONV_STATE_LAYOUT=DS" - ) - self._conv_decomp = derive_mamba_conv_split(mamba_spec, tp_size) - - @property - def ssm_sizes(self) -> tuple[int, int]: - """(conv_state_bytes, ssm_state_bytes) per logical block.""" - return self._ssm_sizes - - @property - def conv_decomp(self) -> MambaConvSplitInfo: - """Conv-state sub-projection decomposition.""" - return self._conv_decomp - - def get_block_descs_ids( - self, - block_ids, - num_regions, - dst_num_blocks, - block_len_per_layer, - block_size_ratio=None, - physical_blocks_per_logical=1, - ): - # NOTE (NickLucche) With HMA, every kv group has the same number of - # layers and layers from different groups share the same kv tensor. - # eg block_ids=[[1,2],[3]]->blocks [1,2] need to be read across all - # regions, same for [3], but group0-group1 blocks will always differ - # (different areas). Therefore we can just flatten the block_ids and - # compute the descs ids for all groups at once. - - # Compute desc ids per group using the right stride: FA descs have - # num_blocks entries per region (kernel granularity), SSM descs have - # logical_blocks entries per region (no kernel splitting). - num_blocks = dst_num_blocks - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - ratio = physical_blocks_per_logical - logical_blocks = num_blocks // ratio - num_fa_descs = num_regions * num_blocks - # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged - # arbitrarily by manager. Therefore, descs are duplicated for SSM and - # Attention like so: - # desc_handle->[descs_fa (all regions) | descs_ssm (all regions)]. - # This is like having two "low-level views" of the same storage. - # `num_fa_descs` offset must be computed per-engine since P and D can - # have different num_blocks (and thus different FA descs counts). - # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). - mamba_region_ids = np.arange( - len(block_len_per_layer) * 4, - )[:, None] - all_descs: list[np.ndarray] = [] - for i, group in enumerate(block_ids): - group_arr = np.asarray(group) - if self._group_kinds[i].is_ssm: - # Mamba blocks are 1:1 logical-to-physical (no expansion). - all_descs.append( - ( - mamba_region_ids * logical_blocks - + group_arr[None, :] - + num_fa_descs - ).flatten() - ) - else: - region_ids = np.arange(num_regions)[:, None] - all_descs.append( - (region_ids * num_blocks + group_arr[None, :]).flatten() - ) - return np.concatenate(all_descs) - - def build_local_descs( - self, - base_addresses, - device_id, - num_blocks, - logical_num_blocks, - block_size_ratio, - block_len_per_layer, - is_blocks_first, - ): - fa_descs = self._build_fa_local_descs( - base_addresses, - device_id, - num_blocks, - block_size_ratio, - block_len_per_layer, - is_blocks_first, - ) - num_regions = len(base_addresses) * (2 if is_blocks_first else 1) - assert len(fa_descs) == num_regions * num_blocks - # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read - # split is unnecessary — a single conv desc per block suffices. - # Consider adding a fast path that falls back to the standard 2-region - # registration when no hetero-TP remote has been seen. Currently we - # always register 4 regions because local descs are created before - # knowing the remote TP. - logger.debug("Registering local Mamba descriptors (4 regions/layer)") - mamba_descs = self._build_mamba_local_descs( - base_addresses, - block_len_per_layer, - logical_num_blocks, - block_size_ratio, - device_id, - ) - return fa_descs + mamba_descs - - def _build_mamba_local_descs( - self, - base_addresses: list[int], - block_len_per_layer: list[int], - logical_num_blocks: int, - block_size_ratio: int, - device_id: int, - ) -> list[tuple[int, int, int]]: - """Build 4 desc regions (x, B, C, ssm) per layer for local - mamba blocks, enabling the 3-read transfer with DS conv layout. - - Conv state sub-projection decomposition requires DS (dim, state_len) - conv layout so that x/B/C sub-projections are contiguous in memory. - """ - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 " - f"is not tested. Got block_size_ratio={block_size_ratio}." - ) - conv_offsets = self._conv_decomp.local_conv_offsets - # SSM States come in tuples (conv_size, ssm_state_size) - conv_size, ssm_size = self._ssm_sizes - n_blocks = logical_num_blocks * block_size_ratio - phys_ratio = self._physical_blocks_per_logical - - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): - page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio - # 3 conv sub-projection regions (x, B, C) - for off, sz in conv_offsets: - for blk in range(n_blocks): - result.append( - ( - base_addr + blk * page_stride + off, - sz, - device_id, - ) - ) - # SSM temporal state follows the conv state. - for blk in range(n_blocks): - result.append( - ( - base_addr + blk * page_stride + conv_size, - ssm_size, - device_id, - ) - ) - return result - - def build_remote_descs( - self, - transfer_topo, - engine_id, - nixl_agent_meta, - block_len_per_layer, - ): - remote_info = transfer_topo.get_engine_info(engine_id) - assert isinstance(remote_info, MambaEngineTransferInfo) - info = remote_info - tp_ratio = transfer_topo.tp_ratio(info.remote_tp_size) - result: list[tuple[int, int, int]] = [] - result.extend( - self._build_fa_remote_descs( - transfer_topo, - nixl_agent_meta, - info, - tp_ratio, - block_len_per_layer, - ) - ) - result.extend( - self._build_mamba_remote_descs( - nixl_agent_meta, - tp_ratio, - transfer_topo.tp_rank, - info.remote_physical_blocks_per_logical, - ) - ) - return result - - # NOTE (ZhanqiuHu): See ABC comment on _should_skip_fa for context. - # This method also handles FA replication (see ABC helpers above). - def _build_fa_remote_descs( - self, - transfer_topo: TransferTopology, - nixl_agent_meta: NixlAgentMetadata, - info: MambaEngineTransferInfo, - tp_ratio: int, - block_len_per_layer: list[int], - ): - """Build remote FA descriptors for mamba models using - transfer_cfg for GQA-aware sizing.""" - tp_rank = transfer_topo.tp_rank - tp_size = transfer_topo.tp_size - is_mla = transfer_topo.is_mla - total_num_kv_heads = transfer_topo.total_num_kv_heads - is_blocks_first = transfer_topo.is_kv_layout_blocks_first - block_size_ratio = transfer_topo.block_size_ratio(info.remote_block_size) - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 " - f"is not tested. Got {block_size_ratio=}." - ) - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate( - nixl_agent_meta.kv_caches_base_addr, - ): - local_block_len = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len - if tp_ratio < 0 and not is_mla: - local_block_len = local_block_len // info.remote_num_fa_reads - rank_offset = _fa_rank_offset( - info, - remote_kv_block_len, - tp_rank=tp_rank, - tp_size=tp_size, - is_mla=is_mla, - total_num_kv_heads=total_num_kv_heads, - ) - num_blocks = nixl_agent_meta.num_blocks - page_size = nixl_agent_meta.block_lens[i] - dev_id = nixl_agent_meta.device_id - for blk in range(num_blocks): - addr = base_addr + blk * page_size + rank_offset - result.append((addr, local_block_len, dev_id)) - if is_blocks_first: - second_split = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - if tp_ratio < 0 and not is_mla: - second_split = second_split // info.remote_num_fa_reads - for blk in range(num_blocks): - addr = base_addr + blk * page_size + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append( - ( - v_addr, - second_split, - dev_id, - ) - ) - return result - - def _build_mamba_remote_descs( - self, - nixl_agent_meta, - tp_ratio, - tp_rank, - physical_blocks_per_logical, - ): - """Build 4 remote desc regions (x, B, C, ssm) per layer - for the 3-read transfer. - - Mamba conv state is always TP-sharded, even when attention KV - is replicated (num_kv_heads < tp_size). - """ - effective_ratio = max(tp_ratio, 1) - local_offset = tp_rank % effective_ratio - conv_size_remote = nixl_agent_meta.ssm_sizes[0] - - if tp_ratio >= 1: - # D_TP >= P_TP: P page is larger, D reads its slice. - conv_offsets = self._conv_decomp.remote_conv_offsets( - local_offset, - effective_ratio, - ) - # SSM temporal state is also TP-sharded on the heads dimension. - ssm_read_size = self._ssm_sizes[1] - else: - # NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages - # are smaller than D's. self._conv_decomp has D-sized dimensions, - # but we need P-sized offsets. Scale down by |tp_ratio|. - abs_ratio = -tp_ratio - xb_p = self._conv_decomp.x_bytes // abs_ratio - bb_p = self._conv_decomp.b_bytes // abs_ratio - conv_offsets = [ - (0, xb_p), - (xb_p, bb_p), - (xb_p + bb_p, bb_p), - ] - ssm_read_size = nixl_agent_meta.ssm_sizes[1] - - # Assume same num_blocks for mamba and fa - remote_ratio = physical_blocks_per_logical - num_blocks = nixl_agent_meta.num_blocks // remote_ratio - dev_id = nixl_agent_meta.device_id - - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate( - nixl_agent_meta.kv_caches_base_addr, - ): - # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case - # block lengths vary across layers (e.g. MLA). - page_stride = nixl_agent_meta.block_lens[i] * remote_ratio - for off, sz in conv_offsets: - for blk in range(num_blocks): - result.append( - ( - base_addr + blk * page_stride + off, - sz, - dev_id, - ) - ) - for blk in range(num_blocks): - ssm_addr = ( - base_addr - + blk * page_stride - + conv_size_remote - + local_offset * ssm_read_size - ) - result.append((ssm_addr, ssm_read_size, dev_id)) - return result - - def build_src_split_handles( - self, - transfer_topo, - engine_id, - src_blocks_data, - num_descs, - ): - remote_info = transfer_topo.get_engine_info(engine_id) - assert isinstance(remote_info, MambaEngineTransferInfo) - info = remote_info - tp_size = transfer_topo.tp_size - assert info.remote_tp_size > tp_size - abs_tp = info.remote_tp_size // tp_size - if self.needs_split_handles( - info, - tp_size=tp_size, - is_mla=transfer_topo.is_mla, - ): - result = list( - self.compute_split_handle_data( - info, - src_blocks_data, - num_descs, - abs_tp, - total_num_kv_heads=transfer_topo.total_num_kv_heads, - ) - ) - logger.info( - "Mamba-HMA split handles: targets=%s, fa_reads=%s, " - "fa_entry=%s, mamba_reads=%s, num_descs=%s", - info.remote_all_source_ranks, - info.remote_num_fa_reads, - info.remote_fa_descriptor_bytes, - info.remote_num_mamba_reads, - num_descs, - ) - return result - return [] - - def compute_read_specs( - self, - local_block_ids, - remote_block_ids, - remote_ranks, - remote_info, - ): - assert isinstance(remote_info, MambaEngineTransferInfo) - info = remote_info - specs: list[ReadSpec] = [] - for rank in remote_ranks: - filtered_local, filtered_remote = self.filter_block_ids_for_rank( - info, - rank, - local_block_ids, - remote_block_ids, - ) - specs.append( - ReadSpec( - remote_rank=rank, - local_block_ids=filtered_local, - remote_block_ids=filtered_remote, - ) - ) - return specs - - def build_engine_transfer_info( - self, - *, - # Local topology - transfer_topo: TransferTopology, - # Block geometry - local_block_len: int, - # Remote facts (from NixlAgentMetadata handshake) - remote_tp_size: int, - remote_block_size: int, - remote_block_len: int, - remote_physical_blocks_per_logical: int, - ) -> MambaEngineTransferInfo: - K = transfer_topo.total_num_kv_heads - local_tp = transfer_topo.tp_size - local_rank = transfer_topo.tp_rank - is_mla = transfer_topo.is_mla - is_kv_layout_blocks_first = transfer_topo.is_kv_layout_blocks_first - - is_remote_replicated = remote_tp_size > K - remote_physical_heads = max(1, K // remote_tp_size) - - if local_tp >= remote_tp_size: - assert local_tp % remote_tp_size == 0 - tp_ratio = local_tp // remote_tp_size - else: - assert remote_tp_size % local_tp == 0 - tp_ratio = -(remote_tp_size // local_tp) - - abs_tp = -tp_ratio if tp_ratio < 0 else 1 - - mamba_range: range | None = None - if tp_ratio < 0: - mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) - - # ---- FA read targets ---- - if is_mla or tp_ratio >= 0: - num_fa_reads = 1 - fa_source_ranks: list[int] = ( - [0] - if is_mla - else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] - ) - else: - local_needs = _physical_head_range(local_tp, K, local_rank) - search_range = ( - mamba_range if mamba_range is not None else range(remote_tp_size) - ) - seen: set[tuple[int, int]] = set() - fa_source_ranks = [] - for p in search_range: - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - if not fa_source_ranks: - for p in range(remote_tp_size): - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - num_fa_reads = len(fa_source_ranks) - - # ---- All source ranks (mamba + FA) ---- - if mamba_range is not None and abs_tp > num_fa_reads: - num_mamba_reads = abs_tp - all_source_ranks = list(mamba_range) - else: - num_mamba_reads = num_fa_reads - all_source_ranks = list(fa_source_ranks) - - # ---- FA descriptor bytes ---- - effective_block_len = min(local_block_len, remote_block_len) - if is_kv_layout_blocks_first: - fa_descriptor_bytes = effective_block_len // 2 - else: - fa_descriptor_bytes = effective_block_len - - # ---- Validation ---- - is_local_replicated = local_tp > K - if is_local_replicated and is_remote_replicated and tp_ratio > 0: - logger.info( - "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", - local_tp, - remote_tp_size, - K, - ) - tt_set = set(all_source_ranks) - for t in fa_source_ranks: - if t not in tt_set: - logger.error( - "FA source rank %d NOT in all_source_ranks %s.", - t, - all_source_ranks, - ) - if is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: - local_k_half = local_block_len // 2 - remote_k_half = remote_block_len // 2 - expected = local_k_half // num_fa_reads - if expected != remote_k_half: - logger.warning( - "FA size mismatch: local_k_half=%d / reads=%d = %d, " - "but remote_k_half=%d.", - local_k_half, - num_fa_reads, - expected, - remote_k_half, - ) - - return MambaEngineTransferInfo( - remote_tp_size=remote_tp_size, - remote_block_len=remote_block_len, - remote_block_size=remote_block_size, - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, - remote_fa_source_ranks=tuple(fa_source_ranks), - remote_all_source_ranks=tuple(all_source_ranks), - remote_num_fa_reads=num_fa_reads, - remote_num_mamba_reads=num_mamba_reads, - remote_fa_descriptor_bytes=fa_descriptor_bytes, - is_remote_replicated=is_remote_replicated, - remote_physical_heads=remote_physical_heads, - ) - - def needs_split_handles( - self, - info: MambaEngineTransferInfo, - tp_size: int, - is_mla: bool, - ) -> bool: - """Whether per-remote-rank split handles are needed. - - True when FA and mamba have different read counts, requiring - different splitting factors in the local handle. - """ - tp_ratio = ( - tp_size // info.remote_tp_size - if tp_size >= info.remote_tp_size - else -(info.remote_tp_size // tp_size) - ) # noqa: E501 - return tp_ratio < 0 and not is_mla and len(info.remote_all_source_ranks) > 1 - - def compute_split_handle_data( - self, - info: MambaEngineTransferInfo, - src_blocks_data: list[tuple[int, int, int]], - num_fa_descs: int, - abs_tp: int, - total_num_kv_heads: int, - ) -> list[list[tuple[int, int, int]]]: - """Per-remote-rank (addr, len, dev) triples for split handles. - - FA descriptors (indices < num_fa_descs) are sliced by - ``remote_num_fa_reads``; mamba descriptors are sliced uniformly - by ``abs_tp``. - """ - all_handle_data: list[list[tuple[int, int, int]]] = [] - for p_idx, p_rank in enumerate(info.remote_all_source_ranks): - handle_data: list[tuple[int, int, int]] = [] - skip_fa = _should_skip_fa(info, p_rank) - fa_slot = ( - _fa_head_slot(info, p_rank, total_num_kv_heads) if not skip_fa else 0 - ) - for j, (addr, local_len, dev) in enumerate(src_blocks_data): - if j < num_fa_descs: - assert info.remote_num_fa_reads >= 1 - fa_chunk = local_len // info.remote_num_fa_reads - handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) - else: - mamba_chunk = local_len // abs_tp - handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) - all_handle_data.append(handle_data) - return all_handle_data - - def filter_block_ids_for_rank( - self, - info: MambaEngineTransferInfo, - remote_rank: int, - local_ids: BlockIds, - remote_ids: BlockIds, - ) -> tuple[BlockIds, BlockIds]: - """Zero out FA groups for remote ranks outside ``fa_source_ranks``. - - Returns (filtered_local_ids, filtered_remote_ids). When the - remote rank carries FA data for this local rank, returns the - inputs unchanged. - """ - if not _should_skip_fa(info, remote_rank): - return local_ids, remote_ids - num_groups = len(local_ids) - filtered_local: list[list[int]] = [ - local_ids[g] if self._group_kinds[g].is_ssm else [] - for g in range(num_groups) - ] - filtered_remote: list[list[int]] = [ - remote_ids[g] if self._group_kinds[g].is_ssm else [] - for g in range(num_groups) - ] - return filtered_local, filtered_remote diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 548e6c9bff0d..29cbcc4817a8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -114,7 +114,9 @@ class EngineTransferPlan: Regions are split into ``fa_regions`` and ``ssm_regions`` matching the descriptor handle layout: [FA descriptors | SSM descriptors]. - ``group_kinds`` maps each kv_cache_group to its type. + ``group_kinds`` maps each kv_cache_group to its type for descriptor + indexing. ``source_ranks_per_group`` encodes which ranks read each + group — executors use this instead of group_kinds for rank routing. """ # Regions in descriptor handle order @@ -124,18 +126,17 @@ class EngineTransferPlan: # Per-group geometric properties (worker-facing, model-agnostic) physical_per_logical: tuple[int, ...] - # Per-group type (FA, SWA, MAMBA, GDN). + # Per-group type — used only for descriptor indexing (save path). group_kinds: tuple[GroupKind, ...] - # Source rank routing + # Per-group ordered source ranks. Position = local piece index. + source_ranks_per_group: tuple[tuple[int, ...], ...] + + # Superset of all source ranks (union of all groups). all_source_ranks: tuple[int, ...] - fa_source_ranks: tuple[int, ...] - fa_source_set: frozenset[int] - # Split handle parameters - num_fa_reads: int - num_mamba_reads: int - fa_head_slots: dict[int, int] + # Maps each source rank to its FA head slot index. + rank_to_attention_slot: dict[int, int] # Remote engine facts (needed by worker at read time) remote_tp_size: int @@ -144,8 +145,8 @@ class EngineTransferPlan: remote_physical_blocks_per_logical: int # Stride for expanding remote logical block IDs to kernel block IDs. - # Dense: equals local physical_blocks_per_logical (stride == count). - # Mamba: equals remote_physical_blocks_per_logical (stride != count). + # Dense: local_physical_blocks_per_logical. + # Mamba: remote_physical_blocks_per_logical. remote_expansion_stride: int @property @@ -168,203 +169,119 @@ def _get_kv_block_len( return block_len_per_layer[layer_idx] -def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - -def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - +@dataclass(frozen=True) +class TPMapping: + """Complete local-to-remote TP mapping for one remote engine.""" -def _compute_tp_ratio(tp_size: int, remote_tp_size: int) -> int: - if tp_size >= remote_tp_size: - assert tp_size % remote_tp_size == 0 - return tp_size // remote_tp_size - assert remote_tp_size % tp_size == 0 - return -(remote_tp_size // tp_size) + source_ranks_per_group: tuple[tuple[int, ...], ...] + all_source_ranks: tuple[int, ...] + rank_to_attention_slot: dict[int, int] + rank_offset_factor: int -def _compute_fa_source_ranks( +def _compute_tp_mapping( tp_rank: int, tp_size: int, remote_tp_size: int, is_mla: bool, total_num_kv_heads: int, -) -> tuple[list[int], list[int], int, int]: - """Compute FA and all source ranks for Mamba models. + group_kinds: tuple[GroupKind, ...], +) -> TPMapping: + """Build the complete local-to-remote TP mapping. - Returns (fa_source_ranks, all_source_ranks, num_fa_reads, num_mamba_reads). - Mirrors the logic in MambaModelBlockTransferPolicy.build_engine_transfer_info. + Computes source ranks, head slot assignments, and the rank offset + factor in a single pass. Both generators call this and unpack. """ K = total_num_kv_heads - tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) - abs_tp = -tp_ratio if tp_ratio < 0 else 1 - mamba_range: range | None = None - if tp_ratio < 0: - mamba_range = range(tp_rank * abs_tp, (tp_rank + 1) * abs_tp) - - fa_source_ranks: list[int] - if is_mla or tp_ratio >= 0: - num_fa_reads = 1 - if is_mla: - fa_source_ranks = [0] - elif tp_ratio > 0: - fa_source_ranks = [tp_rank // tp_ratio] - else: - fa_source_ranks = [tp_rank] + + # --- Attention source ranks --- + if is_mla: + attn_ranks = [0] + elif tp_size >= remote_tp_size: + attn_ranks = [tp_rank * remote_tp_size // tp_size] else: - local_needs = _physical_head_range(tp_size, K, tp_rank) - search_range = mamba_range if mamba_range is not None else range(remote_tp_size) - seen: set[tuple[int, int]] = set() - fa_source_ranks = [] - for p in search_range: - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - if not fa_source_ranks: - for p in range(remote_tp_size): - p_has = _physical_head_range(remote_tp_size, K, p) - ov = _range_overlap(local_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - fa_source_ranks.append(p) - num_fa_reads = len(fa_source_ranks) - - if mamba_range is not None and abs_tp > num_fa_reads: - num_mamba_reads = abs_tp - all_source_ranks = list(mamba_range) + # P > D: one local rank reads from multiple remote ranks. + # GQA dedup: when K < remote_tp_size, several remote ranks + # hold the same KV head. np.unique keeps only the first + # rank per unique head so we don't issue redundant reads. + abs_tp = remote_tp_size // tp_size + start = tp_rank * abs_tp + heads = np.arange(start, start + abs_tp) * K // remote_tp_size + _, unique_idx = np.unique(heads, return_index=True) + attn_ranks = (start + np.sort(unique_idx)).tolist() + + # --- All source ranks (expand for SSM if needed) --- + has_ssm = any(k.is_ssm for k in group_kinds) + if not has_ssm or tp_size >= remote_tp_size: + all_ranks = list(attn_ranks) else: - num_mamba_reads = num_fa_reads - all_source_ranks = list(fa_source_ranks) - - return fa_source_ranks, all_source_ranks, num_fa_reads, num_mamba_reads - - -def _compute_fa_head_slots( - fa_source_ranks: list[int], - all_source_ranks: list[int], - remote_tp_size: int, - total_num_kv_heads: int, -) -> dict[int, int]: - """Pre-compute the FA head slot for each source rank. - - Mirrors _fa_head_slot from block_transfer_policy.py but pre-computes - all values at plan generation time. - """ - fa_index = {r: i for i, r in enumerate(fa_source_ranks)} - K = total_num_kv_heads - result: dict[int, int] = {} - for rank in all_source_ranks: - if rank in fa_index: - result[rank] = fa_index[rank] + abs_tp = remote_tp_size // tp_size + if abs_tp > len(attn_ranks): + all_ranks = list( + range( + tp_rank * abs_tp, + (tp_rank + 1) * abs_tp, + ) + ) else: - r_head = _physical_head_range(remote_tp_size, K, rank) - for target in fa_source_ranks: - t_head = _physical_head_range(remote_tp_size, K, target) - if _range_overlap(r_head, t_head): - result[rank] = fa_index[target] - break - else: - result[rank] = 0 - return result - + all_ranks = list(attn_ranks) -def _compute_fa_rank_offset( - tp_rank: int, - tp_size: int, - tp_ratio: int, - is_mla: bool, - total_num_kv_heads: int, - remote_tp_size: int, - fa_source_ranks: list[int], - remote_kv_block_len: int, -) -> int: - """Byte offset into remote FA block for this local rank. + # --- Per-group ordered source ranks --- + attn_tuple = tuple(attn_ranks) + all_tuple = tuple(all_ranks) + source_ranks_per_group = tuple( + all_tuple if k.is_ssm else attn_tuple for k in group_kinds + ) - Mirrors _fa_rank_offset from block_transfer_policy.py, but takes - raw parameters instead of MambaEngineTransferInfo. - """ - if is_mla or tp_ratio <= 0: - return 0 - K = total_num_kv_heads - is_local_replicated = tp_size > K - if is_local_replicated: + # --- Attention head slots --- + head_to_slot: dict[int, int] = {} + for i, r in enumerate(attn_ranks): + head_to_slot[r * K // remote_tp_size] = i + rank_to_attention_slot = { + r: head_to_slot.get(r * K // remote_tp_size, 0) for r in all_ranks + } + + # --- Rank offset factor --- + if is_mla or tp_size <= remote_tp_size: + rank_offset_factor = 0 + elif tp_size > K: local_head = tp_rank * K // tp_size - p_rank = fa_source_ranks[0] - p_start = p_rank * K // remote_tp_size - return (local_head - p_start) * remote_kv_block_len - return tp_rank % tp_ratio * remote_kv_block_len - + p_start = attn_ranks[0] * K // remote_tp_size + rank_offset_factor = local_head - p_start + else: + rank_offset_factor = tp_rank % (tp_size // remote_tp_size) -# ====================================================================== -# 3. Plan generators — the ONLY model-specific code -# ====================================================================== + return TPMapping( + source_ranks_per_group=source_ranks_per_group, + all_source_ranks=all_tuple, + rank_to_attention_slot=rank_to_attention_slot, + rank_offset_factor=rank_offset_factor, + ) -def generate_dense_plan( +def _build_fa_regions( *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_blocks_first: bool, block_len_per_layer: list[int], - block_size: int, - remote_tp_size: int, - remote_block_size: int, - remote_num_blocks: int, remote_block_lens: list[int], + is_blocks_first: bool, + block_size_ratio: int, + num_attn_reads: int, + rank_offset_factor: int, + remote_num_blocks: int, remote_physical_blocks_per_logical: int, - local_physical_blocks_per_logical: int, -) -> EngineTransferPlan: - """Generate transfer plan for dense (FA-only) models. +) -> list[RegionPlan]: + """Build FA (attention) regions for the transfer plan. - Mirrors the combined logic of: - - DenseModelBlockTransferPolicy.build_engine_transfer_info() - - DenseModelBlockTransferPolicy.build_remote_descs() + K bytes = remote_kv_block_len / num_attn_reads. + V bytes = local_block_len / num_attn_reads (no block_size_ratio). + Offset = rank_offset_factor * remote_kv_block_len per layer. """ - tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) - block_size_ratio = block_size // remote_block_size - indexes_into_remote = ( - not (is_mla or remote_tp_size > total_num_kv_heads) and tp_ratio > 0 - ) - - # Source ranks — mirrors TransferTopology.target_remote_ranks for dense - if tp_ratio > 0: - all_source_ranks: tuple[int, ...] = (tp_rank // tp_ratio,) - else: - abs_ratio = -tp_ratio - all_source_ranks = tuple(tp_rank * abs_ratio + i for i in range(abs_ratio)) - - # Build FA regions — one (K, optionally V) per layer fa_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): local_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) remote_kv_block_len = local_block_len // block_size_ratio - - k_desc_bytes = local_block_len - if block_size_ratio > 1: - k_desc_bytes = remote_kv_block_len - if tp_ratio < 0 and not is_mla: - k_desc_bytes = k_desc_bytes // (-tp_ratio) - - rank_offset = ( - tp_rank % tp_ratio * remote_kv_block_len if indexes_into_remote else 0 - ) + k_desc_bytes = remote_kv_block_len // num_attn_reads + rank_offset = rank_offset_factor * remote_kv_block_len page_stride = remote_block_lens[i] fa_regions.append( @@ -380,10 +297,7 @@ def generate_dense_plan( ) if is_blocks_first: - v_desc_bytes = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) - if tp_ratio < 0 and not is_mla: - v_desc_bytes = v_desc_bytes // (-tp_ratio) - + v_desc_bytes = local_block_len // num_attn_reads fa_regions.append( RegionPlan( kind=RegionKind.FA_V, @@ -396,21 +310,61 @@ def generate_dense_plan( ) ) - # For dense split handles: fa_head_slots maps rank → index, - # so the executor uniformly splits all descs by abs_tp. - fa_head_slots = {r: i for i, r in enumerate(all_source_ranks)} + return fa_regions + + +# ====================================================================== +# 3. Plan generators — the ONLY model-specific code +# ====================================================================== + + +def generate_dense_plan( + *, + tp_rank: int, + tp_size: int, + is_mla: bool, + total_num_kv_heads: int, + is_blocks_first: bool, + block_len_per_layer: list[int], + block_size: int, + remote_tp_size: int, + remote_block_size: int, + remote_num_blocks: int, + remote_block_lens: list[int], + remote_physical_blocks_per_logical: int, + local_physical_blocks_per_logical: int, +) -> EngineTransferPlan: + """Generate transfer plan for dense (FA-only) models.""" + block_size_ratio = block_size // remote_block_size + + m = _compute_tp_mapping( + tp_rank, + tp_size, + remote_tp_size, + is_mla, + total_num_kv_heads, + group_kinds=(GroupKind.FA,), + ) + + fa_regions = _build_fa_regions( + block_len_per_layer=block_len_per_layer, + remote_block_lens=remote_block_lens, + is_blocks_first=is_blocks_first, + block_size_ratio=block_size_ratio, + num_attn_reads=len(m.source_ranks_per_group[0]), + rank_offset_factor=m.rank_offset_factor, + remote_num_blocks=remote_num_blocks, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=(), physical_per_logical=(remote_physical_blocks_per_logical,), group_kinds=(GroupKind.FA,), - all_source_ranks=all_source_ranks, - fa_source_ranks=all_source_ranks, - fa_source_set=frozenset(all_source_ranks), - num_fa_reads=len(all_source_ranks), - num_mamba_reads=0, - fa_head_slots=fa_head_slots, + source_ranks_per_group=m.source_ranks_per_group, + all_source_ranks=m.all_source_ranks, + rank_to_attention_slot=m.rank_to_attention_slot, remote_tp_size=remote_tp_size, remote_block_size=remote_block_size, remote_block_len=remote_block_lens[0], @@ -438,119 +392,49 @@ def generate_mamba_plan( ssm_sizes: tuple[int, int], remote_ssm_sizes: tuple[int, int], ) -> EngineTransferPlan: - """Generate transfer plan for hybrid Mamba (SSM + FA) models. - - Mirrors the combined logic of: - - MambaModelBlockTransferPolicy.build_engine_transfer_info() - - MambaModelBlockTransferPolicy._build_fa_remote_descs() - - MambaModelBlockTransferPolicy._build_mamba_remote_descs() - """ - tp_ratio = _compute_tp_ratio(tp_size, remote_tp_size) + """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" block_size_ratio = block_size // remote_block_size assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got {block_size_ratio=}." ) - # ---- Source rank computation ---- - ( - fa_source_ranks, - all_source_ranks, - num_fa_reads, - num_mamba_reads, - ) = _compute_fa_source_ranks( + m = _compute_tp_mapping( tp_rank, tp_size, remote_tp_size, is_mla, total_num_kv_heads, - ) - - # ---- FA head slots (for split handles) ---- - fa_head_slots = _compute_fa_head_slots( - fa_source_ranks, - all_source_ranks, - remote_tp_size, - total_num_kv_heads, + group_kinds, ) # ---- FA regions ---- - fa_regions: list[RegionPlan] = [] - for i in range(len(remote_block_lens)): - local_block_len = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - remote_kv_block_len = local_block_len // block_size_ratio - - k_desc_bytes = local_block_len - if block_size_ratio > 1: - k_desc_bytes = remote_kv_block_len - if tp_ratio < 0 and not is_mla: - k_desc_bytes = k_desc_bytes // num_fa_reads - - rank_offset = _compute_fa_rank_offset( - tp_rank, - tp_size, - tp_ratio, - is_mla, - total_num_kv_heads, - remote_tp_size, - fa_source_ranks, - remote_kv_block_len, - ) - - page_stride = remote_block_lens[i] - - fa_regions.append( - RegionPlan( - kind=RegionKind.FA_K, - layer_idx=i, - descriptor_bytes=k_desc_bytes, - offset_in_page=rank_offset, - page_stride=page_stride, - num_blocks=remote_num_blocks, - physical_per_logical=remote_physical_blocks_per_logical, - ) - ) - - if is_blocks_first: - v_desc_bytes = _get_kv_block_len( - i, - block_len_per_layer, - is_blocks_first, - ) - if tp_ratio < 0 and not is_mla: - v_desc_bytes = v_desc_bytes // num_fa_reads - - fa_regions.append( - RegionPlan( - kind=RegionKind.FA_V, - layer_idx=i, - descriptor_bytes=v_desc_bytes, - offset_in_page=rank_offset + page_stride // 2, - page_stride=page_stride, - num_blocks=remote_num_blocks, - physical_per_logical=remote_physical_blocks_per_logical, - ) - ) + fa_regions = _build_fa_regions( + block_len_per_layer=block_len_per_layer, + remote_block_lens=remote_block_lens, + is_blocks_first=is_blocks_first, + block_size_ratio=block_size_ratio, + num_attn_reads=len(m.source_ranks_per_group[0]), + rank_offset_factor=m.rank_offset_factor, + remote_num_blocks=remote_num_blocks, + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ) # ---- SSM regions ---- - effective_ratio = max(tp_ratio, 1) - local_offset = tp_rank % effective_ratio + effective_ratio = tp_size // remote_tp_size if tp_size >= remote_tp_size else 1 + local_offset = tp_rank % max(effective_ratio, 1) conv_size_remote = remote_ssm_sizes[0] remote_ratio = remote_physical_blocks_per_logical ssm_num_blocks = remote_num_blocks // remote_ratio - if tp_ratio >= 1: + if tp_size >= remote_tp_size: conv_offsets = conv_decomp.remote_conv_offsets( local_offset, effective_ratio, ) ssm_read_size = ssm_sizes[1] else: - abs_ratio = -tp_ratio + abs_ratio = remote_tp_size // tp_size xb_p = conv_decomp.x_bytes // abs_ratio bb_p = conv_decomp.b_bytes // abs_ratio conv_offsets = [ @@ -560,7 +444,11 @@ def generate_mamba_plan( ] ssm_read_size = remote_ssm_sizes[1] - conv_kinds = [RegionKind.SSM_CONV_X, RegionKind.SSM_CONV_B, RegionKind.SSM_CONV_C] + conv_kinds = [ + RegionKind.SSM_CONV_X, + RegionKind.SSM_CONV_B, + RegionKind.SSM_CONV_C, + ] ssm_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): page_stride = remote_block_lens[i] * remote_ratio @@ -598,12 +486,9 @@ def generate_mamba_plan( ssm_regions=tuple(ssm_regions), physical_per_logical=physical_per_logical_per_group, group_kinds=group_kinds, - all_source_ranks=tuple(all_source_ranks), - fa_source_ranks=tuple(fa_source_ranks), - fa_source_set=frozenset(fa_source_ranks), - num_fa_reads=num_fa_reads, - num_mamba_reads=num_mamba_reads, - fa_head_slots=fa_head_slots, + source_ranks_per_group=m.source_ranks_per_group, + all_source_ranks=m.all_source_ranks, + rank_to_attention_slot=m.rank_to_attention_slot, remote_tp_size=remote_tp_size, remote_block_size=remote_block_size, remote_block_len=remote_block_lens[0], @@ -646,8 +531,7 @@ def build_remote_descs_from_plan( ) -> list[tuple[int, int, int]]: """Build (addr, len, dev_id) descriptor tuples from plan. - Replaces DenseModelBlockTransferPolicy.build_remote_descs() and - MambaModelBlockTransferPolicy.build_remote_descs(). + Builds remote descriptors from a pre-computed plan. """ result: list[tuple[int, int, int]] = [] dev_id = nixl_agent_meta.device_id @@ -670,8 +554,7 @@ def compute_desc_ids_from_plan( ) -> np.ndarray: """Compute NIXL descriptor IDs for given block IDs. - Replaces DenseModelBlockTransferPolicy.get_block_descs_ids() and - MambaModelBlockTransferPolicy.get_block_descs_ids(). + Computes descriptor indices from a pre-computed plan. """ num_fa_regions = len(plan.fa_regions) num_ssm_regions = len(plan.ssm_regions) @@ -687,7 +570,12 @@ def compute_desc_ids_from_plan( all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): group_arr = np.asarray(group) - if plan.group_kinds[i].is_ssm: + if plan.group_kinds[i].is_attention: + fa_region_ids = np.arange(num_fa_regions)[:, None] + all_descs.append( + (fa_region_ids * num_blocks + group_arr[None, :]).flatten() + ) + elif plan.group_kinds[i].is_ssm: ssm_region_ids = np.arange(num_ssm_regions)[:, None] all_descs.append( ( @@ -695,10 +583,7 @@ def compute_desc_ids_from_plan( ).flatten() ) else: - fa_region_ids = np.arange(num_fa_regions)[:, None] - all_descs.append( - (fa_region_ids * num_blocks + group_arr[None, :]).flatten() - ) + raise ValueError(f"Unknown group kind {plan.group_kinds[i]} at index {i}") return np.concatenate(all_descs) @@ -710,39 +595,28 @@ def compute_read_specs_from_plan( ) -> list[ReadSpec]: """Compute read specs from plan. - Replaces compute_read_specs() + filter_block_ids_for_rank(). - No _should_skip_fa — the plan structurally encodes which ranks - serve which groups via fa_source_set. + For each source rank, includes only the groups whose + source_ranks_per_group contains that rank. """ - specs: list[ReadSpec] = [] - for rank in plan.all_source_ranks: - skip_fa = rank not in plan.fa_source_set - if not skip_fa: - specs.append( - ReadSpec( - remote_rank=rank, - local_block_ids=local_block_ids, - remote_block_ids=remote_block_ids, - ) - ) - else: - num_groups = len(local_block_ids) - filtered_local: list[list[int]] = [ - list(local_block_ids[g]) if plan.group_kinds[g].is_ssm else [] + num_groups = len(local_block_ids) + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=[ + list(local_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] for g in range(num_groups) - ] - filtered_remote: list[list[int]] = [ - list(remote_block_ids[g]) if plan.group_kinds[g].is_ssm else [] + ], + remote_block_ids=[ + list(remote_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] for g in range(num_groups) - ] - specs.append( - ReadSpec( - remote_rank=rank, - local_block_ids=filtered_local, - remote_block_ids=filtered_remote, - ) - ) - return specs + ], + ) + for rank in plan.all_source_ranks + ] def build_local_splits_from_plan( @@ -752,30 +626,29 @@ def build_local_splits_from_plan( ) -> list[list[tuple[int, int, int]]]: """Build split handle data for P_TP > D_TP scenario. - Replaces DenseModelBlockTransferPolicy.build_src_split_handles() and - MambaModelBlockTransferPolicy.build_src_split_handles() + - compute_split_handle_data(). - - When num_ssm_regions == 0 (dense), all descs are FA and the split - is uniform. When SSM regions exist, FA and SSM descs get different - split factors. + num_fa_descs is the boundary between FA and SSM descriptors. + Split counts are derived from source_ranks_per_group lengths. + FA uses rank_to_attention_slot for the slot offset; + SSM uses the rank's positional index. """ - abs_tp = len(plan.all_source_ranks) + fa_num_splits = len(plan.source_ranks_per_group[0]) + + has_ssm_descs = num_fa_descs < len(src_blocks_data) + ssm_num_splits = len(plan.source_ranks_per_group[-1]) if has_ssm_descs else 0 + result: list[list[tuple[int, int, int]]] = [] for p_idx, p_rank in enumerate(plan.all_source_ranks): - skip_fa = p_rank not in plan.fa_source_set - fa_slot = plan.fa_head_slots.get(p_rank, 0) if not skip_fa else 0 + fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) handle: list[tuple[int, int, int]] = [] for j, (addr, local_len, dev) in enumerate(src_blocks_data): if j < num_fa_descs: - assert plan.num_fa_reads >= 1 - fa_chunk = local_len // plan.num_fa_reads - handle.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) + chunk = local_len // fa_num_splits + handle.append((addr + fa_slot * chunk, chunk, dev)) else: - mamba_chunk = local_len // abs_tp - handle.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) + chunk = local_len // ssm_num_splits + handle.append((addr + p_idx * chunk, chunk, dev)) result.append(handle) return result @@ -915,8 +788,7 @@ def visualize_plan(plan: EngineTransferPlan) -> str: lines = [ f"EngineTransferPlan(remote_tp={plan.remote_tp_size}, " f"remote_bs={plan.remote_block_size}):", - f" Source ranks: all={list(plan.all_source_ranks)}, " - f"fa={list(plan.fa_source_ranks)}", + f" Source ranks: all={list(plan.all_source_ranks)}", ] total_descs = 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 411b81f2336d..3a905f571bbc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, + EngineTransferInfo, TransferTopology, get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, @@ -29,9 +30,6 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.block_transfer_policy import ( - ModelBlockTransferPolicy, -) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( GET_META_MSG, NixlAgentMetadata, @@ -48,6 +46,11 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, GroupKind, + build_local_descs, + build_local_splits_from_plan, + build_remote_descs_from_plan, + compute_desc_ids_from_plan, + compute_read_specs_from_plan, generate_dense_plan, generate_mamba_plan, ) @@ -343,13 +346,6 @@ def __init__( self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() - self.transfer_policy = ModelBlockTransferPolicy.create( - kv_cache_config=kv_cache_config, - layer_specs=self._layer_specs, - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, - tp_size=vllm_config.parallel_config.tensor_parallel_size, - ) - # Per-engine transfer plans. Generated during handshake, used by # per-request hot path (model-agnostic). self._transfer_plans: dict[EngineId, EngineTransferPlan] = {} @@ -899,17 +895,18 @@ def register_local_xfer_handler( block_size_ratio = self.block_size // block_size local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] - blocks_data = self.transfer_policy.build_local_descs( - # Memory + blocks_data = build_local_descs( + has_mamba=self._has_mamba, + conv_decomp=self._conv_decomp, + ssm_sizes=self._mamba_ssm_size, base_addresses=local_base_addresses, device_id=self.device_id, - # Block geometry num_blocks=self.num_blocks, logical_num_blocks=self._logical_num_blocks, block_size_ratio=block_size_ratio, block_len_per_layer=self.block_len_per_layer, - # Layout is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) logger.debug( "Created %s blocks for src engine %s and rank %s on device id %s", @@ -994,12 +991,7 @@ def add_remote_agent( if self._has_mamba else 1 ) - transfer_info = self.transfer_policy.build_engine_transfer_info( - # Local topology - transfer_topo=transfer_topo, - # Block geometry - local_block_len=self.block_len_per_layer[0], - # Remote facts (from NixlAgentMetadata handshake) + transfer_info = EngineTransferInfo( remote_tp_size=remote_tp_size, remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], @@ -1075,6 +1067,8 @@ def add_remote_agent( tp_ratio, ) + plan = self._transfer_plans[engine_id] + ### (Optional) Register local agent memory regions. MLA is not split. if ( tp_ratio < 0 @@ -1086,9 +1080,8 @@ def add_remote_agent( # we only do this once per remote tp_size (replica-friendly). self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for handle_data in self.transfer_policy.build_src_split_handles( - transfer_topo, - engine_id, + for handle_data in build_local_splits_from_plan( + plan, self.src_blocks_data, self.num_descs, ): @@ -1099,12 +1092,7 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) ### Register remote agent memory regions - blocks_data = self.transfer_policy.build_remote_descs( - transfer_topo, - engine_id, - nixl_agent_meta, - self.block_len_per_layer, - ) + blocks_data = build_remote_descs_from_plan(plan, nixl_agent_meta) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), @@ -1138,8 +1126,8 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self.transfer_topo is not None - remote_info = self.transfer_topo.get_engine_info(remote_engine_id) - assert remote_info.remote_tp_size == remote_tp_size + plan = self._transfer_plans[remote_engine_id] + assert plan.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) block_size_ratio = self.transfer_topo.block_size_ratio( @@ -1424,9 +1412,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize - remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) + plan = self._transfer_plans[meta.remote.engine_id] block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size + plan.remote_block_size ) if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv @@ -1639,21 +1627,18 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.transfer_topo is not None engine_id = meta.remote.engine_id - remote_ranks = self.transfer_topo.target_remote_ranks(engine_id) - remote_info = self.transfer_topo.get_engine_info(engine_id) - tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) - plan = self._transfer_plans[engine_id] + tp_ratio = self.transfer_topo.tp_ratio(plan.remote_tp_size) + meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, plan.remote_expansion_stride, ) remote_block_ids = meta.remote.block_ids - read_specs = self.transfer_policy.compute_read_specs( + read_specs = compute_read_specs_from_plan( + plan, local_block_ids=meta.local_physical_block_ids, remote_block_ids=remote_block_ids, - remote_ranks=remote_ranks, - remote_info=remote_info, ) # D may have to perform multiple reads from different remote ranks. @@ -1666,7 +1651,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_rank = spec.remote_rank local_block_ids = spec.local_block_ids remote_block_ids = spec.remote_block_ids - remote_block_size = remote_info.remote_block_size + remote_block_size = plan.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s with remote block size %s for req %s", @@ -1730,10 +1715,8 @@ def _read_blocks( a single remote worker. """ assert self.transfer_topo is not None - remote_info = self.transfer_topo.get_engine_info(dst_engine_id) - block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size - ) + plan = self._transfer_plans[dst_engine_id] + block_size_ratio = self.transfer_topo.block_size_ratio(plan.remote_block_size) if block_size_ratio > 1: # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. assert not self._is_hma_required @@ -1811,19 +1794,19 @@ def _read_blocks( # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. - # Get descs ids. - remote_block_descs_ids = self.transfer_policy.get_block_descs_ids( + # Get descs ids. Both calls use the same plan since region counts + # (len(fa_regions), len(ssm_regions)) are model-determined and + # identical across engines. + remote_block_descs_ids = compute_desc_ids_from_plan( + plan, block_ids=remote_block_ids, - num_regions=self.num_regions, dst_num_blocks=self.dst_num_blocks[dst_engine_id], - block_len_per_layer=self.block_len_per_layer, - physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, + physical_blocks_per_logical=plan.remote_physical_blocks_per_logical, ) - local_block_descs_ids = self.transfer_policy.get_block_descs_ids( + local_block_descs_ids = compute_desc_ids_from_plan( + plan, block_ids=local_block_ids, - num_regions=self.num_regions, dst_num_blocks=self.dst_num_blocks[self.engine_id], - block_len_per_layer=self.block_len_per_layer, block_size_ratio=block_size_ratio, physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) From d0c4802eb9128a1574f143d16b1ba011a661eaae Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 20:26:55 +0000 Subject: [PATCH 29/49] remove dead fields, visualization, and revert block ID helpers to main Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 10 +- .../kv_connector/unit/test_transfer_plan.py | 35 ------ .../kv_connector/v1/nixl/transfer_plan.py | 101 ------------------ .../kv_connector/v1/nixl/worker.py | 49 ++++++--- 4 files changed, 37 insertions(+), 158 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 57bde517125a..5ac2aaef2cba 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -71,17 +71,17 @@ def test_logical_to_kernel_block_ids_with_hma(): NixlConnectorWorker, ) + # Create a mock worker with just the required attributes + # (use __new__ to skip __init__) worker = object.__new__(NixlConnectorWorker) # Simulate HMA scenario: logical block size = 32, kernel block size = 16 # So each logical block maps to 2 kernel blocks eg [0]->[0,1] - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( - GroupKind, - ) - worker._physical_blocks_per_logical_kv_block = 2 - worker._group_kinds = (GroupKind.FA, GroupKind.SWA) + # FA + SW groups (neither is MambaSpec, so both get expanded) + worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True) + # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 295501579673..8db8e01daa10 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -23,7 +23,6 @@ compute_desc_ids_from_plan, compute_read_specs_from_plan, generate_dense_plan, - visualize_plan, ) # ====================================================================== @@ -295,16 +294,6 @@ def test_build_src_split_handles(self, remote_tp_size): assert length == 1024 // remote_tp_size -class TestDensePlanVisualization: - def test_visualize_produces_output(self): - plan = generate_dense_plan( - **_common_plan_params(), - ) - output = visualize_plan(plan) - assert "FA regions" in output - assert "fa_k" in output - - class TestDensePlanStructure: def test_source_ranks_homogeneous(self): plan = generate_dense_plan( @@ -367,7 +356,6 @@ def _make_mamba_plan_for_desc_ids( offset_in_page=0, page_stride=100, num_blocks=fa_num_blocks, - physical_per_logical=1, ) for i in range(num_fa_regions) ) @@ -379,25 +367,18 @@ def _make_mamba_plan_for_desc_ids( offset_in_page=0, page_stride=200, num_blocks=ssm_num_blocks, - physical_per_logical=1, ) for i in range(num_ssm_regions) ) - physical_per_logical = tuple(1 for _ in group_kinds) all_ranks = (0,) source_ranks_per_group = tuple(all_ranks for _ in group_kinds) return EngineTransferPlan( fa_regions=fa_regions, ssm_regions=ssm_regions, - physical_per_logical=physical_per_logical, group_kinds=group_kinds, source_ranks_per_group=source_ranks_per_group, all_source_ranks=(0,), rank_to_attention_slot={0: 0}, - remote_tp_size=1, - remote_block_size=16, - remote_block_len=0, - remote_physical_blocks_per_logical=1, remote_expansion_stride=1, ) @@ -465,15 +446,10 @@ def test_all_source_ranks_serve_fa(self): plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), - physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), source_ranks_per_group=(both, both), all_source_ranks=(0, 1), rank_to_attention_slot={0: 0, 1: 1}, - remote_tp_size=2, - remote_block_size=16, - remote_block_len=0, - remote_physical_blocks_per_logical=1, remote_expansion_stride=1, ) @@ -493,15 +469,10 @@ def test_non_fa_rank_skips_fa_groups(self): plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), - physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1, 2), rank_to_attention_slot={0: 0}, - remote_tp_size=3, - remote_block_size=16, - remote_block_len=0, - remote_physical_blocks_per_logical=1, remote_expansion_stride=1, ) @@ -541,18 +512,12 @@ def test_fa_and_ssm_different_split_factors(self): offset_in_page=0, page_stride=100, num_blocks=10, - physical_per_logical=1, ), ), - physical_per_logical=(1, 1), group_kinds=(GroupKind.FA, GroupKind.MAMBA), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1), rank_to_attention_slot={0: 0, 1: 0}, - remote_tp_size=2, - remote_block_size=16, - remote_block_len=0, - remote_physical_blocks_per_logical=1, remote_expansion_stride=1, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 29cbcc4817a8..67f9289f0b93 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -12,7 +12,6 @@ — the ONLY model-specific code. 2. Generic executors (build_remote_descs_from_plan, etc.) — consume plans without model branching. - 3. Visualization (visualize_plan). """ from __future__ import annotations @@ -101,9 +100,6 @@ class RegionPlan: page_stride: int num_blocks: int - # Block ID expansion (HMA / kernel block mismatch) - physical_per_logical: int - @dataclass(frozen=True) class EngineTransferPlan: @@ -123,9 +119,6 @@ class EngineTransferPlan: fa_regions: tuple[RegionPlan, ...] ssm_regions: tuple[RegionPlan, ...] - # Per-group geometric properties (worker-facing, model-agnostic) - physical_per_logical: tuple[int, ...] - # Per-group type — used only for descriptor indexing (save path). group_kinds: tuple[GroupKind, ...] @@ -138,12 +131,6 @@ class EngineTransferPlan: # Maps each source rank to its FA head slot index. rank_to_attention_slot: dict[int, int] - # Remote engine facts (needed by worker at read time) - remote_tp_size: int - remote_block_size: int - remote_block_len: int - remote_physical_blocks_per_logical: int - # Stride for expanding remote logical block IDs to kernel block IDs. # Dense: local_physical_blocks_per_logical. # Mamba: remote_physical_blocks_per_logical. @@ -268,7 +255,6 @@ def _build_fa_regions( num_attn_reads: int, rank_offset_factor: int, remote_num_blocks: int, - remote_physical_blocks_per_logical: int, ) -> list[RegionPlan]: """Build FA (attention) regions for the transfer plan. @@ -292,7 +278,6 @@ def _build_fa_regions( offset_in_page=rank_offset, page_stride=page_stride, num_blocks=remote_num_blocks, - physical_per_logical=remote_physical_blocks_per_logical, ) ) @@ -306,7 +291,6 @@ def _build_fa_regions( offset_in_page=rank_offset + page_stride // 2, page_stride=page_stride, num_blocks=remote_num_blocks, - physical_per_logical=remote_physical_blocks_per_logical, ) ) @@ -354,21 +338,15 @@ def generate_dense_plan( num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, remote_num_blocks=remote_num_blocks, - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, ) return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=(), - physical_per_logical=(remote_physical_blocks_per_logical,), group_kinds=(GroupKind.FA,), source_ranks_per_group=m.source_ranks_per_group, all_source_ranks=m.all_source_ranks, rank_to_attention_slot=m.rank_to_attention_slot, - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_lens[0], - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, remote_expansion_stride=local_physical_blocks_per_logical, ) @@ -417,7 +395,6 @@ def generate_mamba_plan( num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, remote_num_blocks=remote_num_blocks, - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, ) # ---- SSM regions ---- @@ -462,7 +439,6 @@ def generate_mamba_plan( offset_in_page=off, page_stride=page_stride, num_blocks=ssm_num_blocks, - physical_per_logical=1, ) ) @@ -474,25 +450,16 @@ def generate_mamba_plan( offset_in_page=conv_size_remote + local_offset * ssm_read_size, page_stride=page_stride, num_blocks=ssm_num_blocks, - physical_per_logical=1, ) ) - physical_per_logical_per_group = tuple( - 1 if k.is_ssm else remote_physical_blocks_per_logical for k in group_kinds - ) return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=tuple(ssm_regions), - physical_per_logical=physical_per_logical_per_group, group_kinds=group_kinds, source_ranks_per_group=m.source_ranks_per_group, all_source_ranks=m.all_source_ranks, rank_to_attention_slot=m.rank_to_attention_slot, - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_block_len=remote_block_lens[0], - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, remote_expansion_stride=remote_physical_blocks_per_logical, ) @@ -502,29 +469,6 @@ def generate_mamba_plan( # ====================================================================== -def logical_to_kernel_block_ids( - block_ids: BlockIds, - physical_per_logical: tuple[int, ...], -) -> BlockIds: - """Convert logical block IDs to kernel-level physical block IDs. - - Each group has its own ratio in ``physical_per_logical``. - Groups with ratio == 1 are passed through unchanged. - """ - if all(r == 1 for r in physical_per_logical): - return block_ids - result: list[list[int]] = [] - for i, group in enumerate(block_ids): - ratio = physical_per_logical[i] - if ratio == 1: - result.append(group) - else: - arr = np.array(group).reshape(-1, 1) - arange = np.arange(ratio).reshape(1, -1) - result.append((arr * ratio + arange).flatten().tolist()) - return result - - def build_remote_descs_from_plan( plan: EngineTransferPlan, nixl_agent_meta: NixlAgentMetadata, @@ -776,48 +720,3 @@ def build_local_descs( physical_blocks_per_logical, ) return fa_descs + mamba_descs - - -# ====================================================================== -# 6. Visualization -# ====================================================================== - - -def visualize_plan(plan: EngineTransferPlan) -> str: - """Human-readable transfer plan for logging and debugging.""" - lines = [ - f"EngineTransferPlan(remote_tp={plan.remote_tp_size}, " - f"remote_bs={plan.remote_block_size}):", - f" Source ranks: all={list(plan.all_source_ranks)}", - ] - total_descs = 0 - - if plan.fa_regions: - lines.append(f" FA regions ({len(plan.fa_regions)}):") - for idx, r in enumerate(plan.fa_regions): - ratio_str = ( - f", p/l={r.physical_per_logical}" if r.physical_per_logical > 1 else "" - ) - lines.append( - f" [{idx}] {r.kind.value:12s} L{r.layer_idx} " - f"{r.descriptor_bytes:6d}B x {r.num_blocks:4d} blks " - f"stride={r.page_stride:6d} " - f"off={r.offset_in_page:6d}" - f"{ratio_str}" - ) - total_descs += r.num_blocks - - if plan.ssm_regions: - lines.append(f" SSM regions ({len(plan.ssm_regions)}):") - for idx, r in enumerate(plan.ssm_regions): - lines.append( - f" [{idx}] {r.kind.value:12s} L{r.layer_idx} " - f"{r.descriptor_bytes:6d}B x {r.num_blocks:4d} blks " - f"stride={r.page_stride:6d} " - f"off={r.offset_in_page:6d}" - ) - total_descs += r.num_blocks - - lines.append(f" Groups: {[k.value for k in plan.group_kinds]}") - lines.append(f" Total descriptors: {total_descs}") - return "\n".join(lines) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 3a905f571bbc..78bc411dc639 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -78,6 +78,7 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size if TYPE_CHECKING: @@ -1126,8 +1127,8 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self.transfer_topo is not None - plan = self._transfer_plans[remote_engine_id] - assert plan.remote_tp_size == remote_tp_size + remote_info = self.transfer_topo.get_engine_info(remote_engine_id) + assert remote_info.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) block_size_ratio = self.transfer_topo.block_size_ratio( @@ -1412,9 +1413,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize - plan = self._transfer_plans[meta.remote.engine_id] + remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) block_size_ratio = self.transfer_topo.block_size_ratio( - plan.remote_block_size + remote_info.remote_block_size ) if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv @@ -1628,7 +1629,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.transfer_topo is not None engine_id = meta.remote.engine_id plan = self._transfer_plans[engine_id] - tp_ratio = self.transfer_topo.tp_ratio(plan.remote_tp_size) + remote_info = self.transfer_topo.get_engine_info(engine_id) + tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, @@ -1651,7 +1653,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_rank = spec.remote_rank local_block_ids = spec.local_block_ids remote_block_ids = spec.remote_block_ids - remote_block_size = plan.remote_block_size + remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s with remote block size %s for req %s", @@ -1716,7 +1718,10 @@ def _read_blocks( """ assert self.transfer_topo is not None plan = self._transfer_plans[dst_engine_id] - block_size_ratio = self.transfer_topo.block_size_ratio(plan.remote_block_size) + remote_info = self.transfer_topo.get_engine_info(dst_engine_id) + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) if block_size_ratio > 1: # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. assert not self._is_hma_required @@ -1801,7 +1806,7 @@ def _read_blocks( plan, block_ids=remote_block_ids, dst_num_blocks=self.dst_num_blocks[dst_engine_id], - physical_blocks_per_logical=plan.remote_physical_blocks_per_logical, + physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) local_block_descs_ids = compute_desc_ids_from_plan( plan, @@ -1871,18 +1876,26 @@ def get_mapped_blocks( return mapped_2d.flatten().astype(np.int64) def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: - """Convert logical block IDs to kernel physical block IDs. - - Required when the logical block size (set by the user) does not match - the one required by the attention backend. + """ + Convert logical block ids to kernel physical block ids. + This is required when the logical block size (the one set by the user) + does not match the one required by the attn backend. """ if self._physical_blocks_per_logical_kv_block == 1: + # Noop when physical and logical block sizes are the same return block_ids - ratio = self._physical_blocks_per_logical_kv_block - arange = np.arange(ratio).reshape(1, -1) + block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( + 1, -1 + ) + # Mamba blocks have no logical<>physical discrepancy + group_specs = self.kv_cache_config.kv_cache_groups return [ - (np.array(group).reshape(-1, 1) * ratio + arange).flatten().tolist() - if self._group_kinds[i].is_attention + BlockTable.map_to_kernel_blocks( + np.array(group), + self._physical_blocks_per_logical_kv_block, + block_arange, + ).tolist() + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec) else group for i, group in enumerate(block_ids) ] @@ -1905,13 +1918,15 @@ def _logical_to_remote_kernel_block_ids( if remote_ratio == 1: return block_ids local_arange = np.arange(local_ratio).reshape(1, -1) + group_specs = self.kv_cache_config.kv_cache_groups result: list[list[int]] = [] for i, group in enumerate(block_ids): - if self._group_kinds[i].is_attention: + if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): arr = np.array(group).reshape(-1, 1) expanded = (arr * remote_ratio + local_arange).flatten() result.append(expanded.tolist()) else: + # Mamba blocks are 1:1 logical-to-physical (no expansion). result.append(group) return result From e9e8a96ab85cb0a3ed9424237f24cadf0c586e30 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 20:45:51 +0000 Subject: [PATCH 30/49] pass remote info/meta objects to plan generators and remove dead code Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 18 ++++-- .../kv_transfer/kv_connector/utils.py | 62 +------------------ .../kv_connector/v1/nixl/transfer_plan.py | 44 ++++++------- .../kv_connector/v1/nixl/worker.py | 22 +++---- 4 files changed, 47 insertions(+), 99 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 8db8e01daa10..8a700d57e62a 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -13,6 +13,7 @@ import pytest +from vllm.distributed.kv_transfer.kv_connector.utils import EngineTransferInfo from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, GroupKind, @@ -77,11 +78,18 @@ def _common_plan_params( is_blocks_first=is_blocks_first, block_len_per_layer=block_len_per_layer, block_size=block_size, - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - remote_num_blocks=remote_num_blocks, - remote_block_lens=remote_block_lens, - remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_lens[0], + remote_physical_blocks_per_logical=remote_physical_blocks_per_logical, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0] * len(block_len_per_layer), + num_blocks=remote_num_blocks, + block_lens=remote_block_lens, + block_size=remote_block_size, + ), local_physical_blocks_per_logical=local_physical_blocks_per_logical, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1662824aee67..b85416ab3071 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -386,46 +386,6 @@ class EngineTransferInfo: """Physical blocks per logical block.""" -@dataclass(frozen=True) -class MambaEngineTransferInfo(EngineTransferInfo): - """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry. - - For hybrid SSM+Attention models, FA and Mamba layers may require - different numbers of reads from different remote ranks. This - dataclass captures that per-engine transfer plan. - """ - - remote_fa_source_ranks: tuple[int, ...] - """Remote ranks carrying unique FA heads for this local rank.""" - - remote_all_source_ranks: tuple[int, ...] - """All remote ranks this local rank reads from (FA + Mamba).""" - - remote_num_fa_reads: int - """Number of distinct remote ranks needed for FA data.""" - - remote_num_mamba_reads: int - """Number of distinct remote ranks needed for Mamba data.""" - - remote_fa_descriptor_bytes: int - """Byte size of one FA K (or V) descriptor entry.""" - - is_remote_replicated: bool - """Whether the remote engine has replicated KV heads - (remote_tp_size > total_num_kv_heads).""" - - remote_physical_heads: int - """Physical KV heads stored per remote rank.""" - - @property - def fa_source_set(self) -> frozenset[int]: - return frozenset(self.remote_fa_source_ranks) - - @property - def fa_source_indices(self) -> dict[int, int]: - return {r: i for i, r in enumerate(self.remote_fa_source_ranks)} - - # ---- Transfer topology ---- @@ -441,7 +401,6 @@ class TransferTopology: is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] - physical_blocks_per_logical: int tensor_shape: torch.Size | None = None def __post_init__(self): @@ -594,14 +553,8 @@ def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: """Get the remote TP rank(s) that the current local TP rank will read from. When remote tp_size > local tp_size, reads from multiple remote ranks. - - For Mamba models, returns the precomputed ``all_source_ranks`` - (FA + Mamba union). """ info = self._engines[remote_engine_id] - if isinstance(info, MambaEngineTransferInfo): - return list(info.remote_all_source_ranks) - tp_ratio = self.tp_ratio(info.remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] @@ -637,21 +590,12 @@ def get_transfer_cache_regions( def describe(self, remote_engine_id: EngineId) -> str: """One-line summary of transfer config for logging.""" info = self._engines[remote_engine_id] - base = ( + return ( + f"TransferTopology(" f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, " f"K={self.total_num_kv_heads}, " f"local_tp={self.tp_size}, " f"remote_tp={info.remote_tp_size}, " f"local_rank={self.tp_rank}, " - f"remote_block_len={info.remote_block_len}" + f"remote_block_len={info.remote_block_len})" ) - if isinstance(info, MambaEngineTransferInfo): - return ( - f"TransferTopology.mamba({base}, " - f"fa_reads={info.remote_num_fa_reads}, " - f"mamba_reads={info.remote_num_mamba_reads}, " - f"fa_sources={list(info.remote_fa_source_ranks)}, " - f"all_sources={list(info.remote_all_source_ranks)}, " - f"fa_desc_bytes={info.remote_fa_descriptor_bytes})" - ) - return f"TransferTopology({base})" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 67f9289f0b93..b29d59b05398 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -22,7 +22,10 @@ import numpy as np -from vllm.distributed.kv_transfer.kv_connector.utils import BlockIds +from vllm.distributed.kv_transfer.kv_connector.utils import ( + BlockIds, + EngineTransferInfo, +) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, ) @@ -311,20 +314,17 @@ def generate_dense_plan( is_blocks_first: bool, block_len_per_layer: list[int], block_size: int, - remote_tp_size: int, - remote_block_size: int, - remote_num_blocks: int, - remote_block_lens: list[int], - remote_physical_blocks_per_logical: int, + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, local_physical_blocks_per_logical: int, ) -> EngineTransferPlan: """Generate transfer plan for dense (FA-only) models.""" - block_size_ratio = block_size // remote_block_size + block_size_ratio = block_size // remote_info.remote_block_size m = _compute_tp_mapping( tp_rank, tp_size, - remote_tp_size, + remote_info.remote_tp_size, is_mla, total_num_kv_heads, group_kinds=(GroupKind.FA,), @@ -332,12 +332,12 @@ def generate_dense_plan( fa_regions = _build_fa_regions( block_len_per_layer=block_len_per_layer, - remote_block_lens=remote_block_lens, + remote_block_lens=remote_meta.block_lens, is_blocks_first=is_blocks_first, block_size_ratio=block_size_ratio, num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, - remote_num_blocks=remote_num_blocks, + remote_num_blocks=remote_meta.num_blocks, ) return EngineTransferPlan( @@ -360,18 +360,19 @@ def generate_mamba_plan( is_blocks_first: bool, block_len_per_layer: list[int], block_size: int, - remote_tp_size: int, - remote_block_size: int, - remote_num_blocks: int, - remote_block_lens: list[int], - remote_physical_blocks_per_logical: int, + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, group_kinds: tuple[GroupKind, ...], conv_decomp: MambaConvSplitInfo, ssm_sizes: tuple[int, int], - remote_ssm_sizes: tuple[int, int], ) -> EngineTransferPlan: """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" - block_size_ratio = block_size // remote_block_size + remote_tp_size = remote_info.remote_tp_size + remote_phys_ratio = remote_info.remote_physical_blocks_per_logical + remote_block_lens = remote_meta.block_lens + remote_ssm_sizes = remote_meta.ssm_sizes + + block_size_ratio = block_size // remote_info.remote_block_size assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got {block_size_ratio=}." @@ -394,15 +395,14 @@ def generate_mamba_plan( block_size_ratio=block_size_ratio, num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, - remote_num_blocks=remote_num_blocks, + remote_num_blocks=remote_meta.num_blocks, ) # ---- SSM regions ---- effective_ratio = tp_size // remote_tp_size if tp_size >= remote_tp_size else 1 local_offset = tp_rank % max(effective_ratio, 1) conv_size_remote = remote_ssm_sizes[0] - remote_ratio = remote_physical_blocks_per_logical - ssm_num_blocks = remote_num_blocks // remote_ratio + ssm_num_blocks = remote_meta.num_blocks // remote_phys_ratio if tp_size >= remote_tp_size: conv_offsets = conv_decomp.remote_conv_offsets( @@ -428,7 +428,7 @@ def generate_mamba_plan( ] ssm_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): - page_stride = remote_block_lens[i] * remote_ratio + page_stride = remote_block_lens[i] * remote_phys_ratio for kind, (off, sz) in zip(conv_kinds, conv_offsets): ssm_regions.append( @@ -460,7 +460,7 @@ def generate_mamba_plan( source_ranks_per_group=m.source_ranks_per_group, all_source_ranks=m.all_source_ranks, rank_to_attention_slot=m.rank_to_attention_slot, - remote_expansion_stride=remote_physical_blocks_per_logical, + remote_expansion_stride=remote_phys_ratio, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 78bc411dc639..6637efea12b3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -673,7 +673,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, # SSM States come in tuples (ssm, conv) tensor_shape=next(iter(kv_caches.values())).shape if not self._has_mamba @@ -1012,11 +1011,8 @@ def add_remote_agent( is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_len_per_layer=self.block_len_per_layer, block_size=self.block_size, - remote_tp_size=remote_tp_size, - remote_block_size=nixl_agent_meta.block_size, - remote_num_blocks=nixl_agent_meta.num_blocks, - remote_block_lens=nixl_agent_meta.block_lens, - remote_physical_blocks_per_logical=physical_blocks_per_logical, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, ) if self._has_mamba: assert self._conv_decomp is not None @@ -1025,7 +1021,6 @@ def add_remote_agent( group_kinds=self._group_kinds, conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, - remote_ssm_sizes=nixl_agent_meta.ssm_sizes, ) else: self._transfer_plans[engine_id] = generate_dense_plan( @@ -1901,21 +1896,22 @@ def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: ] def _logical_to_remote_kernel_block_ids( - self, block_ids: BlockIds, remote_ratio: int + self, block_ids: BlockIds, remote_physical_per_logical: int ) -> BlockIds: """Map logical block IDs to physical kernel block IDs on the remote. Args: block_ids: per-group lists of logical block IDs. - remote_ratio: remote engine's physical blocks per logical block. + remote_physical_per_logical: remote engine's physical blocks + per logical block. Returns: Same structure with FA groups expanded (each logical block L - becomes kernel blocks [L*remote_ratio .. L*remote_ratio + - local_ratio - 1]). Mamba groups are passed through unchanged. + becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]). + Mamba groups are passed through unchanged. """ local_ratio = self._physical_blocks_per_logical_kv_block - if remote_ratio == 1: + if remote_physical_per_logical == 1: return block_ids local_arange = np.arange(local_ratio).reshape(1, -1) group_specs = self.kv_cache_config.kv_cache_groups @@ -1923,7 +1919,7 @@ def _logical_to_remote_kernel_block_ids( for i, group in enumerate(block_ids): if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): arr = np.array(group).reshape(-1, 1) - expanded = (arr * remote_ratio + local_arange).flatten() + expanded = (arr * remote_physical_per_logical + local_arange).flatten() result.append(expanded.tolist()) else: # Mamba blocks are 1:1 logical-to-physical (no expansion). From 93fa814907e01465c9fa217706f96618047f78d1 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 20:52:09 +0000 Subject: [PATCH 31/49] fix mypy: explicit args instead of dict unpacking, remove stale mooncake field Signed-off-by: Zhanqiu Hu --- .../v1/mooncake/mooncake_connector.py | 1 - .../kv_connector/v1/nixl/worker.py | 31 +++++++++++-------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index b1b1cd27a5bc..5a94070ebde7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -842,7 +842,6 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=[backend], - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) self.async_zmq_ctx = zmq.asyncio.Context() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 6637efea12b3..738430b65d7f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1003,28 +1003,33 @@ def add_remote_agent( # Generate the pre-computed transfer plan for this remote engine. # Plan generation is model-aware (if/else), but the per-request # hot path only consumes the plan (model-agnostic). - plan_common = dict( - tp_rank=self.tp_rank, - tp_size=self.world_size, - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - is_blocks_first=transfer_topo.is_kv_layout_blocks_first, - block_len_per_layer=self.block_len_per_layer, - block_size=self.block_size, - remote_info=transfer_info, - remote_meta=nixl_agent_meta, - ) if self._has_mamba: assert self._conv_decomp is not None self._transfer_plans[engine_id] = generate_mamba_plan( - **plan_common, + tp_rank=self.tp_rank, + tp_size=self.world_size, + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + block_len_per_layer=self.block_len_per_layer, + block_size=self.block_size, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, group_kinds=self._group_kinds, conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, ) else: self._transfer_plans[engine_id] = generate_dense_plan( - **plan_common, + tp_rank=self.tp_rank, + tp_size=self.world_size, + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + block_len_per_layer=self.block_len_per_layer, + block_size=self.block_size, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, local_physical_blocks_per_logical=( self._physical_blocks_per_logical_kv_block ), From 00dce93a446815b303fd49af5f2258af66cb66c5 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 21:05:50 +0000 Subject: [PATCH 32/49] pass transfer_topo to plan generators, fix dead code and stale docstring Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 1 - .../kv_connector/unit/test_transfer_plan.py | 39 +++++++++++++--- .../kv_connector/v1/nixl/transfer_plan.py | 44 +++++++++---------- .../kv_connector/v1/nixl/worker.py | 14 +----- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 5ac2aaef2cba..7c4f5f31401d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -153,7 +153,6 @@ def test_read_blocks_for_req_expands_remote_ids( mock_plan = MagicMock(spec=EngineTransferPlan) mock_plan.remote_expansion_stride = expansion_stride - mock_plan.remote_tp_size = 1 mock_plan.all_source_ranks = () mock_plan.source_ranks_per_group = () worker._transfer_plans = {remote_engine_id: mock_plan} diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 8a700d57e62a..316f75548d94 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -10,10 +10,14 @@ from __future__ import annotations from dataclasses import dataclass +from unittest.mock import MagicMock import pytest -from vllm.distributed.kv_transfer.kv_connector.utils import EngineTransferInfo +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, + TransferTopology, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, GroupKind, @@ -49,6 +53,25 @@ class FakeNixlAgentMeta: attn_backend_name: str +def _make_fake_topo( + tp_rank: int = 0, + tp_size: int = 1, + is_mla: bool = False, + total_num_kv_heads: int = 8, + block_size: int = 16, + is_blocks_first: bool = False, +) -> TransferTopology: + """Build a lightweight TransferTopology mock (skips __post_init__).""" + topo = MagicMock(spec=TransferTopology) + topo.tp_rank = tp_rank + topo.tp_size = tp_size + topo.is_mla = is_mla + topo.total_num_kv_heads = total_num_kv_heads + topo.block_size = block_size + topo.is_kv_layout_blocks_first = is_blocks_first + return topo + + def _common_plan_params( tp_rank: int = 0, tp_size: int = 1, @@ -71,13 +94,15 @@ def _common_plan_params( if remote_block_lens is None: remote_block_lens = list(block_len_per_layer) return dict( - tp_rank=tp_rank, - tp_size=tp_size, - is_mla=is_mla, - total_num_kv_heads=num_kv_heads, - is_blocks_first=is_blocks_first, + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=is_mla, + total_num_kv_heads=num_kv_heads, + block_size=block_size, + is_blocks_first=is_blocks_first, + ), block_len_per_layer=block_len_per_layer, - block_size=block_size, remote_info=EngineTransferInfo( remote_tp_size=remote_tp_size, remote_block_size=remote_block_size, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index b29d59b05398..9941687f448c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -25,6 +25,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineTransferInfo, + TransferTopology, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, @@ -74,8 +75,8 @@ def is_ssm(self) -> bool: class RegionKind(enum.Enum): - """Descriptor region type. Used for visualization/debugging only; - executors never branch on this value.""" + """Descriptor region type used for region categorization during plan + construction. Executors never branch on this value.""" FA_K = "fa_k" FA_V = "fa_v" @@ -307,33 +308,28 @@ def _build_fa_regions( def generate_dense_plan( *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_blocks_first: bool, + transfer_topo: TransferTopology, block_len_per_layer: list[int], - block_size: int, remote_info: EngineTransferInfo, remote_meta: NixlAgentMetadata, local_physical_blocks_per_logical: int, ) -> EngineTransferPlan: """Generate transfer plan for dense (FA-only) models.""" - block_size_ratio = block_size // remote_info.remote_block_size + block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size m = _compute_tp_mapping( - tp_rank, - tp_size, + transfer_topo.tp_rank, + transfer_topo.tp_size, remote_info.remote_tp_size, - is_mla, - total_num_kv_heads, + transfer_topo.is_mla, + transfer_topo.total_num_kv_heads, group_kinds=(GroupKind.FA,), ) fa_regions = _build_fa_regions( block_len_per_layer=block_len_per_layer, remote_block_lens=remote_meta.block_lens, - is_blocks_first=is_blocks_first, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, @@ -353,13 +349,8 @@ def generate_dense_plan( def generate_mamba_plan( *, - tp_rank: int, - tp_size: int, - is_mla: bool, - total_num_kv_heads: int, - is_blocks_first: bool, + transfer_topo: TransferTopology, block_len_per_layer: list[int], - block_size: int, remote_info: EngineTransferInfo, remote_meta: NixlAgentMetadata, group_kinds: tuple[GroupKind, ...], @@ -367,12 +358,17 @@ def generate_mamba_plan( ssm_sizes: tuple[int, int], ) -> EngineTransferPlan: """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" + assert group_kinds[0].is_attention, ( + f"First group must be an attention group (FA/SWA), got {group_kinds[0]}" + ) + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size remote_tp_size = remote_info.remote_tp_size remote_phys_ratio = remote_info.remote_physical_blocks_per_logical remote_block_lens = remote_meta.block_lens remote_ssm_sizes = remote_meta.ssm_sizes - block_size_ratio = block_size // remote_info.remote_block_size + block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got {block_size_ratio=}." @@ -382,8 +378,8 @@ def generate_mamba_plan( tp_rank, tp_size, remote_tp_size, - is_mla, - total_num_kv_heads, + transfer_topo.is_mla, + transfer_topo.total_num_kv_heads, group_kinds, ) @@ -391,7 +387,7 @@ def generate_mamba_plan( fa_regions = _build_fa_regions( block_len_per_layer=block_len_per_layer, remote_block_lens=remote_block_lens, - is_blocks_first=is_blocks_first, + is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, num_attn_reads=len(m.source_ranks_per_group[0]), rank_offset_factor=m.rank_offset_factor, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 738430b65d7f..4cca1588ce95 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1006,13 +1006,8 @@ def add_remote_agent( if self._has_mamba: assert self._conv_decomp is not None self._transfer_plans[engine_id] = generate_mamba_plan( - tp_rank=self.tp_rank, - tp_size=self.world_size, - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + transfer_topo=transfer_topo, block_len_per_layer=self.block_len_per_layer, - block_size=self.block_size, remote_info=transfer_info, remote_meta=nixl_agent_meta, group_kinds=self._group_kinds, @@ -1021,13 +1016,8 @@ def add_remote_agent( ) else: self._transfer_plans[engine_id] = generate_dense_plan( - tp_rank=self.tp_rank, - tp_size=self.world_size, - is_mla=self.use_mla, - total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - is_blocks_first=transfer_topo.is_kv_layout_blocks_first, + transfer_topo=transfer_topo, block_len_per_layer=self.block_len_per_layer, - block_size=self.block_size, remote_info=transfer_info, remote_meta=nixl_agent_meta, local_physical_blocks_per_logical=( From 558d528d78c84386232118eaa12ca0b6eb8cd663 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Thu, 23 Apr 2026 21:22:19 +0000 Subject: [PATCH 33/49] fix test: set kv_cache_config on mock worker for remote block ID expansion Signed-off-by: Zhanqiu Hu --- tests/v1/kv_connector/unit/test_nixl_connector_hma.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 7c4f5f31401d..ca754e84480e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -145,6 +145,12 @@ def test_read_blocks_for_req_expands_remote_ids( worker._physical_blocks_per_logical_kv_block = 2 worker._group_kinds = tuple(GroupKind[k] for k in group_kinds) + has_mamba = any(k == "MAMBA" for k in group_kinds) + has_swa = any(k == "SWA" for k in group_kinds) + worker.kv_cache_config = make_kv_cache_config( + block_size=16, swa_enabled=has_swa, mamba_enabled=has_mamba + ) + remote_engine_id = "remote-engine" worker.transfer_topo = MagicMock() From be301ac45458a1e3933b919445814dea8b6afae7 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Fri, 24 Apr 2026 14:19:43 -0400 Subject: [PATCH 34/49] update Signed-off-by: Zhanqiu Hu --- .../unit/test_nixl_connector_hma.py | 66 ++-- .../kv_connector/unit/test_transfer_plan.py | 65 ++-- .../kv_connector/v1/nixl/transfer_plan.py | 326 +++--------------- .../kv_connector/v1/nixl/worker.py | 238 ++++++++++--- 4 files changed, 310 insertions(+), 385 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index ca754e84480e..6c1fdfd511ea 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -93,31 +93,27 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "group_kinds,expansion_stride,remote_block_ids,expected_remote_block_ids", + "group_spec_types,expansion_stride,remote_block_ids," + "expected_remote_block_ids", [ - # Dense (FA+SWA): stride == local_ratio, all groups expanded. - # Regression for https://github.com/vllm-project/vllm/pull/39724 - ( - ("FA", "SWA"), + pytest.param( + ("FullAttentionSpec", "SlidingWindowSpec"), 2, ([0, 1, 2], [3, 4]), [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]], + id="dense_fa_swa", ), - # Mamba (FA+Mamba): stride == remote_physical_blocks_per_logical, - # FA expanded, Mamba passed through unchanged. - # stride=261 (Nemotron 30B TP=1) != local_ratio=2 so that using - # the wrong stride produces different FA results. - ( - ("FA", "MAMBA"), + pytest.param( + ("FullAttentionSpec", "MambaSpec"), 261, ([0, 1, 2], [10, 11]), [[0, 1, 261, 262, 522, 523], [10, 11]], + id="mamba_fa_ssm", ), ], - ids=["dense_fa_swa", "mamba_fa_ssm"], ) def test_read_blocks_for_req_expands_remote_ids( - group_kinds, + group_spec_types, expansion_stride, remote_block_ids, expected_remote_block_ids, @@ -135,18 +131,28 @@ def test_read_blocks_for_req_expands_remote_ids( ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, - GroupKind, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, + ) + + spec_name_to_type = { + "FullAttentionSpec": FullAttentionSpec, + "SlidingWindowSpec": SlidingWindowSpec, + "MambaSpec": MambaSpec, + } + resolved_types = tuple(spec_name_to_type[n] for n in group_spec_types) worker = object.__new__(NixlConnectorWorker) worker._physical_blocks_per_logical_kv_block = 2 - worker._group_kinds = tuple(GroupKind[k] for k in group_kinds) - has_mamba = any(k == "MAMBA" for k in group_kinds) - has_swa = any(k == "SWA" for k in group_kinds) + has_mamba = any(t is MambaSpec for t in resolved_types) + has_swa = any(t is SlidingWindowSpec for t in resolved_types) worker.kv_cache_config = make_kv_cache_config( block_size=16, swa_enabled=has_swa, mamba_enabled=has_mamba ) @@ -308,29 +314,30 @@ def test_nixl_metadata_hma_block_ids_structure(): @pytest.mark.cpu_test def test_get_block_descs_ids_hybrid_ssm(): - """Test compute_desc_ids_from_plan uses per-group strides for hybrid + """Test _compute_desc_ids uses per-group strides for hybrid FA+SSM when ratio=1 (no kernel block size mismatch).""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( - GroupKind, - compute_desc_ids_from_plan, + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from .test_transfer_plan import _make_mamba_plan_for_desc_ids plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), fa_num_blocks=100, ssm_num_blocks=100, ) fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = compute_desc_ids_from_plan( + result = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=(fa_blocks, ssm_blocks), dst_num_blocks=100, + block_size_ratio=None, physical_blocks_per_logical=1, ) @@ -340,12 +347,12 @@ def test_get_block_descs_ids_hybrid_ssm(): @pytest.mark.cpu_test def test_get_block_descs_ids_kernel_block_mismatch(): - """Test compute_desc_ids_from_plan uses different strides for FA + """Test _compute_desc_ids uses different strides for FA (kernel blocks) vs SSM (logical blocks) when ratio > 1.""" - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( - GroupKind, - compute_desc_ids_from_plan, + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, ) + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from .test_transfer_plan import _make_mamba_plan_for_desc_ids @@ -356,17 +363,18 @@ def test_get_block_descs_ids_kernel_block_mismatch(): plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), fa_num_blocks=num_blocks, ssm_num_blocks=logical_blocks, ) fa_blocks = [3, 7] ssm_blocks = [1, 2] - result = compute_desc_ids_from_plan( + result = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=(fa_blocks, ssm_blocks), dst_num_blocks=num_blocks, + block_size_ratio=None, physical_blocks_per_logical=ratio, ) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 316f75548d94..4938019ee57e 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -20,15 +20,13 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, - GroupKind, - RegionKind, RegionPlan, - build_local_splits_from_plan, - build_remote_descs_from_plan, - compute_desc_ids_from_plan, - compute_read_specs_from_plan, generate_dense_plan, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec # ====================================================================== # Test fixtures / helpers @@ -115,6 +113,7 @@ def _common_plan_params( block_lens=remote_block_lens, block_size=remote_block_size, ), + group_spec_types=(FullAttentionSpec,), local_physical_blocks_per_logical=local_physical_blocks_per_logical, ) @@ -193,7 +192,7 @@ def test_build_remote_descs(self, tp_size, remote_tp_size, tp_rank_frac): meta = _make_nixl_meta( base_addrs, num_blocks, remote_block_lens, block_size=block_size ) - descs = build_remote_descs_from_plan(plan, meta) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) expected_count = len(plan.fa_regions) * num_blocks assert len(descs) == expected_count @@ -235,7 +234,10 @@ def test_compute_desc_ids(self, tp_size, remote_tp_size): ) block_ids = ([1, 5, 10, 20],) - ids = compute_desc_ids_from_plan(plan, block_ids, dst_num_blocks=num_blocks) + ids = NixlConnectorWorker._compute_desc_ids_from_plan( + plan, block_ids, dst_num_blocks=num_blocks, + block_size_ratio=None, physical_blocks_per_logical=1, + ) num_regions = len(plan.fa_regions) assert len(ids) == num_regions * len(block_ids[0]) @@ -276,7 +278,7 @@ def test_compute_read_specs(self, tp_size, remote_tp_size): local_ids = ([1, 2, 3],) remote_ids = ([4, 5, 6],) - specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) assert len(specs) == len(plan.all_source_ranks) for spec in specs: @@ -314,7 +316,7 @@ def test_build_src_split_handles(self, remote_tp_size): src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)] num_descs = len(src_blocks_data) - splits = build_local_splits_from_plan( + splits = NixlConnectorWorker._build_local_splits_from_plan( plan, src_blocks_data, num_descs, @@ -349,23 +351,21 @@ def test_source_ranks_p_gt_d(self): def test_no_ssm_regions(self): plan = generate_dense_plan(**_common_plan_params()) assert plan.ssm_regions == () - assert plan.group_kinds == (GroupKind.FA,) + assert plan.group_spec_types == (FullAttentionSpec,) def test_blocks_first_has_k_and_v(self): plan = generate_dense_plan( **_common_plan_params(is_blocks_first=True), ) - kinds = [r.kind.value for r in plan.fa_regions] - assert "fa_k" in kinds - assert "fa_v" in kinds + num_layers = 2 + assert len(plan.fa_regions) == num_layers * 2 # K + V per layer def test_not_blocks_first_has_only_k(self): plan = generate_dense_plan( **_common_plan_params(is_blocks_first=False), ) - kinds = [r.kind.value for r in plan.fa_regions] - assert "fa_k" in kinds - assert "fa_v" not in kinds + num_layers = 2 + assert len(plan.fa_regions) == num_layers # K only per layer # ====================================================================== @@ -376,14 +376,13 @@ def test_not_blocks_first_has_only_k(self): def _make_mamba_plan_for_desc_ids( num_fa_regions: int, num_ssm_regions: int, - group_kinds: tuple[GroupKind, ...], + group_spec_types: tuple[type, ...], fa_num_blocks: int = 100, ssm_num_blocks: int = 100, ) -> EngineTransferPlan: """Build a minimal plan with enough structure for compute_desc_ids.""" fa_regions = tuple( RegionPlan( - kind=RegionKind.FA_K, layer_idx=i, descriptor_bytes=100, offset_in_page=0, @@ -394,7 +393,6 @@ def _make_mamba_plan_for_desc_ids( ) ssm_regions = tuple( RegionPlan( - kind=RegionKind.SSM_CONV_X, layer_idx=i % (num_ssm_regions // 4) if num_ssm_regions >= 4 else 0, descriptor_bytes=50, offset_in_page=0, @@ -404,11 +402,11 @@ def _make_mamba_plan_for_desc_ids( for i in range(num_ssm_regions) ) all_ranks = (0,) - source_ranks_per_group = tuple(all_ranks for _ in group_kinds) + source_ranks_per_group = tuple(all_ranks for _ in group_spec_types) return EngineTransferPlan( fa_regions=fa_regions, ssm_regions=ssm_regions, - group_kinds=group_kinds, + group_spec_types=group_spec_types, source_ranks_per_group=source_ranks_per_group, all_source_ranks=(0,), rank_to_attention_slot={0: 0}, @@ -424,7 +422,7 @@ def test_hybrid_ssm_ratio_1(self): plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, # 4 regions per layer, 1 layer - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), fa_num_blocks=100, ssm_num_blocks=100, ) @@ -432,10 +430,11 @@ def test_hybrid_ssm_ratio_1(self): fa_blocks = [3, 5] ssm_blocks = [1, 2] - result = compute_desc_ids_from_plan( + result = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=(fa_blocks, ssm_blocks), dst_num_blocks=100, + block_size_ratio=None, physical_blocks_per_logical=1, ) @@ -451,7 +450,7 @@ def test_kernel_block_mismatch(self): plan = _make_mamba_plan_for_desc_ids( num_fa_regions=2, num_ssm_regions=4, - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), fa_num_blocks=num_blocks, ssm_num_blocks=logical_blocks, ) @@ -459,10 +458,11 @@ def test_kernel_block_mismatch(self): fa_blocks = [3, 7] ssm_blocks = [1, 2] - result = compute_desc_ids_from_plan( + result = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=(fa_blocks, ssm_blocks), dst_num_blocks=num_blocks, + block_size_ratio=None, physical_blocks_per_logical=ratio, ) @@ -479,7 +479,7 @@ def test_all_source_ranks_serve_fa(self): plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(both, both), all_source_ranks=(0, 1), rank_to_attention_slot={0: 0, 1: 1}, @@ -489,7 +489,7 @@ def test_all_source_ranks_serve_fa(self): local_ids = ([1, 2], [3, 4]) remote_ids = ([5, 6], [7, 8]) - specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) assert len(specs) == 2 for spec in specs: assert list(spec.local_block_ids[0]) == [1, 2] @@ -502,7 +502,7 @@ def test_non_fa_rank_skips_fa_groups(self): plan = EngineTransferPlan( fa_regions=(), ssm_regions=(), - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1, 2), rank_to_attention_slot={0: 0}, @@ -512,7 +512,7 @@ def test_non_fa_rank_skips_fa_groups(self): local_ids = ([1, 2], [3, 4]) remote_ids = ([5, 6], [7, 8]) - specs = compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) assert len(specs) == 3 # Rank 0 (FA source): gets all groups @@ -539,7 +539,6 @@ def test_fa_and_ssm_different_split_factors(self): fa_regions=(), ssm_regions=( RegionPlan( - kind=RegionKind.SSM_STATE, layer_idx=0, descriptor_bytes=100, offset_in_page=0, @@ -547,7 +546,7 @@ def test_fa_and_ssm_different_split_factors(self): num_blocks=10, ), ), - group_kinds=(GroupKind.FA, GroupKind.MAMBA), + group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1), rank_to_attention_slot={0: 0, 1: 0}, @@ -561,7 +560,7 @@ def test_fa_and_ssm_different_split_factors(self): (3000, 400, 0), # SSM desc 0 ] - splits = build_local_splits_from_plan(plan, src_blocks_data, 2) + splits = NixlConnectorWorker._build_local_splits_from_plan(plan, src_blocks_data, 2) assert len(splits) == 2 # 2 source ranks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 9941687f448c..ed5fc85d494c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -2,21 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Plan-based transfer design for NIXL connector. -Instead of an ABC hierarchy with duplicated Dense/Mamba implementations, -we pre-generate a flat transfer plan per remote engine during handshake. -All downstream operations become generic plan executors with zero model -branching. - -Architecture: - 1. Plan generators (generate_dense_plan, generate_mamba_plan) - — the ONLY model-specific code. - 2. Generic executors (build_remote_descs_from_plan, etc.) - — consume plans without model branching. +Data structures, plan generators, and local descriptor builders +for NIXL KV cache transfers. """ from __future__ import annotations -import enum from dataclasses import dataclass from typing import TYPE_CHECKING @@ -30,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, ) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( @@ -51,39 +43,12 @@ class ReadSpec: remote_block_ids: BlockIds -class GroupKind(enum.Enum): - """KV cache group type for transfer purposes. +def _is_attention_spec(spec_type: type[KVCacheSpec]) -> bool: + return issubclass(spec_type, AttentionSpec) - Used by ``EngineTransferPlan`` and block expansion functions to - dispatch per-group behavior without model-specific branching. - """ - - FA = "fa" - SWA = "swa" - MAMBA = "mamba" - GDN = "gdn" - @property - def is_attention(self) -> bool: - """FA and SWA both need block expansion and standard descriptors.""" - return self in (GroupKind.FA, GroupKind.SWA) - - @property - def is_ssm(self) -> bool: - """MAMBA and GDN have state descriptors instead of KV pages.""" - return self in (GroupKind.MAMBA, GroupKind.GDN) - - -class RegionKind(enum.Enum): - """Descriptor region type used for region categorization during plan - construction. Executors never branch on this value.""" - - FA_K = "fa_k" - FA_V = "fa_v" - SSM_CONV_X = "ssm_conv_x" - SSM_CONV_B = "ssm_conv_b" - SSM_CONV_C = "ssm_conv_c" - SSM_STATE = "ssm_state" +def _is_ssm_spec(spec_type: type[KVCacheSpec]) -> bool: + return issubclass(spec_type, MambaSpec) @dataclass(frozen=True) @@ -91,11 +56,10 @@ class RegionPlan: """Pre-computed plan for one descriptor region. Everything needed to build NIXL descriptors and compute descriptor - IDs is baked in — no runtime model branching. The executor plugs - in per-rank ``base_addr`` and ``device_id`` from NixlAgentMetadata. + IDs is baked in. The executor plugs in per-rank ``base_addr`` and + ``device_id`` from NixlAgentMetadata. """ - kind: RegionKind layer_idx: int # Descriptor geometry @@ -109,22 +73,17 @@ class RegionPlan: class EngineTransferPlan: """Complete transfer plan for one remote engine. - Generated once during handshake. Stored alongside (or replacing) - ``EngineTransferInfo`` on ``TransferTopology``. - - Regions are split into ``fa_regions`` and ``ssm_regions`` matching - the descriptor handle layout: [FA descriptors | SSM descriptors]. - ``group_kinds`` maps each kv_cache_group to its type for descriptor - indexing. ``source_ranks_per_group`` encodes which ranks read each - group — executors use this instead of group_kinds for rank routing. + Generated once during handshake. Regions are split into + ``fa_regions`` and ``ssm_regions`` matching the descriptor + handle layout. """ # Regions in descriptor handle order fa_regions: tuple[RegionPlan, ...] ssm_regions: tuple[RegionPlan, ...] - # Per-group type — used only for descriptor indexing (save path). - group_kinds: tuple[GroupKind, ...] + # Per-group KVCacheSpec type — used for descriptor indexing. + group_spec_types: tuple[type[KVCacheSpec], ...] # Per-group ordered source ranks. Position = local piece index. source_ranks_per_group: tuple[tuple[int, ...], ...] @@ -176,75 +135,73 @@ def _compute_tp_mapping( remote_tp_size: int, is_mla: bool, total_num_kv_heads: int, - group_kinds: tuple[GroupKind, ...], + group_spec_types: tuple[type[KVCacheSpec], ...], ) -> TPMapping: """Build the complete local-to-remote TP mapping. Computes source ranks, head slot assignments, and the rank offset factor in a single pass. Both generators call this and unpack. """ - K = total_num_kv_heads - # --- Attention source ranks --- if is_mla: + # All heads replicated across all ranks. attn_ranks = [0] elif tp_size >= remote_tp_size: attn_ranks = [tp_rank * remote_tp_size // tp_size] else: - # P > D: one local rank reads from multiple remote ranks. + # P (remote TP size) > D (local TP size): one local rank reads from multiple remote ranks. # GQA dedup: when K < remote_tp_size, several remote ranks # hold the same KV head. np.unique keeps only the first # rank per unique head so we don't issue redundant reads. abs_tp = remote_tp_size // tp_size start = tp_rank * abs_tp - heads = np.arange(start, start + abs_tp) * K // remote_tp_size + heads = (np.arange(start, start + abs_tp) + * total_num_kv_heads // remote_tp_size) _, unique_idx = np.unique(heads, return_index=True) attn_ranks = (start + np.sort(unique_idx)).tolist() - # --- All source ranks (expand for SSM if needed) --- - has_ssm = any(k.is_ssm for k in group_kinds) - if not has_ssm or tp_size >= remote_tp_size: - all_ranks = list(attn_ranks) - else: - abs_tp = remote_tp_size // tp_size - if abs_tp > len(attn_ranks): - all_ranks = list( - range( - tp_rank * abs_tp, - (tp_rank + 1) * abs_tp, - ) - ) + # --- SSM source ranks --- + has_ssm = any(_is_ssm_spec(t) for t in group_spec_types) + if has_ssm: + if tp_size < remote_tp_size: + abs_tp = remote_tp_size // tp_size + ssm_ranks = list(range(tp_rank * abs_tp, (tp_rank + 1) * abs_tp)) else: - all_ranks = list(attn_ranks) + ssm_ranks = list(attn_ranks) + else: + ssm_ranks = [] + + all_ranks = sorted(set(attn_ranks) | set(ssm_ranks)) # --- Per-group ordered source ranks --- - attn_tuple = tuple(attn_ranks) - all_tuple = tuple(all_ranks) source_ranks_per_group = tuple( - all_tuple if k.is_ssm else attn_tuple for k in group_kinds + tuple(ssm_ranks) if _is_ssm_spec(t) else tuple(attn_ranks) + for t in group_spec_types ) # --- Attention head slots --- head_to_slot: dict[int, int] = {} for i, r in enumerate(attn_ranks): - head_to_slot[r * K // remote_tp_size] = i + head_to_slot[r * total_num_kv_heads // remote_tp_size] = i rank_to_attention_slot = { - r: head_to_slot.get(r * K // remote_tp_size, 0) for r in all_ranks + r: head_to_slot.get(r * total_num_kv_heads // remote_tp_size, 0) + for r in all_ranks } # --- Rank offset factor --- if is_mla or tp_size <= remote_tp_size: rank_offset_factor = 0 - elif tp_size > K: - local_head = tp_rank * K // tp_size - p_start = attn_ranks[0] * K // remote_tp_size + elif tp_size > total_num_kv_heads: + local_head = tp_rank * total_num_kv_heads // tp_size + p_start = (attn_ranks[0] * total_num_kv_heads + // remote_tp_size) rank_offset_factor = local_head - p_start else: rank_offset_factor = tp_rank % (tp_size // remote_tp_size) return TPMapping( source_ranks_per_group=source_ranks_per_group, - all_source_ranks=all_tuple, + all_source_ranks=tuple(all_ranks), rank_to_attention_slot=rank_to_attention_slot, rank_offset_factor=rank_offset_factor, ) @@ -276,7 +233,6 @@ def _build_fa_regions( fa_regions.append( RegionPlan( - kind=RegionKind.FA_K, layer_idx=i, descriptor_bytes=k_desc_bytes, offset_in_page=rank_offset, @@ -289,7 +245,6 @@ def _build_fa_regions( v_desc_bytes = local_block_len // num_attn_reads fa_regions.append( RegionPlan( - kind=RegionKind.FA_V, layer_idx=i, descriptor_bytes=v_desc_bytes, offset_in_page=rank_offset + page_stride // 2, @@ -302,7 +257,7 @@ def _build_fa_regions( # ====================================================================== -# 3. Plan generators — the ONLY model-specific code +# 3. Plan generators # ====================================================================== @@ -312,9 +267,10 @@ def generate_dense_plan( block_len_per_layer: list[int], remote_info: EngineTransferInfo, remote_meta: NixlAgentMetadata, + group_spec_types: tuple[type[KVCacheSpec], ...], local_physical_blocks_per_logical: int, ) -> EngineTransferPlan: - """Generate transfer plan for dense (FA-only) models.""" + """Generate transfer plan for dense (attention-only) models.""" block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size m = _compute_tp_mapping( @@ -323,7 +279,7 @@ def generate_dense_plan( remote_info.remote_tp_size, transfer_topo.is_mla, transfer_topo.total_num_kv_heads, - group_kinds=(GroupKind.FA,), + group_spec_types=group_spec_types, ) fa_regions = _build_fa_regions( @@ -339,7 +295,7 @@ def generate_dense_plan( return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=(), - group_kinds=(GroupKind.FA,), + group_spec_types=group_spec_types, source_ranks_per_group=m.source_ranks_per_group, all_source_ranks=m.all_source_ranks, rank_to_attention_slot=m.rank_to_attention_slot, @@ -353,13 +309,13 @@ def generate_mamba_plan( block_len_per_layer: list[int], remote_info: EngineTransferInfo, remote_meta: NixlAgentMetadata, - group_kinds: tuple[GroupKind, ...], + group_spec_types: tuple[type[KVCacheSpec], ...], conv_decomp: MambaConvSplitInfo, ssm_sizes: tuple[int, int], ) -> EngineTransferPlan: """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" - assert group_kinds[0].is_attention, ( - f"First group must be an attention group (FA/SWA), got {group_kinds[0]}" + assert _is_attention_spec(group_spec_types[0]), ( + f"First group must be an attention spec, got {group_spec_types[0]}" ) tp_rank = transfer_topo.tp_rank tp_size = transfer_topo.tp_size @@ -380,7 +336,7 @@ def generate_mamba_plan( remote_tp_size, transfer_topo.is_mla, transfer_topo.total_num_kv_heads, - group_kinds, + group_spec_types, ) # ---- FA regions ---- @@ -417,19 +373,13 @@ def generate_mamba_plan( ] ssm_read_size = remote_ssm_sizes[1] - conv_kinds = [ - RegionKind.SSM_CONV_X, - RegionKind.SSM_CONV_B, - RegionKind.SSM_CONV_C, - ] ssm_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): page_stride = remote_block_lens[i] * remote_phys_ratio - for kind, (off, sz) in zip(conv_kinds, conv_offsets): + for off, sz in conv_offsets: ssm_regions.append( RegionPlan( - kind=kind, layer_idx=i, descriptor_bytes=sz, offset_in_page=off, @@ -440,7 +390,6 @@ def generate_mamba_plan( ssm_regions.append( RegionPlan( - kind=RegionKind.SSM_STATE, layer_idx=i, descriptor_bytes=ssm_read_size, offset_in_page=conv_size_remote + local_offset * ssm_read_size, @@ -452,7 +401,7 @@ def generate_mamba_plan( return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=tuple(ssm_regions), - group_kinds=group_kinds, + group_spec_types=group_spec_types, source_ranks_per_group=m.source_ranks_per_group, all_source_ranks=m.all_source_ranks, rank_to_attention_slot=m.rank_to_attention_slot, @@ -461,141 +410,7 @@ def generate_mamba_plan( # ====================================================================== -# 4. Generic executors — identical for ALL models -# ====================================================================== - - -def build_remote_descs_from_plan( - plan: EngineTransferPlan, - nixl_agent_meta: NixlAgentMetadata, -) -> list[tuple[int, int, int]]: - """Build (addr, len, dev_id) descriptor tuples from plan. - - Builds remote descriptors from a pre-computed plan. - """ - result: list[tuple[int, int, int]] = [] - dev_id = nixl_agent_meta.device_id - - for region in plan.all_regions: - base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] - for blk in range(region.num_blocks): - addr = base_addr + blk * region.page_stride + region.offset_in_page - result.append((addr, region.descriptor_bytes, dev_id)) - - return result - - -def compute_desc_ids_from_plan( - plan: EngineTransferPlan, - block_ids: BlockIds, - dst_num_blocks: int, - block_size_ratio: float | None = None, - physical_blocks_per_logical: int = 1, -) -> np.ndarray: - """Compute NIXL descriptor IDs for given block IDs. - - Computes descriptor indices from a pre-computed plan. - """ - num_fa_regions = len(plan.fa_regions) - num_ssm_regions = len(plan.ssm_regions) - - num_blocks = dst_num_blocks - if block_size_ratio is not None: - num_blocks = int(num_blocks * block_size_ratio) - ratio = physical_blocks_per_logical - logical_blocks = num_blocks // ratio - - num_fa_descs = num_fa_regions * num_blocks - - all_descs: list[np.ndarray] = [] - for i, group in enumerate(block_ids): - group_arr = np.asarray(group) - if plan.group_kinds[i].is_attention: - fa_region_ids = np.arange(num_fa_regions)[:, None] - all_descs.append( - (fa_region_ids * num_blocks + group_arr[None, :]).flatten() - ) - elif plan.group_kinds[i].is_ssm: - ssm_region_ids = np.arange(num_ssm_regions)[:, None] - all_descs.append( - ( - ssm_region_ids * logical_blocks + group_arr[None, :] + num_fa_descs - ).flatten() - ) - else: - raise ValueError(f"Unknown group kind {plan.group_kinds[i]} at index {i}") - - return np.concatenate(all_descs) - - -def compute_read_specs_from_plan( - plan: EngineTransferPlan, - local_block_ids: BlockIds, - remote_block_ids: BlockIds, -) -> list[ReadSpec]: - """Compute read specs from plan. - - For each source rank, includes only the groups whose - source_ranks_per_group contains that rank. - """ - num_groups = len(local_block_ids) - return [ - ReadSpec( - remote_rank=rank, - local_block_ids=[ - list(local_block_ids[g]) - if rank in plan.source_ranks_per_group[g] - else [] - for g in range(num_groups) - ], - remote_block_ids=[ - list(remote_block_ids[g]) - if rank in plan.source_ranks_per_group[g] - else [] - for g in range(num_groups) - ], - ) - for rank in plan.all_source_ranks - ] - - -def build_local_splits_from_plan( - plan: EngineTransferPlan, - src_blocks_data: list[tuple[int, int, int]], - num_fa_descs: int, -) -> list[list[tuple[int, int, int]]]: - """Build split handle data for P_TP > D_TP scenario. - - num_fa_descs is the boundary between FA and SSM descriptors. - Split counts are derived from source_ranks_per_group lengths. - FA uses rank_to_attention_slot for the slot offset; - SSM uses the rank's positional index. - """ - fa_num_splits = len(plan.source_ranks_per_group[0]) - - has_ssm_descs = num_fa_descs < len(src_blocks_data) - ssm_num_splits = len(plan.source_ranks_per_group[-1]) if has_ssm_descs else 0 - - result: list[list[tuple[int, int, int]]] = [] - - for p_idx, p_rank in enumerate(plan.all_source_ranks): - fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) - - handle: list[tuple[int, int, int]] = [] - for j, (addr, local_len, dev) in enumerate(src_blocks_data): - if j < num_fa_descs: - chunk = local_len // fa_num_splits - handle.append((addr + fa_slot * chunk, chunk, dev)) - else: - chunk = local_len // ssm_num_splits - handle.append((addr + p_idx * chunk, chunk, dev)) - result.append(handle) - - return result - - -# ====================================================================== -# 5. Local descriptor building (no plan needed — purely local geometry) +# 4. Local descriptor building # ====================================================================== @@ -679,40 +494,3 @@ def build_mamba_local_descs( return result -def build_local_descs( - *, - has_mamba: bool, - conv_decomp: MambaConvSplitInfo | None, - ssm_sizes: tuple[int, int], - base_addresses: list[int], - device_id: int, - num_blocks: int, - logical_num_blocks: int, - block_size_ratio: int, - block_len_per_layer: list[int], - is_blocks_first: bool, - physical_blocks_per_logical: int = 1, -) -> list[tuple[int, int, int]]: - """Build local (src) descriptor tuples for NIXL registration.""" - fa_descs = build_fa_local_descs( - base_addresses, - device_id, - num_blocks, - block_size_ratio, - block_len_per_layer, - is_blocks_first, - ) - if not has_mamba: - return fa_descs - assert conv_decomp is not None - mamba_descs = build_mamba_local_descs( - base_addresses, - block_len_per_layer, - logical_num_blocks, - block_size_ratio, - device_id, - conv_decomp, - ssm_sizes, - physical_blocks_per_logical, - ) - return fa_descs + mamba_descs diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 4cca1588ce95..757e7a3bc29e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -45,12 +45,11 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, - GroupKind, - build_local_descs, - build_local_splits_from_plan, - build_remote_descs_from_plan, - compute_desc_ids_from_plan, - compute_read_specs_from_plan, + ReadSpec, + _is_attention_spec, + _is_ssm_spec, + build_fa_local_descs, + build_mamba_local_descs, generate_dense_plan, generate_mamba_plan, ) @@ -75,7 +74,6 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, - SlidingWindowSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.worker.block_table import BlockTable @@ -91,17 +89,167 @@ class NixlConnectorWorker: """Implementation of Worker side methods""" + # ------------------------------------------------------------------ + # Plan executors (pure functions, no self access) + # ------------------------------------------------------------------ + @staticmethod - def _spec_to_group_kind(spec: "KVCacheSpec") -> GroupKind: - if isinstance(spec, MambaSpec): - return GroupKind.MAMBA - if isinstance(spec, SlidingWindowSpec): - return GroupKind.SWA - if isinstance(spec, FullAttentionSpec): - return GroupKind.FA - raise NotImplementedError( - f"Unsupported KVCacheSpec type for NIXL transfer: {type(spec)}" + def _build_remote_descs_from_plan( + plan: EngineTransferPlan, + nixl_agent_meta: "NixlAgentMetadata", + ) -> list[tuple[int, int, int]]: + """Build (addr, len, dev_id) descriptor tuples from plan.""" + result: list[tuple[int, int, int]] = [] + dev_id = nixl_agent_meta.device_id + + for region in plan.all_regions: + base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] + for blk in range(region.num_blocks): + addr = (base_addr + blk * region.page_stride + + region.offset_in_page) + result.append((addr, region.descriptor_bytes, dev_id)) + + return result + + @staticmethod + def _compute_desc_ids_from_plan( + plan: EngineTransferPlan, + block_ids: BlockIds, + dst_num_blocks: int, + block_size_ratio: float | None, + physical_blocks_per_logical: int, + ) -> np.ndarray: + """Compute NIXL descriptor IDs for given block IDs.""" + num_fa_regions = len(plan.fa_regions) + num_ssm_regions = len(plan.ssm_regions) + + num_blocks = dst_num_blocks + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) + ratio = physical_blocks_per_logical + logical_blocks = num_blocks // ratio + + num_fa_descs = num_fa_regions * num_blocks + + all_descs: list[np.ndarray] = [] + for i, group in enumerate(block_ids): + group_arr = np.asarray(group) + spec_type = plan.group_spec_types[i] + if _is_attention_spec(spec_type): + fa_region_ids = np.arange(num_fa_regions)[:, None] + all_descs.append( + (fa_region_ids * num_blocks + + group_arr[None, :]).flatten() + ) + elif _is_ssm_spec(spec_type): + ssm_region_ids = np.arange(num_ssm_regions)[:, None] + all_descs.append( + (ssm_region_ids * logical_blocks + + group_arr[None, :] + + num_fa_descs).flatten() + ) + else: + raise ValueError( + f"Unknown spec type {spec_type} at index {i}") + + return np.concatenate(all_descs) + + @staticmethod + def _compute_read_specs_from_plan( + plan: EngineTransferPlan, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, + ) -> list[ReadSpec]: + """Compute read specs from plan. + + For each source rank, includes only the groups whose + source_ranks_per_group contains that rank. + """ + num_groups = len(local_block_ids) + return [ + ReadSpec( + remote_rank=rank, + local_block_ids=[ + list(local_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + remote_block_ids=[ + list(remote_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + ) + for rank in plan.all_source_ranks + ] + + @staticmethod + def _build_local_splits_from_plan( + plan: EngineTransferPlan, + src_blocks_data: list[tuple[int, int, int]], + num_fa_descs: int, + ) -> list[list[tuple[int, int, int]]]: + """Build split handle data for P_TP > D_TP scenario. + + num_fa_descs is the boundary between FA and SSM descriptors. + Split counts are derived from source_ranks_per_group lengths. + FA uses rank_to_attention_slot for the slot offset; + SSM uses the rank's positional index. + """ + fa_num_splits = len(plan.source_ranks_per_group[0]) + + has_ssm_descs = num_fa_descs < len(src_blocks_data) + ssm_num_splits = (len(plan.source_ranks_per_group[-1]) + if has_ssm_descs else 0) + + result: list[list[tuple[int, int, int]]] = [] + + for p_idx, p_rank in enumerate(plan.all_source_ranks): + fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) + + handle: list[tuple[int, int, int]] = [] + for j, (addr, local_len, dev) in enumerate(src_blocks_data): + if j < num_fa_descs: + chunk = local_len // fa_num_splits + handle.append((addr + fa_slot * chunk, chunk, dev)) + else: + chunk = local_len // ssm_num_splits + handle.append((addr + p_idx * chunk, chunk, dev)) + result.append(handle) + + return result + + def _build_local_descs( + self, + base_addresses: list[int], + block_size_ratio: int, + ) -> list[tuple[int, int, int]]: + """Build local (src) descriptor tuples for NIXL registration.""" + assert self.transfer_topo is not None + fa_descs = build_fa_local_descs( + base_addresses, + self.device_id, + self.num_blocks, + block_size_ratio, + self.block_len_per_layer, + self.transfer_topo.is_kv_layout_blocks_first, ) + if not self._has_mamba: + return fa_descs + assert self._conv_decomp is not None + mamba_descs = build_mamba_local_descs( + base_addresses, + self.block_len_per_layer, + self._logical_num_blocks, + block_size_ratio, + self.device_id, + self._conv_decomp, + self._mamba_ssm_size, + self._physical_blocks_per_logical_kv_block, + ) + return fa_descs + mamba_descs def __init__( self, @@ -142,14 +290,13 @@ def __init__( } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) - # ---- Group kinds and model state (derived from model config) ---- - self._group_kinds = tuple( - self._spec_to_group_kind(group.kv_cache_spec) - for group in kv_cache_config.kv_cache_groups - ) + # ---- Model state (derived from model config) ---- mamba_ssm_size = (0, 0) self._conv_decomp: MambaConvSplitInfo | None = None - self._has_mamba = any(k.is_ssm for k in self._group_kinds) + self._has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) + for g in kv_cache_config.kv_cache_groups + ) if self._has_mamba: assert self._is_hma_required from vllm.model_executor.layers.mamba.mamba_utils import ( @@ -895,19 +1042,7 @@ def register_local_xfer_handler( block_size_ratio = self.block_size // block_size local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] - blocks_data = build_local_descs( - has_mamba=self._has_mamba, - conv_decomp=self._conv_decomp, - ssm_sizes=self._mamba_ssm_size, - base_addresses=local_base_addresses, - device_id=self.device_id, - num_blocks=self.num_blocks, - logical_num_blocks=self._logical_num_blocks, - block_size_ratio=block_size_ratio, - block_len_per_layer=self.block_len_per_layer, - is_blocks_first=transfer_topo.is_kv_layout_blocks_first, - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, - ) + blocks_data = self._build_local_descs(local_base_addresses, block_size_ratio) logger.debug( "Created %s blocks for src engine %s and rank %s on device id %s", len(blocks_data), @@ -1000,9 +1135,7 @@ def add_remote_agent( transfer_topo.register_remote_engine(engine_id, transfer_info) logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) - # Generate the pre-computed transfer plan for this remote engine. - # Plan generation is model-aware (if/else), but the per-request - # hot path only consumes the plan (model-agnostic). + # Generate the transfer plan for this remote engine. if self._has_mamba: assert self._conv_decomp is not None self._transfer_plans[engine_id] = generate_mamba_plan( @@ -1010,7 +1143,10 @@ def add_remote_agent( block_len_per_layer=self.block_len_per_layer, remote_info=transfer_info, remote_meta=nixl_agent_meta, - group_kinds=self._group_kinds, + group_spec_types=tuple( + type(g.kv_cache_spec) + for g in self.kv_cache_config.kv_cache_groups + ), conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, ) @@ -1020,6 +1156,10 @@ def add_remote_agent( block_len_per_layer=self.block_len_per_layer, remote_info=transfer_info, remote_meta=nixl_agent_meta, + group_spec_types=tuple( + type(g.kv_cache_spec) + for g in self.kv_cache_config.kv_cache_groups + ), local_physical_blocks_per_logical=( self._physical_blocks_per_logical_kv_block ), @@ -1071,7 +1211,7 @@ def add_remote_agent( # we only do this once per remote tp_size (replica-friendly). self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for handle_data in build_local_splits_from_plan( + for handle_data in self._build_local_splits_from_plan( plan, self.src_blocks_data, self.num_descs, @@ -1083,7 +1223,7 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) ### Register remote agent memory regions - blocks_data = build_remote_descs_from_plan(plan, nixl_agent_meta) + blocks_data = self._build_remote_descs_from_plan(plan, nixl_agent_meta) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), @@ -1627,7 +1767,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): plan.remote_expansion_stride, ) remote_block_ids = meta.remote.block_ids - read_specs = compute_read_specs_from_plan( + read_specs = self._compute_read_specs_from_plan( plan, local_block_ids=meta.local_physical_block_ids, remote_block_ids=remote_block_ids, @@ -1776,8 +1916,9 @@ def _read_blocks( == len(local_block_ids) == len(self.kv_cache_config.kv_cache_groups) ) - # Partial prefix cache hit: trim remote blocks to match local count. - # SSM groups share the block table so counts always match (no-op trim). + # Partial prefix cache hit: just read uncomputed blocks. + # Skip mamba groups — their blocks represent full state (conv+ssm), + # not per-token data, so trimming would corrupt the transfer. remote_block_ids = list(remote_block_ids) for i, remote_group in enumerate(remote_block_ids): num_local_blocks = len(local_block_ids[i]) @@ -1789,16 +1930,15 @@ def _read_blocks( # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. - # Get descs ids. Both calls use the same plan since region counts - # (len(fa_regions), len(ssm_regions)) are model-determined and - # identical across engines. - remote_block_descs_ids = compute_desc_ids_from_plan( + # Get descs ids. + remote_block_descs_ids = self._compute_desc_ids_from_plan( plan, block_ids=remote_block_ids, dst_num_blocks=self.dst_num_blocks[dst_engine_id], + block_size_ratio=None, physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) - local_block_descs_ids = compute_desc_ids_from_plan( + local_block_descs_ids = self._compute_desc_ids_from_plan( plan, block_ids=local_block_ids, dst_num_blocks=self.dst_num_blocks[self.engine_id], From 9d4ffbe386f30d012606f4ba915ac2609b3bd160 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Fri, 24 Apr 2026 14:23:37 -0400 Subject: [PATCH 35/49] fix: pre-commit lint (unused var, line length, formatting) Signed-off-by: ZhanqiuHu --- .../unit/test_nixl_connector_hma.py | 3 +- .../kv_connector/unit/test_transfer_plan.py | 23 ++++++++++---- .../kv_connector/v1/nixl/transfer_plan.py | 11 +++---- .../kv_connector/v1/nixl/worker.py | 30 ++++++++----------- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6c1fdfd511ea..127db16f2eb5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -93,8 +93,7 @@ def test_logical_to_kernel_block_ids_with_hma(): @pytest.mark.cpu_test @pytest.mark.parametrize( - "group_spec_types,expansion_stride,remote_block_ids," - "expected_remote_block_ids", + "group_spec_types,expansion_stride,remote_block_ids,expected_remote_block_ids", [ pytest.param( ("FullAttentionSpec", "SlidingWindowSpec"), diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 4938019ee57e..377bc773cbc8 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -235,8 +235,11 @@ def test_compute_desc_ids(self, tp_size, remote_tp_size): block_ids = ([1, 5, 10, 20],) ids = NixlConnectorWorker._compute_desc_ids_from_plan( - plan, block_ids, dst_num_blocks=num_blocks, - block_size_ratio=None, physical_blocks_per_logical=1, + plan, + block_ids, + dst_num_blocks=num_blocks, + block_size_ratio=None, + physical_blocks_per_logical=1, ) num_regions = len(plan.fa_regions) @@ -278,7 +281,9 @@ def test_compute_read_specs(self, tp_size, remote_tp_size): local_ids = ([1, 2, 3],) remote_ids = ([4, 5, 6],) - specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) assert len(specs) == len(plan.all_source_ranks) for spec in specs: @@ -489,7 +494,9 @@ def test_all_source_ranks_serve_fa(self): local_ids = ([1, 2], [3, 4]) remote_ids = ([5, 6], [7, 8]) - specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) assert len(specs) == 2 for spec in specs: assert list(spec.local_block_ids[0]) == [1, 2] @@ -512,7 +519,9 @@ def test_non_fa_rank_skips_fa_groups(self): local_ids = ([1, 2], [3, 4]) remote_ids = ([5, 6], [7, 8]) - specs = NixlConnectorWorker._compute_read_specs_from_plan(plan, local_ids, remote_ids) + specs = NixlConnectorWorker._compute_read_specs_from_plan( + plan, local_ids, remote_ids + ) assert len(specs) == 3 # Rank 0 (FA source): gets all groups @@ -560,7 +569,9 @@ def test_fa_and_ssm_different_split_factors(self): (3000, 400, 0), # SSM desc 0 ] - splits = NixlConnectorWorker._build_local_splits_from_plan(plan, src_blocks_data, 2) + splits = NixlConnectorWorker._build_local_splits_from_plan( + plan, src_blocks_data, 2 + ) assert len(splits) == 2 # 2 source ranks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index ed5fc85d494c..003019a63afc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -149,14 +149,14 @@ def _compute_tp_mapping( elif tp_size >= remote_tp_size: attn_ranks = [tp_rank * remote_tp_size // tp_size] else: - # P (remote TP size) > D (local TP size): one local rank reads from multiple remote ranks. + # P (remote TP) > D (local TP): one local rank + # reads from multiple remote ranks. # GQA dedup: when K < remote_tp_size, several remote ranks # hold the same KV head. np.unique keeps only the first # rank per unique head so we don't issue redundant reads. abs_tp = remote_tp_size // tp_size start = tp_rank * abs_tp - heads = (np.arange(start, start + abs_tp) - * total_num_kv_heads // remote_tp_size) + heads = np.arange(start, start + abs_tp) * total_num_kv_heads // remote_tp_size _, unique_idx = np.unique(heads, return_index=True) attn_ranks = (start + np.sort(unique_idx)).tolist() @@ -193,8 +193,7 @@ def _compute_tp_mapping( rank_offset_factor = 0 elif tp_size > total_num_kv_heads: local_head = tp_rank * total_num_kv_heads // tp_size - p_start = (attn_ranks[0] * total_num_kv_heads - // remote_tp_size) + p_start = attn_ranks[0] * total_num_kv_heads // remote_tp_size rank_offset_factor = local_head - p_start else: rank_offset_factor = tp_rank % (tp_size // remote_tp_size) @@ -492,5 +491,3 @@ def build_mamba_local_descs( ) ) return result - - diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 757e7a3bc29e..a501b2ea2310 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -81,7 +81,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -105,8 +105,7 @@ def _build_remote_descs_from_plan( for region in plan.all_regions: base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] for blk in range(region.num_blocks): - addr = (base_addr + blk * region.page_stride - + region.offset_in_page) + addr = base_addr + blk * region.page_stride + region.offset_in_page result.append((addr, region.descriptor_bytes, dev_id)) return result @@ -138,19 +137,19 @@ def _compute_desc_ids_from_plan( if _is_attention_spec(spec_type): fa_region_ids = np.arange(num_fa_regions)[:, None] all_descs.append( - (fa_region_ids * num_blocks - + group_arr[None, :]).flatten() + (fa_region_ids * num_blocks + group_arr[None, :]).flatten() ) elif _is_ssm_spec(spec_type): ssm_region_ids = np.arange(num_ssm_regions)[:, None] all_descs.append( - (ssm_region_ids * logical_blocks - + group_arr[None, :] - + num_fa_descs).flatten() + ( + ssm_region_ids * logical_blocks + + group_arr[None, :] + + num_fa_descs + ).flatten() ) else: - raise ValueError( - f"Unknown spec type {spec_type} at index {i}") + raise ValueError(f"Unknown spec type {spec_type} at index {i}") return np.concatenate(all_descs) @@ -201,8 +200,7 @@ def _build_local_splits_from_plan( fa_num_splits = len(plan.source_ranks_per_group[0]) has_ssm_descs = num_fa_descs < len(src_blocks_data) - ssm_num_splits = (len(plan.source_ranks_per_group[-1]) - if has_ssm_descs else 0) + ssm_num_splits = len(plan.source_ranks_per_group[-1]) if has_ssm_descs else 0 result: list[list[tuple[int, int, int]]] = [] @@ -1037,8 +1035,6 @@ def register_local_xfer_handler( data copy correctness. """ assert self.transfer_topo is not None - transfer_topo = self.transfer_topo - block_size_ratio = self.block_size // block_size local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] @@ -1144,8 +1140,7 @@ def add_remote_agent( remote_info=transfer_info, remote_meta=nixl_agent_meta, group_spec_types=tuple( - type(g.kv_cache_spec) - for g in self.kv_cache_config.kv_cache_groups + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups ), conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, @@ -1157,8 +1152,7 @@ def add_remote_agent( remote_info=transfer_info, remote_meta=nixl_agent_meta, group_spec_types=tuple( - type(g.kv_cache_spec) - for g in self.kv_cache_config.kv_cache_groups + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups ), local_physical_blocks_per_logical=( self._physical_blocks_per_logical_kv_block From 7b8922bd23112267873042f5e9901df136cd619e Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Fri, 24 Apr 2026 14:44:30 -0400 Subject: [PATCH 36/49] clean Signed-off-by: ZhanqiuHu --- .../kv_connector/v1/nixl/transfer_plan.py | 100 ++++++++++-------- .../kv_connector/v1/nixl/worker.py | 80 ++++++++------ .../v1/ssm_conv_transfer_utils.py | 13 ++- 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 003019a63afc..25edb41fd901 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -53,11 +53,11 @@ def _is_ssm_spec(spec_type: type[KVCacheSpec]) -> bool: @dataclass(frozen=True) class RegionPlan: - """Pre-computed plan for one descriptor region. + """Geometry for one descriptor region. Everything needed to build NIXL descriptors and compute descriptor - IDs is baked in. The executor plugs in per-rank ``base_addr`` and - ``device_id`` from NixlAgentMetadata. + IDs is baked in. The caller plugs in ``base_addr`` and + ``device_id`` when constructing the final descriptor tuples. """ layer_idx: int @@ -355,13 +355,19 @@ def generate_mamba_plan( conv_size_remote = remote_ssm_sizes[0] ssm_num_blocks = remote_meta.num_blocks // remote_phys_ratio + # Mamba conv state is always TP-sharded, even when attention KV + # is replicated (num_kv_heads < tp_size). if tp_size >= remote_tp_size: + # D_TP >= P_TP: P page is larger, D reads its slice. conv_offsets = conv_decomp.remote_conv_offsets( local_offset, effective_ratio, ) ssm_read_size = ssm_sizes[1] else: + # NOTE (ZhanqiuHu): P_TP > D_TP, so P pages are smaller + # than D's. conv_decomp has D-sized dimensions, but we + # need P-sized offsets. Scale down by abs_ratio. abs_ratio = remote_tp_size // tp_size xb_p = conv_decomp.x_bytes // abs_ratio bb_p = conv_decomp.b_bytes // abs_ratio @@ -372,6 +378,8 @@ def generate_mamba_plan( ] ssm_read_size = remote_ssm_sizes[1] + # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], + # in case block lengths vary across layers (e.g. MLA). ssm_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): page_stride = remote_block_lens[i] * remote_phys_ratio @@ -413,54 +421,57 @@ def generate_mamba_plan( # ====================================================================== -def build_fa_local_descs( - base_addresses: list[int], - device_id: int, +def build_fa_local_regions( num_blocks: int, block_size_ratio: int, block_len_per_layer: list[int], is_blocks_first: bool, -) -> list[tuple[int, int, int]]: - """Build FA local descriptors for NIXL registration.""" - result: list[tuple[int, int, int]] = [] +) -> list[RegionPlan]: + """Build FA local region specs for NIXL registration.""" + regions: list[RegionPlan] = [] n_blocks = num_blocks * block_size_ratio - for i, base_addr in enumerate(base_addresses): + for i in range(len(block_len_per_layer)): kv_block_len = ( _get_kv_block_len(i, block_len_per_layer, is_blocks_first) // block_size_ratio ) page_stride = block_len_per_layer[i] // block_size_ratio - for block_id in range(n_blocks): - result.append( - ( - base_addr + block_id * page_stride, - kv_block_len, - device_id, - ) + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=kv_block_len, + offset_in_page=0, + page_stride=page_stride, + num_blocks=n_blocks, ) + ) if is_blocks_first: second_split = _get_kv_block_len( i, block_len_per_layer, is_blocks_first, ) - for block_id in range(n_blocks): - v_addr = base_addr + block_id * page_stride + kv_block_len - result.append((v_addr, second_split, device_id)) - return result + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=second_split, + offset_in_page=kv_block_len, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + return regions -def build_mamba_local_descs( - base_addresses: list[int], +def build_mamba_local_regions( block_len_per_layer: list[int], logical_num_blocks: int, block_size_ratio: int, - device_id: int, conv_decomp: MambaConvSplitInfo, ssm_sizes: tuple[int, int], physical_blocks_per_logical: int, -) -> list[tuple[int, int, int]]: - """Build 4 SSM descriptor regions (x, B, C, ssm) per layer.""" +) -> list[RegionPlan]: + """Build 4 SSM region specs (x, B, C, ssm) per layer.""" assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 " f"is not tested. Got {block_size_ratio=}." @@ -470,24 +481,27 @@ def build_mamba_local_descs( n_blocks = logical_num_blocks * block_size_ratio phys_ratio = physical_blocks_per_logical - result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): + regions: list[RegionPlan] = [] + for i in range(len(block_len_per_layer)): page_stride = block_len_per_layer[i] // block_size_ratio * phys_ratio for off, sz in conv_offsets: - for blk in range(n_blocks): - result.append( - ( - base_addr + blk * page_stride + off, - sz, - device_id, - ) - ) - for blk in range(n_blocks): - result.append( - ( - base_addr + blk * page_stride + conv_size, - ssm_size, - device_id, + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=sz, + offset_in_page=off, + page_stride=page_stride, + num_blocks=n_blocks, ) ) - return result + # SSM temporal state follows the conv state. + regions.append( + RegionPlan( + layer_idx=i, + descriptor_bytes=ssm_size, + offset_in_page=conv_size, + page_stride=page_stride, + num_blocks=n_blocks, + ) + ) + return regions diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a501b2ea2310..5c08e7ee5aa9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -48,8 +48,8 @@ ReadSpec, _is_attention_spec, _is_ssm_spec, - build_fa_local_descs, - build_mamba_local_descs, + build_fa_local_regions, + build_mamba_local_regions, generate_dense_plan, generate_mamba_plan, ) @@ -90,7 +90,7 @@ class NixlConnectorWorker: """Implementation of Worker side methods""" # ------------------------------------------------------------------ - # Plan executors (pure functions, no self access) + # Plan executors (static — no self access) # ------------------------------------------------------------------ @staticmethod @@ -130,6 +130,13 @@ def _compute_desc_ids_from_plan( num_fa_descs = num_fa_regions * num_blocks + # NOTE (NickLucche) With HMA, every kv group has the same number + # of layers and layers from different groups share the same kv + # tensor. Therefore we compute desc IDs per group using the + # right stride: + # FA descs have num_blocks entries per region (kernel granularity), + # SSM descs have logical_blocks entries per region (no kernel + # splitting). all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): group_arr = np.asarray(group) @@ -140,6 +147,13 @@ def _compute_desc_ids_from_plan( (fa_region_ids * num_blocks + group_arr[None, :]).flatten() ) elif _is_ssm_spec(spec_type): + # NOTE (NickLucche) SSM and Attention block regions can + # be exchanged arbitrarily by manager. Therefore, descs + # are laid out as: + # [descs_fa (all regions) | descs_ssm (all regions)]. + # num_fa_descs offset must be computed per-engine since + # P and D can have different num_blocks (and thus + # different FA desc counts). ssm_region_ids = np.arange(num_ssm_regions)[:, None] all_descs.append( ( @@ -197,6 +211,7 @@ def _build_local_splits_from_plan( FA uses rank_to_attention_slot for the slot offset; SSM uses the rank's positional index. """ + # Mamba-HMA: FA and Mamba use different split factors. fa_num_splits = len(plan.source_ranks_per_group[0]) has_ssm_descs = num_fa_descs < len(src_blocks_data) @@ -226,28 +241,37 @@ def _build_local_descs( ) -> list[tuple[int, int, int]]: """Build local (src) descriptor tuples for NIXL registration.""" assert self.transfer_topo is not None - fa_descs = build_fa_local_descs( - base_addresses, - self.device_id, + fa_regions = build_fa_local_regions( self.num_blocks, block_size_ratio, self.block_len_per_layer, self.transfer_topo.is_kv_layout_blocks_first, ) - if not self._has_mamba: - return fa_descs - assert self._conv_decomp is not None - mamba_descs = build_mamba_local_descs( - base_addresses, - self.block_len_per_layer, - self._logical_num_blocks, - block_size_ratio, - self.device_id, - self._conv_decomp, - self._mamba_ssm_size, - self._physical_blocks_per_logical_kv_block, - ) - return fa_descs + mamba_descs + if self._has_mamba: + # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the + # 3-read split is unnecessary — a single conv desc per block + # suffices. Consider adding a fast path. Currently we always + # register 4 regions because local descs are created before + # knowing the remote TP. + assert self._conv_decomp is not None + mamba_regions = build_mamba_local_regions( + self.block_len_per_layer, + self._logical_num_blocks, + block_size_ratio, + self._conv_decomp, + self._mamba_ssm_size, + self._physical_blocks_per_logical_kv_block, + ) + else: + mamba_regions = [] + + result: list[tuple[int, int, int]] = [] + for region in fa_regions + mamba_regions: + base = base_addresses[region.layer_idx] + for blk in range(region.num_blocks): + addr = base + blk * region.page_stride + region.offset_in_page + result.append((addr, region.descriptor_bytes, self.device_id)) + return result def __init__( self, @@ -290,6 +314,9 @@ def __init__( # ---- Model state (derived from model config) ---- mamba_ssm_size = (0, 0) + # Conv state sub-projection decomposition (None when no Mamba). + # The 3-read transfer requires DS (dim, state_len) conv layout so + # that x/B/C sub-projections are contiguous in memory. self._conv_decomp: MambaConvSplitInfo | None = None self._has_mamba = any( isinstance(g.kv_cache_spec, MambaSpec) @@ -310,22 +337,11 @@ def __init__( for spec in self._layer_specs.values() if isinstance(spec, MambaSpec) ) - conv_nbytes, ssm_nbytes = ( - torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc] - torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc] - ) - conv_shape, ssm_shape = ( - torch.Size(mamba_spec.shapes[0]), - torch.Size(mamba_spec.shapes[1]), - ) - mamba_ssm_size = ( - conv_shape.numel() * conv_nbytes, - ssm_shape.numel() * ssm_nbytes, - ) self._conv_decomp = derive_mamba_conv_split( mamba_spec, vllm_config.parallel_config.tensor_parallel_size, ) + mamba_ssm_size = self._conv_decomp.ssm_sizes self._mamba_ssm_size = mamba_ssm_size # Agent. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py index 309426814c68..00b8e2bb7275 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py @@ -31,6 +31,7 @@ class MambaConvSplitInfo: x_local: int # intermediate_size / TP (columns for x) b_local: int # groups_ss / TP (columns for B; C is same size) conv_dtype_size: int # bytes per element (e.g. 2 for float16) + ssm_sizes: tuple[int, int] # (conv_state_bytes, ssm_state_bytes) @property def conv_dim_local(self) -> int: @@ -99,8 +100,8 @@ def derive_mamba_conv_split( local_tp: this engine's tensor-parallel size. Returns: - MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and - conv_dtype_size. + MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, + conv_dtype_size, and ssm_sizes (conv_state_bytes, ssm_state_bytes). """ if mamba_spec.mamba_type != "mamba2": raise NotImplementedError( @@ -142,12 +143,20 @@ def derive_mamba_conv_split( dtype=mamba_spec.dtypes[0], # type: ignore[misc] ).element_size() + ssm_dtype_size = torch.tensor( + [], + dtype=mamba_spec.dtypes[1], # type: ignore[misc] + ).element_size() + conv_state_bytes = torch.Size(mamba_spec.shapes[0]).numel() * conv_dtype_size + ssm_state_bytes = torch.Size(mamba_spec.shapes[1]).numel() * ssm_dtype_size + # Divide by TP to get per-rank column counts. return MambaConvSplitInfo( conv_rows=conv_rows, x_local=intermediate_size // local_tp, b_local=groups_ss // local_tp, conv_dtype_size=conv_dtype_size, + ssm_sizes=(conv_state_bytes, ssm_state_bytes), ) From 72ece8228c7ff27d3315e17ff61ddb9441b8e6d6 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Fri, 24 Apr 2026 14:47:47 -0400 Subject: [PATCH 37/49] rename Signed-off-by: ZhanqiuHu --- .../kv_connector/v1/nixl/transfer_plan.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 25edb41fd901..441d549f89a9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -272,7 +272,7 @@ def generate_dense_plan( """Generate transfer plan for dense (attention-only) models.""" block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size - m = _compute_tp_mapping( + tp_mapping = _compute_tp_mapping( transfer_topo.tp_rank, transfer_topo.tp_size, remote_info.remote_tp_size, @@ -286,8 +286,8 @@ def generate_dense_plan( remote_block_lens=remote_meta.block_lens, is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, - num_attn_reads=len(m.source_ranks_per_group[0]), - rank_offset_factor=m.rank_offset_factor, + num_attn_reads=len(tp_mapping.source_ranks_per_group[0]), + rank_offset_factor=tp_mapping.rank_offset_factor, remote_num_blocks=remote_meta.num_blocks, ) @@ -295,9 +295,9 @@ def generate_dense_plan( fa_regions=tuple(fa_regions), ssm_regions=(), group_spec_types=group_spec_types, - source_ranks_per_group=m.source_ranks_per_group, - all_source_ranks=m.all_source_ranks, - rank_to_attention_slot=m.rank_to_attention_slot, + source_ranks_per_group=tp_mapping.source_ranks_per_group, + all_source_ranks=tp_mapping.all_source_ranks, + rank_to_attention_slot=tp_mapping.rank_to_attention_slot, remote_expansion_stride=local_physical_blocks_per_logical, ) @@ -329,7 +329,7 @@ def generate_mamba_plan( f"is not tested. Got {block_size_ratio=}." ) - m = _compute_tp_mapping( + tp_mapping = _compute_tp_mapping( tp_rank, tp_size, remote_tp_size, @@ -344,8 +344,8 @@ def generate_mamba_plan( remote_block_lens=remote_block_lens, is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, - num_attn_reads=len(m.source_ranks_per_group[0]), - rank_offset_factor=m.rank_offset_factor, + num_attn_reads=len(tp_mapping.source_ranks_per_group[0]), + rank_offset_factor=tp_mapping.rank_offset_factor, remote_num_blocks=remote_meta.num_blocks, ) @@ -409,9 +409,9 @@ def generate_mamba_plan( fa_regions=tuple(fa_regions), ssm_regions=tuple(ssm_regions), group_spec_types=group_spec_types, - source_ranks_per_group=m.source_ranks_per_group, - all_source_ranks=m.all_source_ranks, - rank_to_attention_slot=m.rank_to_attention_slot, + source_ranks_per_group=tp_mapping.source_ranks_per_group, + all_source_ranks=tp_mapping.all_source_ranks, + rank_to_attention_slot=tp_mapping.rank_to_attention_slot, remote_expansion_stride=remote_phys_ratio, ) From 44caacdfc7734ea3e3fb211f4c7a3e6e78e2b9d9 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Sun, 26 Apr 2026 02:47:01 -0400 Subject: [PATCH 38/49] test case Signed-off-by: ZhanqiuHu --- tests/v1/kv_connector/unit/test_nixl_connector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 46f6ba706708..fb4b641e1376 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -472,7 +472,6 @@ def __init__( is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) @@ -727,7 +726,6 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] - worker.num_descs = len(worker.src_blocks_data) def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size @@ -2437,7 +2435,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], - physical_blocks_per_logical=decode_worker._physical_blocks_per_logical_kv_block, tensor_shape=test_shape, ) From c3a5c65eceedacbb2609c88ee5e1dce4bcae827b Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Sun, 26 Apr 2026 02:51:30 -0400 Subject: [PATCH 39/49] update test Signed-off-by: ZhanqiuHu --- tests/v1/kv_connector/unit/test_nixl_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fb4b641e1376..d20f026da241 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -726,6 +726,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + worker.num_descs = len(worker.src_blocks_data) def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size From bf52923a5363fc1c6c5728d5403a1cddefeb62c3 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Sun, 26 Apr 2026 20:47:38 +0000 Subject: [PATCH 40/49] test Signed-off-by: Zhanqiu Hu --- .../kv_connector/v1/nixl/transfer_plan.py | 17 ++++++++++----- .../kv_connector/v1/nixl/worker.py | 21 +++++++++++++++++-- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 441d549f89a9..b968845b689c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -281,12 +281,17 @@ def generate_dense_plan( group_spec_types=group_spec_types, ) + num_attn_reads = next( + len(ranks) + for t, ranks in zip(group_spec_types, tp_mapping.source_ranks_per_group) + if _is_attention_spec(t) + ) fa_regions = _build_fa_regions( block_len_per_layer=block_len_per_layer, remote_block_lens=remote_meta.block_lens, is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, - num_attn_reads=len(tp_mapping.source_ranks_per_group[0]), + num_attn_reads=num_attn_reads, rank_offset_factor=tp_mapping.rank_offset_factor, remote_num_blocks=remote_meta.num_blocks, ) @@ -313,9 +318,6 @@ def generate_mamba_plan( ssm_sizes: tuple[int, int], ) -> EngineTransferPlan: """Generate transfer plan for hybrid Mamba (SSM + FA) models.""" - assert _is_attention_spec(group_spec_types[0]), ( - f"First group must be an attention spec, got {group_spec_types[0]}" - ) tp_rank = transfer_topo.tp_rank tp_size = transfer_topo.tp_size remote_tp_size = remote_info.remote_tp_size @@ -339,12 +341,17 @@ def generate_mamba_plan( ) # ---- FA regions ---- + num_attn_reads = next( + len(ranks) + for t, ranks in zip(group_spec_types, tp_mapping.source_ranks_per_group) + if _is_attention_spec(t) + ) fa_regions = _build_fa_regions( block_len_per_layer=block_len_per_layer, remote_block_lens=remote_block_lens, is_blocks_first=transfer_topo.is_kv_layout_blocks_first, block_size_ratio=block_size_ratio, - num_attn_reads=len(tp_mapping.source_ranks_per_group[0]), + num_attn_reads=num_attn_reads, rank_offset_factor=tp_mapping.rank_offset_factor, remote_num_blocks=remote_meta.num_blocks, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 5c08e7ee5aa9..62e786f99378 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -212,10 +212,27 @@ def _build_local_splits_from_plan( SSM uses the rank's positional index. """ # Mamba-HMA: FA and Mamba use different split factors. - fa_num_splits = len(plan.source_ranks_per_group[0]) + fa_num_splits = next( + len(ranks) + for t, ranks in zip(plan.group_spec_types, plan.source_ranks_per_group) + if _is_attention_spec(t) + ) has_ssm_descs = num_fa_descs < len(src_blocks_data) - ssm_num_splits = len(plan.source_ranks_per_group[-1]) if has_ssm_descs else 0 + ssm_num_splits = ( + next( + ( + len(ranks) + for t, ranks in zip( + plan.group_spec_types, plan.source_ranks_per_group + ) + if _is_ssm_spec(t) + ), + 0, + ) + if has_ssm_descs + else 0 + ) result: list[list[tuple[int, int, int]]] = [] From 0dc8e33267d03b22afe39da577678bd1c9c38382 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Mon, 27 Apr 2026 00:39:45 +0000 Subject: [PATCH 41/49] updates Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/v1/nixl/transfer_plan.py | 4 ++++ vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index b968845b689c..6fca950c0f57 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -222,6 +222,10 @@ def _build_fa_regions( V bytes = local_block_len / num_attn_reads (no block_size_ratio). Offset = rank_offset_factor * remote_kv_block_len per layer. """ + assert len(remote_block_lens) == len(block_len_per_layer), ( + f"Layer count mismatch: remote has {len(remote_block_lens)} layers " + f"but local has {len(block_len_per_layer)}" + ) fa_regions: list[RegionPlan] = [] for i in range(len(remote_block_lens)): local_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 62e786f99378..cb64a80d4173 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -119,6 +119,10 @@ def _compute_desc_ids_from_plan( physical_blocks_per_logical: int, ) -> np.ndarray: """Compute NIXL descriptor IDs for given block IDs.""" + assert len(block_ids) == len(plan.group_spec_types), ( + f"block_ids has {len(block_ids)} groups but plan has " + f"{len(plan.group_spec_types)} group_spec_types" + ) num_fa_regions = len(plan.fa_regions) num_ssm_regions = len(plan.ssm_regions) From f8a01e68e24c995b2623ec140021fdb750a78aff Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Mon, 27 Apr 2026 01:49:16 +0000 Subject: [PATCH 42/49] fix: add Mamba guard to block ID trimming in _read_blocks Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index cb64a80d4173..c9b9ef599823 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1951,10 +1951,14 @@ def _read_blocks( # Skip mamba groups — their blocks represent full state (conv+ssm), # not per-token data, so trimming would corrupt the transfer. remote_block_ids = list(remote_block_ids) + group_specs = self.kv_cache_config.kv_cache_groups for i, remote_group in enumerate(remote_block_ids): + num_remote_blocks = len(remote_group) num_local_blocks = len(local_block_ids[i]) - assert num_local_blocks <= len(remote_group) - if num_local_blocks < len(remote_group): + is_mamba = isinstance(group_specs[i].kv_cache_spec, MambaSpec) + if not is_mamba: + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks and not is_mamba: remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from From a6e52666cf76372d0466dc5f847fe4d099858939 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Mon, 27 Apr 2026 12:57:48 +0000 Subject: [PATCH 43/49] updates Signed-off-by: Zhanqiu Hu --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 ++++ .../kv_transfer/kv_connector/v1/nixl/metadata.py | 4 +++- .../kv_transfer/kv_connector/v1/nixl/worker.py | 11 ++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index d20f026da241..3803e4fd3869 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -527,6 +527,7 @@ def _nixl_handshake( block_size=self.block_size, ssm_sizes=(0, 0), attn_backend_name=self.backend_name, + physical_blocks_per_logical_kv_block=1, ), remote_tp_rank=remote_tp_rank, remote_tp_size=remote_tp_size, @@ -979,6 +980,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( block_size=worker.block_size, ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) with pytest.raises(RuntimeError): @@ -1036,6 +1038,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( block_size=worker.block_size, ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -2355,6 +2358,7 @@ def test_compatibility_hash_validation( block_size=prefill_block_size, ssm_sizes=(0, 0), attn_backend_name=decode_worker.backend_name, + physical_blocks_per_logical_kv_block=1, ) handshake_payload = NixlHandshakePayload( compatibility_hash=remote_hash, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index 71ebbf1174fb..c56e373ba99d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -32,8 +32,9 @@ # Version History: # 1: Initial version with compatibility checking # 2: Add remote_request_id to kv_transfer_params +# 3: Add physical_blocks_per_logical_kv_block to NixlAgentMetadata # -NIXL_CONNECTOR_VERSION: int = 2 +NIXL_CONNECTOR_VERSION: int = 3 @dataclass @@ -48,6 +49,7 @@ class NixlAgentMetadata: block_size: int ssm_sizes: tuple[int, int] attn_backend_name: str + physical_blocks_per_logical_kv_block: int @dataclass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index c9b9ef599823..97b88ad06396 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -59,7 +59,6 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, - compute_physical_blocks_per_logical, derive_mamba_conv_split, ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config @@ -1047,6 +1046,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size=self.block_size, ssm_sizes=self._mamba_ssm_size, attn_backend_name=self.backend_name, + physical_blocks_per_logical_kv_block=( + self._physical_blocks_per_logical_kv_block + ), ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1152,12 +1154,7 @@ def add_remote_agent( assert self.transfer_topo is not None transfer_topo = self.transfer_topo physical_blocks_per_logical = ( - compute_physical_blocks_per_logical( - nixl_agent_meta.ssm_sizes, - nixl_agent_meta.block_lens[0], - ) - if self._has_mamba - else 1 + nixl_agent_meta.physical_blocks_per_logical_kv_block ) transfer_info = EngineTransferInfo( remote_tp_size=remote_tp_size, From 2c920b5dcc0a8b368c6b8261c4a111c07d752031 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Sun, 26 Apr 2026 06:15:35 +0000 Subject: [PATCH 44/49] add gemma4 heterotp support for NIXL KV transfer Per-group TP mapping, sub-descriptor splitting, and block ID remapping for heterogeneous attention models (SWA + FA groups with different total_num_kv_heads). Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 200 ++++++++++++++- .../kv_connector/v1/nixl/metadata.py | 6 + .../kv_connector/v1/nixl/transfer_plan.py | 235 +++++++++++++++++- .../kv_connector/v1/nixl/worker.py | 57 ++++- .../layers/attention/attention.py | 4 + vllm/model_executor/models/gemma4.py | 1 + vllm/v1/kv_cache_interface.py | 1 + 7 files changed, 492 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 377bc773cbc8..812da487c8eb 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -22,6 +22,7 @@ EngineTransferPlan, RegionPlan, generate_dense_plan, + generate_gemma4_plan, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, @@ -414,7 +415,7 @@ def _make_mamba_plan_for_desc_ids( group_spec_types=group_spec_types, source_ranks_per_group=source_ranks_per_group, all_source_ranks=(0,), - rank_to_attention_slot={0: 0}, + rank_to_attention_slot=({0: 0},) * len(group_kinds), remote_expansion_stride=1, ) @@ -487,7 +488,7 @@ def test_all_source_ranks_serve_fa(self): group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(both, both), all_source_ranks=(0, 1), - rank_to_attention_slot={0: 0, 1: 1}, + rank_to_attention_slot=({0: 0, 1: 1}, {0: 0, 1: 1}), remote_expansion_stride=1, ) @@ -512,7 +513,7 @@ def test_non_fa_rank_skips_fa_groups(self): group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1, 2), - rank_to_attention_slot={0: 0}, + rank_to_attention_slot=({0: 0}, {0: 0}), remote_expansion_stride=1, ) @@ -558,7 +559,7 @@ def test_fa_and_ssm_different_split_factors(self): group_spec_types=(FullAttentionSpec, MambaSpec), source_ranks_per_group=(fa_readers, ssm_readers), all_source_ranks=(0, 1), - rank_to_attention_slot={0: 0, 1: 0}, + rank_to_attention_slot=({0: 0, 1: 0}, {0: 0, 1: 0}), remote_expansion_stride=1, ) @@ -584,3 +585,194 @@ def test_fa_and_ssm_different_split_factors(self): # FA: chunk=200//1=200, slot=0 (skip_fa) → (1000, 200, 0), (2000, 200, 0) # SSM: chunk=400//2=200, idx=1 → (3200, 200, 0) assert splits[1] == [(1000, 200, 0), (2000, 200, 0), (3200, 200, 0)] + + +# ====================================================================== +# Gemma4 HeteroTP tests +# ====================================================================== + + +def _make_gemma4_plan_params( + tp_rank: int = 0, + tp_size: int = 4, + remote_tp_size: int = 2, +) -> dict: + """Build kwargs for generate_gemma4_plan at 2p4d. + + Gemma4-26B at P_TP=2, D_TP=4: + SWA: 25 layers, K=8, head_dim=256, block_size=16 on both sides + FA: 5 layers, K=2, head_dim=512, P block_size=32, D block_size=16 + + With page unification + HMA, all groups share one physical pool. + page_size: P=65536, D=32768 → remote_to_local_page_ratio=2. + For simplicity, use 2 physical layers in tests. + """ + # D side (local): kv_heads_per_rank for all groups = page_size / block_size + # page_size = 32768 for both groups at D_TP=4. + d_page = 32768 + p_page = 65536 + num_layers = 2 + + return dict( + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=False, + total_num_kv_heads=8, + block_size=16, + is_blocks_first=False, + ), + block_len_per_layer=[d_page] * num_layers, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=16, + remote_block_len=p_page, + remote_physical_blocks_per_logical=1, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0x10000 * (i + 1) for i in range(num_layers)], + num_blocks=500, + block_lens=[p_page] * num_layers, + block_size=16, + ), + group_kinds=(GroupKind.SWA, GroupKind.FA), + total_num_kv_heads_per_group=(8, 2), + local_tokens_per_block=(16, 16), + remote_tokens_per_block=(16, 32), + ) + + +class TestGemma4PlanStructure: + """Verify plan structure for Gemma4-style heterogeneous attention.""" + + def test_plan_fields_2p4d_rank0(self): + """D rank 0 at 2p4d: ratio=2, SWA head-split, FA multi-block.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=0)) + + assert plan.remote_to_local_page_ratio == 2 + assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) + assert plan.local_blocks_per_remote_block == (1, 2) + assert plan.sub_desc_index_per_group == (0, 0) # rank 0: index=0 + assert plan.all_source_ranks == (0,) + assert plan.source_ranks_per_group == ((0,), (0,)) + + def test_plan_fields_2p4d_rank1(self): + """D rank 1 at 2p4d: SWA reads second descriptor (index=1).""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=1)) + + assert plan.sub_desc_index_per_group == (1, 0) # rank 1: SWA=1 + assert plan.local_blocks_per_remote_block == (1, 2) + assert plan.all_source_ranks == (0,) + + def test_plan_fields_2p4d_rank2(self): + """D rank 2 reads from P rank 1.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=2)) + + assert plan.all_source_ranks == (1,) + assert plan.sub_desc_index_per_group == (0, 0) + + def test_fa_regions_have_multiple_descs_per_block(self): + """FA regions should have descs_per_block = page ratio.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + + for region in plan.fa_regions: + assert region.descs_per_block == 2 + assert region.desc_stride_bytes == 32768 # D page size + + +class TestGemma4RemoteDescs: + """Verify remote descriptor building with sub-descriptors.""" + + def test_descs_per_block(self): + """Each region produces num_blocks * descs_per_block descriptors.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[65536, 65536], + ) + descs = build_remote_descs_from_plan(plan, meta) + + # 2 layers × 1 region/layer × 500 blocks × 2 descs/block = 2000 + expected_count = 2 * 500 * 2 + assert len(descs) == expected_count + + def test_desc_stride_within_block(self): + """Descriptors within a block should be desc_stride_bytes apart.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[65536, 65536], + ) + descs = build_remote_descs_from_plan(plan, meta) + + # First block, layer 0: descriptor 0 and descriptor 1 + addr_d0, len_d0, _ = descs[0] + addr_d1, len_d1, _ = descs[1] + assert addr_d1 - addr_d0 == 32768 # desc_stride_bytes + assert len_d0 == len_d1 == 32768 # descriptor_bytes + + +class TestGemma4DescIds: + """Verify desc ID computation with sub-desc block IDs.""" + + def test_remapped_block_ids(self): + """After remapping, descriptor indices are correct.""" + plan = generate_gemma4_plan(**_make_gemma4_plan_params()) + + # SWA blocks [3, 7], FA blocks [10, 11] + # Remapped to descriptor indices: + # SWA (sub_desc_index=0): [3*2+0, 7*2+0] = [6, 14] + # FA (2 local per remote): [10*2, 10*2+1, 11*2, 11*2+1] = [20,21,22,23] + # + # dst_num_blocks = 500 * 2 = 1000 (num_blocks * descs_per_block) + # 2 fa_regions (2 layers), each with 1000 desc slots + # SWA: [0*1000+6, 0*1000+14, 1*1000+6, 1*1000+14] + # = [6, 14, 1006, 1014] + # FA: [0*1000+20, 0*1000+21, 0*1000+22, 0*1000+23, + # 1*1000+20, 1*1000+21, 1*1000+22, 1*1000+23] + # = [20, 21, 22, 23, 1020, 1021, 1022, 1023] + + # First remap via read specs to get descriptor-level block IDs + local_swa = [10, 11] + local_fa = [20, 21, 22, 23] + remote_swa = [3, 7] + remote_fa = [10, 11] + + specs = compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + assert len(specs) == 1 # Single source rank + spec = specs[0] + + # Verify remapped remote block IDs + assert list(spec.remote_block_ids[0]) == [6, 14] # SWA: b*2+0 + assert list(spec.remote_block_ids[1]) == [20, 21, 22, 23] # FA: 2 per + + # Verify local block IDs unchanged + assert list(spec.local_block_ids[0]) == [10, 11] + assert list(spec.local_block_ids[1]) == [20, 21, 22, 23] + + # Now compute desc IDs with the remapped remote blocks + remote_ids = compute_desc_ids_from_plan( + plan, + block_ids=spec.remote_block_ids, + dst_num_blocks=500 * 2, # num_blocks * descs_per_block + ) + expected_remote = [6, 14, 1006, 1014, 20, 21, 22, 23, 1020, 1021, 1022, 1023] + assert list(remote_ids) == expected_remote + + # Local desc IDs (standard, descs_per_block=1 locally) + local_ids = compute_desc_ids_from_plan( + plan, + block_ids=spec.local_block_ids, + dst_num_blocks=1000, # local num_blocks + ) + expected_local = [10, 11, 1010, 1011, 20, 21, 22, 23, 1020, 1021, 1022, 1023] + assert list(local_ids) == expected_local + + # Both have same length → can be paired for transfer + assert len(remote_ids) == len(local_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index c56e373ba99d..724fc709d841 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -50,6 +50,12 @@ class NixlAgentMetadata: ssm_sizes: tuple[int, int] attn_backend_name: str physical_blocks_per_logical_kv_block: int + # Per-group block_size in tokens after page unification, indexed by + # kv_cache_group position. Needed for HeteroTP models (e.g. Gemma4) + # where groups have different token counts per block. + # Example — Gemma4 at P_TP=2: [16, 32] for [SWA, FA]. + # None for homogeneous models (all groups share the same block_size). + tokens_per_block_per_group: list[int] | None = None @dataclass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 6fca950c0f57..5b7ad04ec3b9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -58,6 +58,12 @@ class RegionPlan: Everything needed to build NIXL descriptors and compute descriptor IDs is baked in. The caller plugs in ``base_addr`` and ``device_id`` when constructing the final descriptor tuples. + + When ``descs_per_block > 1``, each physical block produces multiple + NIXL descriptors. This happens when the remote page is larger than + the local page (e.g. Gemma4 2p4d where P page = 65536 bytes, + D page = 32768 bytes → ``descs_per_block = 2``). Each descriptor + covers one local-page-sized chunk of the remote block. """ layer_idx: int @@ -68,6 +74,13 @@ class RegionPlan: page_stride: int num_blocks: int + # How many NIXL descriptors to register per physical block. + # Default 1 (one desc per block). When the remote page is N times + # larger than local, set to N so each block produces N descriptors. + descs_per_block: int = 1 + # Byte offset between consecutive descriptors within the same block. + desc_stride_bytes: int = 0 + @dataclass(frozen=True) class EngineTransferPlan: @@ -76,9 +89,12 @@ class EngineTransferPlan: Generated once during handshake. Regions are split into ``fa_regions`` and ``ssm_regions`` matching the descriptor handle layout. + + Per-group HeteroTP fields enable models where different attention + groups have different transfer behaviors (e.g. Gemma4 SWA + FA). """ - # Regions in descriptor handle order + # --- Core regions (descriptor handle order) --- fa_regions: tuple[RegionPlan, ...] ssm_regions: tuple[RegionPlan, ...] @@ -91,14 +107,35 @@ class EngineTransferPlan: # Superset of all source ranks (union of all groups). all_source_ranks: tuple[int, ...] - # Maps each source rank to its FA head slot index. - rank_to_attention_slot: dict[int, int] + # Per-group head slot mapping. Each dict maps source rank → slot. + # Per-group because different groups can have different num_kv_heads, + # leading to different head-to-slot assignments. + # Example: Gemma4 has SWA K=8 and FA K=2; at 4p2d these would + # produce genuinely different slot mappings. + rank_to_attention_slot: tuple[dict[int, int], ...] # Stride for expanding remote logical block IDs to kernel block IDs. - # Dense: local_physical_blocks_per_logical. - # Mamba: remote_physical_blocks_per_logical. remote_expansion_stride: int + # --- HeteroTP per-group fields (e.g. Gemma4 SWA + FA) --- + # Active only when remote_to_local_page_ratio > 1. + # For Dense/Mamba (ratio=1), these are unused and default to empty. + + # remote_page_size_bytes / local_page_size_bytes. + # Gemma4 2p4d: 65536 / 32768 = 2. + remote_to_local_page_ratio: int = 1 + + # Per-group: how many local (D) blocks correspond to one remote (P) + # block. Computed as remote_block_size / local_block_size per group. + # Gemma4 2p4d: SWA = 16/16 = 1, FA = 32/16 = 2. + local_blocks_per_remote_block: tuple[int, ...] = () + + # Per-group: which descriptor index to read from a multi-descriptor + # remote block (for head-split groups where local reads a portion). + # Gemma4 2p4d rank 0: SWA = 0 (first half), FA = 0 (unused, reads all). + # Gemma4 2p4d rank 1: SWA = 1 (second half), FA = 0. + sub_desc_index_per_group: tuple[int, ...] = () + @property def all_regions(self) -> tuple[RegionPlan, ...]: return self.fa_regions + self.ssm_regions @@ -306,7 +343,7 @@ def generate_dense_plan( group_spec_types=group_spec_types, source_ranks_per_group=tp_mapping.source_ranks_per_group, all_source_ranks=tp_mapping.all_source_ranks, - rank_to_attention_slot=tp_mapping.rank_to_attention_slot, + rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,), remote_expansion_stride=local_physical_blocks_per_logical, ) @@ -416,17 +453,201 @@ def generate_mamba_plan( ) ) + n_groups = len(group_spec_types) return EngineTransferPlan( fa_regions=tuple(fa_regions), ssm_regions=tuple(ssm_regions), group_spec_types=group_spec_types, source_ranks_per_group=tp_mapping.source_ranks_per_group, all_source_ranks=tp_mapping.all_source_ranks, - rank_to_attention_slot=tp_mapping.rank_to_attention_slot, + rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,) * n_groups, remote_expansion_stride=remote_phys_ratio, ) +def generate_gemma4_plan( + *, + transfer_topo: TransferTopology, + block_len_per_layer: list[int], + remote_info: EngineTransferInfo, + remote_meta: NixlAgentMetadata, + group_spec_types: tuple[type[KVCacheSpec], ...], + total_num_kv_heads_per_group: tuple[int, ...], + local_tokens_per_block: tuple[int, ...], + remote_tokens_per_block: tuple[int, ...], +) -> EngineTransferPlan: + """Generate transfer plan for Gemma4-style heterogeneous attention. + + Gemma4 has multiple attention groups (SWA, FA) with different + ``total_num_kv_heads`` and ``head_dim``. With page unification and + HMA, all groups share physical memory pools. This generator: + + 1. Calls ``_compute_tp_mapping`` per group with group-specific K. + 2. Builds FA regions with multiple descriptors per block when P and + D have different page sizes. + 3. Encodes per-group transfer behavior via + ``local_blocks_per_remote_block`` and ``sub_desc_index_per_group``. + """ + tp_rank = transfer_topo.tp_rank + tp_size = transfer_topo.tp_size + remote_tp_size = remote_info.remote_tp_size + is_mla = transfer_topo.is_mla + is_blocks_first = transfer_topo.is_kv_layout_blocks_first + n_groups = len(group_spec_types) + + local_page = block_len_per_layer[0] + remote_page = remote_meta.block_lens[0] + page_ratio = remote_page // local_page + assert page_ratio >= 1, ( + f"Remote page {remote_page} must be >= local page {local_page}" + ) + + blocks_per_remote: list[int] = [] + sub_desc_idx: list[int] = [] + + source_ranks_all: list[tuple[int, ...]] = [] + rank_to_slot_all: list[dict[int, int]] = [] + + for g in range(n_groups): + n_local = remote_tokens_per_block[g] // local_tokens_per_block[g] + blocks_per_remote.append(n_local) + + K_g = total_num_kv_heads_per_group[g] + m_g = _compute_tp_mapping( + tp_rank, + tp_size, + remote_tp_size, + is_mla, + K_g, + (group_spec_types[g],), + ) + source_ranks_all.append(m_g.source_ranks_per_group[0]) + rank_to_slot_all.append(m_g.rank_to_attention_slot) + + # Head-split groups: rank_offset_factor selects which descriptor. + if n_local == 1 and page_ratio > 1: + sub_desc_idx.append(m_g.rank_offset_factor) + else: + sub_desc_idx.append(0) + + all_ranks: set[int] = set() + for ranks in source_ranks_all: + all_ranks.update(ranks) + all_source_ranks = tuple(sorted(all_ranks)) + + # HMA: one K pool (+ optional V pool) shared by all groups. + # Register descs_per_block descriptors per physical block. + fa_regions: list[RegionPlan] = [] + for i in range(len(remote_meta.block_lens)): + local_block_len = _get_kv_block_len( + i, + block_len_per_layer, + is_blocks_first, + ) + page_stride = remote_meta.block_lens[i] + + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_K, + layer_idx=i, + descriptor_bytes=local_block_len, + offset_in_page=0, + page_stride=page_stride, + num_blocks=remote_meta.num_blocks, + descs_per_block=page_ratio, + desc_stride_bytes=local_block_len, + ) + ) + + if is_blocks_first: + fa_regions.append( + RegionPlan( + kind=RegionKind.FA_V, + layer_idx=i, + descriptor_bytes=local_block_len, + offset_in_page=page_stride // 2, + page_stride=page_stride, + num_blocks=remote_meta.num_blocks, + descs_per_block=page_ratio, + desc_stride_bytes=local_block_len, + ) + ) + + return EngineTransferPlan( + fa_regions=tuple(fa_regions), + ssm_regions=(), + group_spec_types=group_spec_types, + source_ranks_per_group=tuple(source_ranks_all), + all_source_ranks=all_source_ranks, + rank_to_attention_slot=tuple(rank_to_slot_all), + remote_expansion_stride=1, + remote_to_local_page_ratio=page_ratio, + local_blocks_per_remote_block=tuple(blocks_per_remote), + sub_desc_index_per_group=tuple(sub_desc_idx), + ) + + +# ====================================================================== +# 4. Local descriptor building +# ====================================================================== + + +def _remap_remote_blocks_to_subdesc_ids( + plan: EngineTransferPlan, + remote_block_ids: BlockIds, + local_block_ids: BlockIds, +) -> tuple[BlockIds, BlockIds]: + """Convert remote block IDs into descriptor-level indices. + + When ``remote_to_local_page_ratio > 1``, each remote physical block + is registered as multiple descriptors (one per local-page-sized + chunk). This function converts remote block IDs into the + descriptor index space so that ``_compute_desc_ids_from_plan`` can + look up the correct descriptors. + + Two per-group cases: + + * **Multi-block** (``local_blocks_per_remote_block > 1``, e.g. FA): + One remote block covers multiple local blocks. + Remote block ``b`` → descriptor indices + ``[b*ratio, b*ratio+1, ..., b*ratio+(n-1)]``. + Example: FA block 10, ratio=2 → desc indices [20, 21]. + + * **Head-split** (``local_blocks_per_remote_block == 1``, e.g. SWA): + Local reads one specific chunk of the remote block. + Remote block ``b`` → descriptor index + ``b*ratio + sub_desc_index_per_group[g]``. + Example: SWA block 10, ratio=2, index=1 → desc index 21. + + Local block IDs are returned unchanged. + """ + if plan.remote_to_local_page_ratio <= 1: + return remote_block_ids, local_block_ids + + ratio = plan.remote_to_local_page_ratio + num_groups = len(remote_block_ids) + new_remote: list[list[int]] = [] + new_local: list[list[int]] = [] + + for g in range(num_groups): + n_local = plan.local_blocks_per_remote_block[g] + r_ids = list(remote_block_ids[g]) + l_ids = list(local_block_ids[g]) + + if n_local > 1: + remapped: list[int] = [] + for b in r_ids: + remapped.extend(b * ratio + s for s in range(n_local)) + new_remote.append(remapped) + else: + idx = plan.sub_desc_index_per_group[g] + new_remote.append([b * ratio + idx for b in r_ids]) + + new_local.append(l_ids) + + return new_remote, new_local + + # ====================================================================== # 4. Local descriptor building # ====================================================================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 97b88ad06396..cb1ae8dd256f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -51,6 +51,7 @@ build_fa_local_regions, build_mamba_local_regions, generate_dense_plan, + generate_gemma4_plan, generate_mamba_plan, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( @@ -71,6 +72,7 @@ from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, MambaSpec, UniformTypeKVCacheSpecs, @@ -364,6 +366,29 @@ def __init__( mamba_ssm_size = self._conv_decomp.ssm_sizes self._mamba_ssm_size = mamba_ssm_size + # ---- Heterogeneous attention detection (e.g. Gemma4 SWA + FA) ---- + tp_size = vllm_config.parallel_config.tensor_parallel_size + attn_specs = [ + g.kv_cache_spec + for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ] + self._is_hetero_attn = ( + len(attn_specs) > 1 and len({s.num_kv_heads for s in attn_specs}) > 1 + ) + if self._is_hetero_attn: + self._total_kv_heads_per_group = tuple( + s.total_num_kv_heads + if s.total_num_kv_heads is not None + else s.num_kv_heads * tp_size + for s in attn_specs + ) + unified_page = max(s.page_size_bytes for s in attn_specs) + self._local_tokens_per_block_per_group = tuple( + s.block_size * unified_page // s.real_page_size_bytes + for s in attn_specs + ) + # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] # Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs. @@ -1049,6 +1074,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): physical_blocks_per_logical_kv_block=( self._physical_blocks_per_logical_kv_block ), + tokens_per_block_per_group=( + list(self._local_tokens_per_block_per_group) + if self._is_hetero_attn + else None + ), ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1179,6 +1209,26 @@ def add_remote_agent( conv_decomp=self._conv_decomp, ssm_sizes=self._mamba_ssm_size, ) + elif self._is_hetero_attn: + remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ()) + group_spec_types = tuple( + type(g.kv_cache_spec) + for g in self.kv_cache_config.kv_cache_groups + ) + assert len(remote_tpb) == len(group_spec_types), ( + f"Remote tokens_per_block_per_group length " + f"{len(remote_tpb)} != {len(group_spec_types)} groups" + ) + self._transfer_plans[engine_id] = generate_gemma4_plan( + transfer_topo=transfer_topo, + block_len_per_layer=self.block_len_per_layer, + remote_info=transfer_info, + remote_meta=nixl_agent_meta, + group_spec_types=group_spec_types, + total_num_kv_heads_per_group=(self._total_kv_heads_per_group), + local_tokens_per_block=(self._local_tokens_per_block_per_group), + remote_tokens_per_block=remote_tpb, + ) else: self._transfer_plans[engine_id] = generate_dense_plan( transfer_topo=transfer_topo, @@ -1963,10 +2013,15 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. + # For HeteroTP (page_ratio > 1), each remote block is registered as + # multiple descriptors, so scale the descriptor-space block count. + remote_desc_blocks = ( + self.dst_num_blocks[dst_engine_id] * plan.remote_to_local_page_ratio + ) remote_block_descs_ids = self._compute_desc_ids_from_plan( plan, block_ids=remote_block_ids, - dst_num_blocks=self.dst_num_blocks[dst_engine_id], + dst_num_blocks=remote_desc_blocks, block_size_ratio=None, physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, ) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index db9ae2bbda34..c0b83cf35e68 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -203,6 +203,7 @@ def __init__( kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, head_size_v: int | None = None, + total_num_kv_heads: int | None = None, **extra_impl_args, ) -> None: """ @@ -285,6 +286,7 @@ def __init__( self.head_size = head_size self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads + self.total_num_kv_heads = total_num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None @@ -552,6 +554,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, kv_quant_mode=quant_mode, + total_num_kv_heads=self.total_num_kv_heads, sliding_window=self.sliding_window, ) elif self.kv_cache_dtype.startswith("turboquant_"): @@ -579,6 +582,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, kv_quant_mode=quant_mode, + total_num_kv_heads=self.total_num_kv_heads, ) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index b724fa71968c..52543f04d654 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -492,6 +492,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + total_num_kv_heads=self.total_num_kv_heads, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2545c440368a..87b45195df61 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -133,6 +133,7 @@ class AttentionSpec(KVCacheSpec): dtype: torch.dtype kv_quant_mode: KVQuantMode = KVQuantMode.NONE page_size_padded: int | None = None + total_num_kv_heads: int | None = None @property def page_size_bytes(self) -> int: From 594e1ab16ee68e75a99ff71ff783df374ce47b6c Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 28 Apr 2026 15:46:29 -0400 Subject: [PATCH 45/49] add gather-read support for Gemma4 HeteroTP NIXL transfer When D_TP < P_TP (e.g. 4p2d), local pages are larger than remote pages. Introduce gather-read: split local blocks into sub-descriptors matching the remote page size so NIXL can pair them for RDMA. - Add local_to_remote_page_ratio and remote_blocks_per_local_block to EngineTransferPlan; generate_gemma4_plan now handles both split-read (2p4d) and gather-read (4p2d) directions - Add build_fa_local_descs_for_gather_read for local sub-desc registration - Worker: register gather-read handles, bidirectional block trimming, scale local desc-space block count for gather-read - Fix FullAttentionSpec.merge() and SinkFullAttentionSpec.merge() to propagate total_num_kv_heads (was causing AssertionError on Gemma4) - Add unit tests for gather-read configs (4p2d, 4p1d) Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 208 ++++++++++++++++++ .../kv_connector/v1/nixl/transfer_plan.py | 188 ++++++++++++++-- .../kv_connector/v1/nixl/worker.py | 95 ++++++-- vllm/v1/kv_cache_interface.py | 2 + 4 files changed, 459 insertions(+), 34 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 812da487c8eb..3087492d3258 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -776,3 +776,211 @@ def test_remapped_block_ids(self): # Both have same length → can be paired for transfer assert len(remote_ids) == len(local_ids) + + +# ====================================================================== +# Gemma4 Gather-Read tests (local page > remote page) +# ====================================================================== + + +def _make_gemma4_gather_plan_params( + tp_rank: int = 0, + tp_size: int = 2, + remote_tp_size: int = 4, +) -> dict: + """Build kwargs for generate_gemma4_plan at 4p2d (gather-read). + + Gemma4-26B at P_TP=4, D_TP=2: + SWA: K=8, head_dim=256, P_tpb=16, D_tpb=16 → concat (2 P ranks) + FA: K=2, head_dim=512, P_tpb=16, D_tpb=32 → gather (2P→1D block) + + page_size: P=32768, D=65536 → local_to_remote_page_ratio=2. + """ + d_page = 65536 + p_page = 32768 + num_layers = 2 + + return dict( + transfer_topo=_make_fake_topo( + tp_rank=tp_rank, + tp_size=tp_size, + is_mla=False, + total_num_kv_heads=8, + block_size=16, + is_blocks_first=False, + ), + block_len_per_layer=[d_page] * num_layers, + remote_info=EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_size=16, + remote_block_len=p_page, + remote_physical_blocks_per_logical=1, + ), + remote_meta=_make_nixl_meta( + base_addrs=[0x10000 * (i + 1) for i in range(num_layers)], + num_blocks=500, + block_lens=[p_page] * num_layers, + block_size=16, + ), + group_kinds=(GroupKind.SWA, GroupKind.FA), + total_num_kv_heads_per_group=(8, 2), + local_tokens_per_block=(16, 32), + remote_tokens_per_block=(16, 16), + ) + + +class TestGemma4GatherReadPlanStructure: + """Verify plan structure for gather-read (4p2d).""" + + def test_plan_fields_4p2d_rank0(self): + """D rank 0 at 4p2d: gather_ratio=2, SWA concat, FA gather.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + assert plan.local_to_remote_page_ratio == 2 + assert plan.remote_to_local_page_ratio == 1 + assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) + assert plan.remote_blocks_per_local_block == (1, 2) + assert plan.local_blocks_per_remote_block == (1, 1) + # SWA: D rank 0 reads from P rank 0 and P rank 1 + assert (0,) in plan.source_ranks_per_group[0] or \ + len(plan.source_ranks_per_group[0]) == 2 + # FA: after GQA dedup, D rank 0 reads from P rank 0 only + assert len(plan.source_ranks_per_group[1]) == 1 + + def test_no_assertion_error(self): + """4p2d should NOT crash (old code had assert page_ratio >= 1).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + assert plan is not None + + def test_fa_regions_standard_descs(self): + """Gather-read: FA regions have descs_per_block=1 (standard).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + + for region in plan.fa_regions: + assert region.descs_per_block == 1 + assert region.descriptor_bytes == 32768 # remote page size + + +class TestGemma4GatherReadRemoteDescs: + """Verify remote descriptor building for gather-read.""" + + def test_standard_descs_per_block(self): + """Gather-read: 1 desc per block (no remote sub-descs).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[32768, 32768], + ) + descs = build_remote_descs_from_plan(plan, meta) + + # 2 layers × 1 region/layer × 500 blocks × 1 desc/block = 1000 + assert len(descs) == 2 * 500 * 1 + + def test_desc_bytes_match_remote_page(self): + """Each remote desc should be remote_page_size bytes.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params()) + meta = _make_nixl_meta( + base_addrs=[0x10000, 0x20000], + num_blocks=500, + block_lens=[32768, 32768], + ) + descs = build_remote_descs_from_plan(plan, meta) + + for _, length, _ in descs: + assert length == 32768 + + +class TestGemma4GatherReadSpecs: + """Verify read spec computation for gather-read.""" + + def test_gather_read_specs_4p2d_rank0(self): + """4p2d rank 0: SWA from 2 ranks, FA from 1 rank (gather).""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + # D has 2 SWA blocks and 1 FA block (32 tokens) + local_swa = [10, 11] + local_fa = [20] + # P has 2 SWA blocks per rank and 2 FA blocks (16 tokens each) + remote_swa = [5, 6] + remote_fa = [30, 31] + + specs = compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + + # SWA reads from 2 P ranks → 2 specs + assert len(specs) == 2 + + # Spec 0 (P rank 0): + # SWA: local sub-desc slot 0 → [10*2+0, 11*2+0] = [20, 22] + # FA: expanded → [20*2+0, 20*2+1] = [40, 41] + spec0 = specs[0] + assert list(spec0.local_block_ids[0]) == [20, 22] # SWA slot 0 + assert list(spec0.local_block_ids[1]) == [40, 41] # FA gather + assert list(spec0.remote_block_ids[0]) == [5, 6] # SWA blocks + assert list(spec0.remote_block_ids[1]) == [30, 31] # FA blocks + + # Spec 1 (P rank 1): + # SWA: local sub-desc slot 1 → [10*2+1, 11*2+1] = [21, 23] + # FA: empty (rank 1 not in FA source_ranks after GQA dedup) + spec1 = specs[1] + assert list(spec1.local_block_ids[0]) == [21, 23] # SWA slot 1 + assert list(spec1.remote_block_ids[0]) == [5, 6] # SWA blocks + assert spec1.local_block_ids[1] == [] # FA empty for rank 1 + assert spec1.remote_block_ids[1] == [] + + def test_gather_read_desc_ids_match(self): + """Local and remote desc IDs should have same length for NIXL.""" + plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) + + local_swa = [10, 11] + local_fa = [20] + remote_swa = [5, 6] + remote_fa = [30, 31] + + specs = compute_read_specs_from_plan( + plan, + local_block_ids=(local_swa, local_fa), + remote_block_ids=(remote_swa, remote_fa), + ) + + for spec in specs: + # Remote desc IDs: standard (no sub-descs), num_blocks=500 + remote_ids = compute_desc_ids_from_plan( + plan, + block_ids=spec.remote_block_ids, + dst_num_blocks=500, + ) + # Local desc IDs: gather sub-descs, num_blocks=1000*gather_ratio + local_ids = compute_desc_ids_from_plan( + plan, + block_ids=spec.local_block_ids, + dst_num_blocks=1000 * 2, # local_num_blocks * gather_ratio + ) + assert len(remote_ids) == len(local_ids), ( + f"Desc ID length mismatch for rank {spec.remote_rank}: " + f"remote={len(remote_ids)}, local={len(local_ids)}" + ) + + +class TestGemma4GatherReadPlan4p1d: + """Verify gather-read for 4p1d (D_TP=1, P_TP=4).""" + + def test_4p1d_no_crash(self): + """4p1d should not crash.""" + params = _make_gemma4_gather_plan_params( + tp_rank=0, tp_size=1, remote_tp_size=4 + ) + # D_TP=1: D_page = 131072 (8 heads * 256 * 2 * 16 * 2 for SWA) + # P_TP=4: P_page = 32768 + params["block_len_per_layer"] = [131072, 131072] + params["local_tokens_per_block"] = (16, 32) + params["remote_tokens_per_block"] = (16, 16) + plan = generate_gemma4_plan(**params) + + assert plan.local_to_remote_page_ratio == 4 + assert plan.remote_to_local_page_ratio == 1 + assert plan.remote_blocks_per_local_block == (1, 2) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 5b7ad04ec3b9..f85fd359ca17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -136,6 +136,18 @@ class EngineTransferPlan: # Gemma4 2p4d rank 1: SWA = 1 (second half), FA = 0. sub_desc_index_per_group: tuple[int, ...] = () + # --- Gather-read fields (local page > remote page, e.g. 4p2d FA) --- + # When D pages are larger than P pages, local blocks are split into + # sub-descriptors matching the remote block size for RDMA pairing. + + # local_page_size_bytes / remote_page_size_bytes. + # 4p2d Gemma4: 65536 / 32768 = 2. 1 when no gather-read. + local_to_remote_page_ratio: int = 1 + + # Per-group: how many remote blocks fill one local block. + # FA in 4p2d: D_tpb / P_tpb = 32 / 16 = 2. + remote_blocks_per_local_block: tuple[int, ...] = () + @property def all_regions(self) -> tuple[RegionPlan, ...]: return self.fa_regions + self.ssm_regions @@ -483,10 +495,14 @@ def generate_gemma4_plan( HMA, all groups share physical memory pools. This generator: 1. Calls ``_compute_tp_mapping`` per group with group-specific K. - 2. Builds FA regions with multiple descriptors per block when P and - D have different page sizes. + 2. Handles both **split-read** (remote page > local page, e.g. 2p4d) + and **gather-read** (local page > remote page, e.g. 4p2d). 3. Encodes per-group transfer behavior via - ``local_blocks_per_remote_block`` and ``sub_desc_index_per_group``. + ``local_blocks_per_remote_block`` / ``remote_blocks_per_local_block`` + and ``sub_desc_index_per_group``. + + Split-read (P_page > D_page): remote blocks are split into sub-descs. + Gather-read (D_page > P_page): local blocks are split into sub-descs. """ tp_rank = transfer_topo.tp_rank tp_size = transfer_topo.tp_size @@ -497,20 +513,31 @@ def generate_gemma4_plan( local_page = block_len_per_layer[0] remote_page = remote_meta.block_lens[0] - page_ratio = remote_page // local_page - assert page_ratio >= 1, ( - f"Remote page {remote_page} must be >= local page {local_page}" - ) + + if remote_page >= local_page: + split_page_ratio = remote_page // local_page + gather_page_ratio = 1 + else: + split_page_ratio = 1 + gather_page_ratio = local_page // remote_page blocks_per_remote: list[int] = [] + remote_blocks_per_local: list[int] = [] sub_desc_idx: list[int] = [] source_ranks_all: list[tuple[int, ...]] = [] rank_to_slot_all: list[dict[int, int]] = [] for g in range(n_groups): - n_local = remote_tokens_per_block[g] // local_tokens_per_block[g] - blocks_per_remote.append(n_local) + r_tpb = remote_tokens_per_block[g] + l_tpb = local_tokens_per_block[g] + + if r_tpb >= l_tpb: + blocks_per_remote.append(r_tpb // l_tpb) + remote_blocks_per_local.append(1) + else: + blocks_per_remote.append(1) + remote_blocks_per_local.append(l_tpb // r_tpb) K_g = total_num_kv_heads_per_group[g] m_g = _compute_tp_mapping( @@ -524,8 +551,8 @@ def generate_gemma4_plan( source_ranks_all.append(m_g.source_ranks_per_group[0]) rank_to_slot_all.append(m_g.rank_to_attention_slot) - # Head-split groups: rank_offset_factor selects which descriptor. - if n_local == 1 and page_ratio > 1: + # Head-split groups (split-read only): rank_offset selects sub-desc. + if blocks_per_remote[-1] == 1 and split_page_ratio > 1: sub_desc_idx.append(m_g.rank_offset_factor) else: sub_desc_idx.append(0) @@ -536,7 +563,6 @@ def generate_gemma4_plan( all_source_ranks = tuple(sorted(all_ranks)) # HMA: one K pool (+ optional V pool) shared by all groups. - # Register descs_per_block descriptors per physical block. fa_regions: list[RegionPlan] = [] for i in range(len(remote_meta.block_lens)): local_block_len = _get_kv_block_len( @@ -546,16 +572,34 @@ def generate_gemma4_plan( ) page_stride = remote_meta.block_lens[i] + if split_page_ratio > 1: + # Split-read: remote blocks produce sub-descs of local page size + desc_bytes = local_block_len + descs_per_block = split_page_ratio + desc_stride = local_block_len + elif gather_page_ratio > 1: + # Gather-read: standard remote descs at remote page size + remote_block_len = _get_kv_block_len( + i, remote_meta.block_lens, is_blocks_first + ) + desc_bytes = remote_block_len + descs_per_block = 1 + desc_stride = 0 + else: + desc_bytes = local_block_len + descs_per_block = 1 + desc_stride = 0 + fa_regions.append( RegionPlan( kind=RegionKind.FA_K, layer_idx=i, - descriptor_bytes=local_block_len, + descriptor_bytes=desc_bytes, offset_in_page=0, page_stride=page_stride, num_blocks=remote_meta.num_blocks, - descs_per_block=page_ratio, - desc_stride_bytes=local_block_len, + descs_per_block=descs_per_block, + desc_stride_bytes=desc_stride, ) ) @@ -564,12 +608,12 @@ def generate_gemma4_plan( RegionPlan( kind=RegionKind.FA_V, layer_idx=i, - descriptor_bytes=local_block_len, + descriptor_bytes=desc_bytes, offset_in_page=page_stride // 2, page_stride=page_stride, num_blocks=remote_meta.num_blocks, - descs_per_block=page_ratio, - desc_stride_bytes=local_block_len, + descs_per_block=descs_per_block, + desc_stride_bytes=desc_stride, ) ) @@ -581,9 +625,11 @@ def generate_gemma4_plan( all_source_ranks=all_source_ranks, rank_to_attention_slot=tuple(rank_to_slot_all), remote_expansion_stride=1, - remote_to_local_page_ratio=page_ratio, + remote_to_local_page_ratio=split_page_ratio, local_blocks_per_remote_block=tuple(blocks_per_remote), sub_desc_index_per_group=tuple(sub_desc_idx), + local_to_remote_page_ratio=gather_page_ratio, + remote_blocks_per_local_block=tuple(remote_blocks_per_local), ) @@ -648,6 +694,67 @@ def _remap_remote_blocks_to_subdesc_ids( return new_remote, new_local +def _build_gather_read_specs( + plan: EngineTransferPlan, + local_block_ids: BlockIds, + remote_block_ids: BlockIds, +) -> list[ReadSpec]: + """Build read specs for gather-read (local page > remote page). + + In gather-read, local blocks are split into sub-descriptors matching + remote block size. Each rank's read targets specific local sub-descs: + + * **Gather groups** (``remote_blocks_per_local_block > 1``, e.g. FA): + N remote blocks fill one local block. + Local block ``b`` → sub-desc indices + ``[b*ratio, b*ratio+1, ..., b*ratio+(N-1)]``. + + * **Concat groups** (``remote_blocks_per_local_block == 1``, e.g. SWA): + Each rank writes to a specific slot of the local block. + Local block ``b`` → sub-desc index + ``b*ratio + rank_slot``. + """ + gather_ratio = plan.local_to_remote_page_ratio + num_groups = len(local_block_ids) + specs: list[ReadSpec] = [] + + for rank in plan.all_source_ranks: + rank_local: list[list[int]] = [] + rank_remote: list[list[int]] = [] + + for g in range(num_groups): + if rank not in plan.source_ranks_per_group[g]: + rank_local.append([]) + rank_remote.append([]) + continue + + n_remote_per_local = plan.remote_blocks_per_local_block[g] + + if n_remote_per_local > 1: + expanded_local: list[int] = [] + for b in local_block_ids[g]: + expanded_local.extend( + b * gather_ratio + s + for s in range(n_remote_per_local) + ) + rank_local.append(expanded_local) + rank_remote.append(list(remote_block_ids[g])) + else: + slot = plan.rank_to_attention_slot[g].get(rank, 0) + rank_local.append( + [b * gather_ratio + slot for b in local_block_ids[g]] + ) + rank_remote.append(list(remote_block_ids[g])) + + specs.append(ReadSpec( + remote_rank=rank, + local_block_ids=rank_local, + remote_block_ids=rank_remote, + )) + + return specs + + # ====================================================================== # 4. Local descriptor building # ====================================================================== @@ -695,6 +802,49 @@ def build_fa_local_regions( return regions +def build_fa_local_descs_for_gather_read( + base_addresses: list[int], + device_id: int, + num_blocks: int, + block_len_per_layer: list[int], + is_blocks_first: bool, + gather_page_ratio: int, +) -> list[tuple[int, int, int]]: + """Build FA local descriptors with sub-descriptors for gather-read. + + Each local block produces ``gather_page_ratio`` descriptors, each + covering ``kv_block_len // gather_page_ratio`` bytes. This allows + NIXL to pair each local sub-descriptor with a remote descriptor of + matching size (the remote's natural page size). + """ + result: list[tuple[int, int, int]] = [] + for i, base_addr in enumerate(base_addresses): + kv_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) + page_stride = block_len_per_layer[i] + sub_desc_bytes = kv_block_len // gather_page_ratio + + for block_id in range(num_blocks): + blk_addr = base_addr + block_id * page_stride + for s in range(gather_page_ratio): + result.append( + (blk_addr + s * sub_desc_bytes, sub_desc_bytes, device_id) + ) + + if is_blocks_first: + v_sub_desc_bytes = kv_block_len // gather_page_ratio + for block_id in range(num_blocks): + v_blk_addr = ( + base_addr + block_id * page_stride + kv_block_len + ) + for s in range(gather_page_ratio): + result.append( + (v_blk_addr + s * v_sub_desc_bytes, v_sub_desc_bytes, + device_id) + ) + + return result + + def build_mamba_local_regions( block_len_per_layer: list[int], logical_num_blocks: int, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index cb1ae8dd256f..697c2b88352c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -46,8 +46,11 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.transfer_plan import ( EngineTransferPlan, ReadSpec, + _build_gather_read_specs, _is_attention_spec, _is_ssm_spec, + _remap_remote_blocks_to_subdesc_ids, + build_fa_local_descs_for_gather_read, build_fa_local_regions, build_mamba_local_regions, generate_dense_plan, @@ -180,9 +183,26 @@ def _compute_read_specs_from_plan( ) -> list[ReadSpec]: """Compute read specs from plan. - For each source rank, includes only the groups whose - source_ranks_per_group contains that rank. + Dispatches to the correct remapping strategy: + + - **Gather-read** (``local_to_remote_page_ratio > 1``): per-rank + local sub-desc remapping via ``_build_gather_read_specs``. + - **Split-read** (``remote_to_local_page_ratio > 1``): + rank-independent remote sub-desc remapping via + ``_remap_remote_blocks_to_subdesc_ids``. + - **Standard**: direct per-rank group filtering. """ + if plan.local_to_remote_page_ratio > 1: + return _build_gather_read_specs( + plan, local_block_ids, remote_block_ids + ) + + remote_block_ids, local_block_ids = ( + _remap_remote_blocks_to_subdesc_ids( + plan, remote_block_ids, local_block_ids + ) + ) + num_groups = len(local_block_ids) return [ ReadSpec( @@ -491,6 +511,9 @@ def __init__( # Populated dynamically during handshake based on remote configuration. # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} + # Gather-read local handles: local blocks split into sub-descs + # matching remote page size. Keyed by engine_id. + self._gather_read_handles: dict[EngineId, int] = {} # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict) @@ -1278,12 +1301,37 @@ def add_remote_agent( plan = self._transfer_plans[engine_id] - ### (Optional) Register local agent memory regions. MLA is not split. + ### (Optional) Register local agent memory regions. if ( + plan.local_to_remote_page_ratio > 1 + and engine_id not in self._gather_read_handles + ): + # Gather-read: local page > remote page. Register local descs + # with sub-descriptors matching the remote block size. + assert self.transfer_topo is not None + local_base_addresses = self.kv_caches_base_addr[ + self.engine_id + ][self.tp_rank] + gather_blocks_data = build_fa_local_descs_for_gather_read( + base_addresses=local_base_addresses, + device_id=self.device_id, + num_blocks=self.num_blocks, + block_len_per_layer=self.block_len_per_layer, + is_blocks_first=self.transfer_topo.is_kv_layout_blocks_first, + gather_page_ratio=plan.local_to_remote_page_ratio, + ) + descs = self.nixl_wrapper.get_xfer_descs( + gather_blocks_data, self.nixl_memory_type + ) + self._gather_read_handles[engine_id] = ( + self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + ) + elif ( tp_ratio < 0 and not self.use_mla and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): + # MLA is not split. # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). @@ -1871,7 +1919,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): req_id, ) # Get side handles. - if tp_ratio < 0 and not self.use_mla: + if engine_id in self._gather_read_handles: + # Gather-read: local sub-desc handle matches remote page size. + local_xfer_side_handle = self._gather_read_handles[engine_id] + elif tp_ratio < 0 and not self.use_mla: assert remote_block_size == self.block_size # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. @@ -1994,30 +2045,41 @@ def _read_blocks( == len(local_block_ids) == len(self.kv_cache_config.kv_cache_groups) ) - # Partial prefix cache hit: just read uncomputed blocks. + # Partial prefix cache hit: trim to the shorter of local/remote. # Skip mamba groups — their blocks represent full state (conv+ssm), # not per-token data, so trimming would corrupt the transfer. + # For standard and split-read: remote >= local (trim remote). + # For gather-read: local sub-descs may exceed remote (trim local). remote_block_ids = list(remote_block_ids) + local_block_ids = list(local_block_ids) group_specs = self.kv_cache_config.kv_cache_groups - for i, remote_group in enumerate(remote_block_ids): - num_remote_blocks = len(remote_group) - num_local_blocks = len(local_block_ids[i]) + for i in range(len(remote_block_ids)): is_mamba = isinstance(group_specs[i].kv_cache_spec, MambaSpec) - if not is_mamba: - assert num_local_blocks <= num_remote_blocks - if num_local_blocks < num_remote_blocks and not is_mamba: - remote_block_ids[i] = remote_group[-num_local_blocks:] + if is_mamba: + continue + n_local = len(local_block_ids[i]) + n_remote = len(remote_block_ids[i]) + n = min(n_local, n_remote) + if n_local > n: + local_block_ids[i] = local_block_ids[i][-n:] + if n_remote > n: + remote_block_ids[i] = remote_block_ids[i][-n:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - # For HeteroTP (page_ratio > 1), each remote block is registered as - # multiple descriptors, so scale the descriptor-space block count. + # For split-read (page_ratio > 1), each remote block is registered as + # multiple descriptors, so scale the remote descriptor-space count. + # For gather-read, scale the local descriptor-space count instead. remote_desc_blocks = ( self.dst_num_blocks[dst_engine_id] * plan.remote_to_local_page_ratio ) + local_desc_blocks = self.dst_num_blocks[self.engine_id] + if plan.local_to_remote_page_ratio > 1: + local_desc_blocks *= plan.local_to_remote_page_ratio + remote_block_descs_ids = self._compute_desc_ids_from_plan( plan, block_ids=remote_block_ids, @@ -2028,7 +2090,7 @@ def _read_blocks( local_block_descs_ids = self._compute_desc_ids_from_plan( plan, block_ids=local_block_ids, - dst_num_blocks=self.dst_num_blocks[self.engine_id], + dst_num_blocks=local_desc_blocks, block_size_ratio=block_size_ratio, physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, ) @@ -2188,6 +2250,9 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_dlist_handle(handle) self.src_xfer_handles_by_tp_ratio.clear() + for handle in self._gather_read_handles.values(): + self.nixl_wrapper.release_dlist_handle(handle) + self._gather_read_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 87b45195df61..69661bf3d7e5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -235,6 +235,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + total_num_kv_heads=specs[0].total_num_kv_heads, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -602,6 +603,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + total_num_kv_heads=specs[0].total_num_kv_heads, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) From 06a31ce15b45deb9b02671e3702ff055bed4e33e Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 28 Apr 2026 15:46:29 -0400 Subject: [PATCH 46/49] rename sub_desc terminology, add gather-read pairing and assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename sub_desc_index_per_group → remote_desc_offset_per_group and _remap_remote_blocks_to_subdesc_ids → _remap_remote_blocks_to_desc_ids to eliminate confusing "sub-descriptor" naming throughout. - Add _pair_gather_group helper with remainder handling for partial block fills in gather-read (FA groups). - Add length-match assertions in both _build_gather_read_specs and _read_blocks to catch descriptor/block ID pairing mismatches early. - Add diagnostic INFO/DEBUG logging for HeteroTP plan generation, ReadSpec construction, and _read_blocks transfer details. Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 19 +- .../kv_connector/v1/nixl/transfer_plan.py | 206 +++++++++++++----- .../kv_connector/v1/nixl/worker.py | 140 ++++++++++-- 3 files changed, 288 insertions(+), 77 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 3087492d3258..99a03ea6269d 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -652,7 +652,7 @@ def test_plan_fields_2p4d_rank0(self): assert plan.remote_to_local_page_ratio == 2 assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) assert plan.local_blocks_per_remote_block == (1, 2) - assert plan.sub_desc_index_per_group == (0, 0) # rank 0: index=0 + assert plan.remote_desc_offset_per_group == (0, 0) # rank 0: index=0 assert plan.all_source_ranks == (0,) assert plan.source_ranks_per_group == ((0,), (0,)) @@ -660,7 +660,7 @@ def test_plan_fields_2p4d_rank1(self): """D rank 1 at 2p4d: SWA reads second descriptor (index=1).""" plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=1)) - assert plan.sub_desc_index_per_group == (1, 0) # rank 1: SWA=1 + assert plan.remote_desc_offset_per_group == (1, 0) # rank 1: SWA=1 assert plan.local_blocks_per_remote_block == (1, 2) assert plan.all_source_ranks == (0,) @@ -669,7 +669,7 @@ def test_plan_fields_2p4d_rank2(self): plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=2)) assert plan.all_source_ranks == (1,) - assert plan.sub_desc_index_per_group == (0, 0) + assert plan.remote_desc_offset_per_group == (0, 0) def test_fa_regions_have_multiple_descs_per_block(self): """FA regions should have descs_per_block = page ratio.""" @@ -842,8 +842,9 @@ def test_plan_fields_4p2d_rank0(self): assert plan.remote_blocks_per_local_block == (1, 2) assert plan.local_blocks_per_remote_block == (1, 1) # SWA: D rank 0 reads from P rank 0 and P rank 1 - assert (0,) in plan.source_ranks_per_group[0] or \ - len(plan.source_ranks_per_group[0]) == 2 + assert (0,) in plan.source_ranks_per_group[0] or len( + plan.source_ranks_per_group[0] + ) == 2 # FA: after GQA dedup, D rank 0 reads from P rank 0 only assert len(plan.source_ranks_per_group[1]) == 1 @@ -920,7 +921,7 @@ def test_gather_read_specs_4p2d_rank0(self): spec0 = specs[0] assert list(spec0.local_block_ids[0]) == [20, 22] # SWA slot 0 assert list(spec0.local_block_ids[1]) == [40, 41] # FA gather - assert list(spec0.remote_block_ids[0]) == [5, 6] # SWA blocks + assert list(spec0.remote_block_ids[0]) == [5, 6] # SWA blocks assert list(spec0.remote_block_ids[1]) == [30, 31] # FA blocks # Spec 1 (P rank 1): @@ -928,7 +929,7 @@ def test_gather_read_specs_4p2d_rank0(self): # FA: empty (rank 1 not in FA source_ranks after GQA dedup) spec1 = specs[1] assert list(spec1.local_block_ids[0]) == [21, 23] # SWA slot 1 - assert list(spec1.remote_block_ids[0]) == [5, 6] # SWA blocks + assert list(spec1.remote_block_ids[0]) == [5, 6] # SWA blocks assert spec1.local_block_ids[1] == [] # FA empty for rank 1 assert spec1.remote_block_ids[1] == [] @@ -971,9 +972,7 @@ class TestGemma4GatherReadPlan4p1d: def test_4p1d_no_crash(self): """4p1d should not crash.""" - params = _make_gemma4_gather_plan_params( - tp_rank=0, tp_size=1, remote_tp_size=4 - ) + params = _make_gemma4_gather_plan_params(tp_rank=0, tp_size=1, remote_tp_size=4) # D_TP=1: D_page = 131072 (8 heads * 256 * 2 * 16 * 2 for SWA) # P_TP=4: P_page = 32768 params["block_len_per_layer"] = [131072, 131072] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index f85fd359ca17..1b9743bc82e7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -8,6 +8,7 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING @@ -23,6 +24,8 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec, MambaSpec +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( NixlAgentMetadata, @@ -130,15 +133,15 @@ class EngineTransferPlan: # Gemma4 2p4d: SWA = 16/16 = 1, FA = 32/16 = 2. local_blocks_per_remote_block: tuple[int, ...] = () - # Per-group: which descriptor index to read from a multi-descriptor + # Per-group: which descriptor offset to read from a multi-descriptor # remote block (for head-split groups where local reads a portion). # Gemma4 2p4d rank 0: SWA = 0 (first half), FA = 0 (unused, reads all). # Gemma4 2p4d rank 1: SWA = 1 (second half), FA = 0. - sub_desc_index_per_group: tuple[int, ...] = () + remote_desc_offset_per_group: tuple[int, ...] = () # --- Gather-read fields (local page > remote page, e.g. 4p2d FA) --- - # When D pages are larger than P pages, local blocks are split into - # sub-descriptors matching the remote block size for RDMA pairing. + # When D pages are larger than P pages, each local block is registered + # as multiple NIXL descriptors matching the remote block size. # local_page_size_bytes / remote_page_size_bytes. # 4p2d Gemma4: 65536 / 32768 = 2. 1 when no gather-read. @@ -499,10 +502,10 @@ def generate_gemma4_plan( and **gather-read** (local page > remote page, e.g. 4p2d). 3. Encodes per-group transfer behavior via ``local_blocks_per_remote_block`` / ``remote_blocks_per_local_block`` - and ``sub_desc_index_per_group``. + and ``remote_desc_offset_per_group``. - Split-read (P_page > D_page): remote blocks are split into sub-descs. - Gather-read (D_page > P_page): local blocks are split into sub-descs. + Split-read (P_page > D_page): each remote block → multiple descriptors. + Gather-read (D_page > P_page): each local block → multiple descriptors. """ tp_rank = transfer_topo.tp_rank tp_size = transfer_topo.tp_size @@ -523,7 +526,7 @@ def generate_gemma4_plan( blocks_per_remote: list[int] = [] remote_blocks_per_local: list[int] = [] - sub_desc_idx: list[int] = [] + remote_desc_offset: list[int] = [] source_ranks_all: list[tuple[int, ...]] = [] rank_to_slot_all: list[dict[int, int]] = [] @@ -551,17 +554,48 @@ def generate_gemma4_plan( source_ranks_all.append(m_g.source_ranks_per_group[0]) rank_to_slot_all.append(m_g.rank_to_attention_slot) - # Head-split groups (split-read only): rank_offset selects sub-desc. + # Head-split groups (split-read only): rank_offset selects descriptor. if blocks_per_remote[-1] == 1 and split_page_ratio > 1: - sub_desc_idx.append(m_g.rank_offset_factor) + remote_desc_offset.append(m_g.rank_offset_factor) else: - sub_desc_idx.append(0) + remote_desc_offset.append(0) all_ranks: set[int] = set() for ranks in source_ranks_all: all_ranks.update(ranks) all_source_ranks = tuple(sorted(all_ranks)) + # --- Diagnostic logging for HeteroTP plan --- + logger.info( + "[HeteroTP Plan] tp_rank=%d, tp_size=%d, remote_tp_size=%d, " + "local_page=%d, remote_page=%d, " + "split_page_ratio=%d, gather_page_ratio=%d", + tp_rank, + tp_size, + remote_tp_size, + local_page, + remote_page, + split_page_ratio, + gather_page_ratio, + ) + for g in range(n_groups): + logger.info( + "[HeteroTP Plan] group=%d kind=%s: K=%d, " + "local_tpb=%d, remote_tpb=%d, " + "blocks_per_remote=%d, remote_blocks_per_local=%d, " + "desc_offset=%d, source_ranks=%s, slot_map=%s", + g, + group_kinds[g].value, + total_num_kv_heads_per_group[g], + local_tokens_per_block[g], + remote_tokens_per_block[g], + blocks_per_remote[g], + remote_blocks_per_local[g], + remote_desc_offset[g], + source_ranks_all[g], + rank_to_slot_all[g], + ) + # HMA: one K pool (+ optional V pool) shared by all groups. fa_regions: list[RegionPlan] = [] for i in range(len(remote_meta.block_lens)): @@ -573,7 +607,7 @@ def generate_gemma4_plan( page_stride = remote_meta.block_lens[i] if split_page_ratio > 1: - # Split-read: remote blocks produce sub-descs of local page size + # Split-read: remote blocks produce descriptors of local page size desc_bytes = local_block_len descs_per_block = split_page_ratio desc_stride = local_block_len @@ -627,7 +661,7 @@ def generate_gemma4_plan( remote_expansion_stride=1, remote_to_local_page_ratio=split_page_ratio, local_blocks_per_remote_block=tuple(blocks_per_remote), - sub_desc_index_per_group=tuple(sub_desc_idx), + remote_desc_offset_per_group=tuple(remote_desc_offset), local_to_remote_page_ratio=gather_page_ratio, remote_blocks_per_local_block=tuple(remote_blocks_per_local), ) @@ -638,7 +672,7 @@ def generate_gemma4_plan( # ====================================================================== -def _remap_remote_blocks_to_subdesc_ids( +def _remap_remote_blocks_to_desc_ids( plan: EngineTransferPlan, remote_block_ids: BlockIds, local_block_ids: BlockIds, @@ -662,8 +696,8 @@ def _remap_remote_blocks_to_subdesc_ids( * **Head-split** (``local_blocks_per_remote_block == 1``, e.g. SWA): Local reads one specific chunk of the remote block. Remote block ``b`` → descriptor index - ``b*ratio + sub_desc_index_per_group[g]``. - Example: SWA block 10, ratio=2, index=1 → desc index 21. + ``b*ratio + remote_desc_offset_per_group[g]``. + Example: SWA block 10, ratio=2, offset=1 → desc index 21. Local block IDs are returned unchanged. """ @@ -686,7 +720,7 @@ def _remap_remote_blocks_to_subdesc_ids( remapped.extend(b * ratio + s for s in range(n_local)) new_remote.append(remapped) else: - idx = plan.sub_desc_index_per_group[g] + idx = plan.remote_desc_offset_per_group[g] new_remote.append([b * ratio + idx for b in r_ids]) new_local.append(l_ids) @@ -701,21 +735,76 @@ def _build_gather_read_specs( ) -> list[ReadSpec]: """Build read specs for gather-read (local page > remote page). - In gather-read, local blocks are split into sub-descriptors matching - remote block size. Each rank's read targets specific local sub-descs: + In gather-read, each local block is registered as multiple NIXL + descriptors (``descs_per_block`` in ``RegionPlan``), each matching + the remote block byte size. Each rank's read targets specific + local descriptor IDs: * **Gather groups** (``remote_blocks_per_local_block > 1``, e.g. FA): N remote blocks fill one local block. - Local block ``b`` → sub-desc indices - ``[b*ratio, b*ratio+1, ..., b*ratio+(N-1)]``. + Local block ``b`` → descriptor IDs + ``[b*gather_ratio, b*gather_ratio+1, ..., b*gather_ratio+(N-1)]``. + Remote block IDs are kept as-is (one remote block = one + remote descriptor). The matched-length invariant + ``len(local_desc_ids) == len(remote_block_ids)`` must hold; + it is enforced by an assertion after construction. * **Concat groups** (``remote_blocks_per_local_block == 1``, e.g. SWA): - Each rank writes to a specific slot of the local block. - Local block ``b`` → sub-desc index - ``b*ratio + rank_slot``. + Each rank writes to a specific descriptor within the local block. + Local block ``b`` → descriptor ID + ``b*gather_ratio + rank_slot``. """ gather_ratio = plan.local_to_remote_page_ratio num_groups = len(local_block_ids) + + def _pair_gather_group( + g_local_block_ids: list[int], + g_remote_block_ids: list[int], + remote_blocks_per_local: int, + ) -> tuple[list[int], list[int]]: + """Pair local descriptor IDs with remote block IDs for a gather group. + + With HMA, all groups receive the same block ID list. For gather + groups (``remote_blocks_per_local > 1``), every + ``remote_blocks_per_local`` consecutive remote blocks map to + descriptors of a single local block: + + local block b, remote blocks [r0, r1] → + local desc b*gather_ratio + 0 paired with r0 + local desc b*gather_ratio + 1 paired with r1 + + When the remote block count is not a multiple of + ``remote_blocks_per_local``, the remainder fills the first + descriptors of the next local block (partial fill). + + Returns matched-length lists: + (local_desc_ids, paired_remote_block_ids) + """ + n_local = len(g_local_block_ids) + n_remote = len(g_remote_block_ids) + n_full = min(n_remote // remote_blocks_per_local, n_local) + remainder_remote = n_remote - n_full * remote_blocks_per_local + + local_desc_ids: list[int] = [] + paired_remote: list[int] = [] + + for i in range(n_full): + b = g_local_block_ids[i] + for s in range(remote_blocks_per_local): + local_desc_ids.append(b * gather_ratio + s) + paired_remote.append( + g_remote_block_ids[i * remote_blocks_per_local + s] + ) + + if remainder_remote > 0 and n_full < n_local: + b = g_local_block_ids[n_full] + base = n_full * remote_blocks_per_local + for s in range(remainder_remote): + local_desc_ids.append(b * gather_ratio + s) + paired_remote.append(g_remote_block_ids[base + s]) + + return local_desc_ids, paired_remote + specs: list[ReadSpec] = [] for rank in plan.all_source_ranks: @@ -731,30 +820,50 @@ def _build_gather_read_specs( n_remote_per_local = plan.remote_blocks_per_local_block[g] if n_remote_per_local > 1: - expanded_local: list[int] = [] - for b in local_block_ids[g]: - expanded_local.extend( - b * gather_ratio + s - for s in range(n_remote_per_local) - ) - rank_local.append(expanded_local) - rank_remote.append(list(remote_block_ids[g])) + g_local, g_remote = _pair_gather_group( + local_block_ids[g], + remote_block_ids[g], + n_remote_per_local, + ) + rank_local.append(g_local) + rank_remote.append(g_remote) else: slot = plan.rank_to_attention_slot[g].get(rank, 0) + l_ids = local_block_ids[g] + r_ids = remote_block_ids[g] + n = min(len(l_ids), len(r_ids)) rank_local.append( - [b * gather_ratio + slot for b in local_block_ids[g]] + [ + l_ids[i] * gather_ratio + slot + for i in range(len(l_ids) - n, len(l_ids)) + ] ) - rank_remote.append(list(remote_block_ids[g])) + rank_remote.append(list(r_ids[len(r_ids) - n :])) + + for g in range(num_groups): + assert len(rank_local[g]) == len(rank_remote[g]), ( + f"Gather-read length mismatch: group={g}, rank={rank}, " + f"n_local_descs={len(rank_local[g])}, " + f"n_remote_blocks={len(rank_remote[g])}. " + f"Each local descriptor must pair with exactly one " + f"remote block ID." + ) - specs.append(ReadSpec( - remote_rank=rank, - local_block_ids=rank_local, - remote_block_ids=rank_remote, - )) + specs.append( + ReadSpec( + remote_rank=rank, + local_block_ids=rank_local, + remote_block_ids=rank_remote, + ) + ) return specs +logger = logging.getLogger(__name__) + + + # ====================================================================== # 4. Local descriptor building # ====================================================================== @@ -810,36 +919,31 @@ def build_fa_local_descs_for_gather_read( is_blocks_first: bool, gather_page_ratio: int, ) -> list[tuple[int, int, int]]: - """Build FA local descriptors with sub-descriptors for gather-read. + """Build FA local descriptors for gather-read. Each local block produces ``gather_page_ratio`` descriptors, each covering ``kv_block_len // gather_page_ratio`` bytes. This allows - NIXL to pair each local sub-descriptor with a remote descriptor of + NIXL to pair each local descriptor with a remote descriptor of matching size (the remote's natural page size). """ result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): kv_block_len = _get_kv_block_len(i, block_len_per_layer, is_blocks_first) page_stride = block_len_per_layer[i] - sub_desc_bytes = kv_block_len // gather_page_ratio + desc_bytes = kv_block_len // gather_page_ratio for block_id in range(num_blocks): blk_addr = base_addr + block_id * page_stride for s in range(gather_page_ratio): - result.append( - (blk_addr + s * sub_desc_bytes, sub_desc_bytes, device_id) - ) + result.append((blk_addr + s * desc_bytes, desc_bytes, device_id)) if is_blocks_first: - v_sub_desc_bytes = kv_block_len // gather_page_ratio + v_desc_bytes = kv_block_len // gather_page_ratio for block_id in range(num_blocks): - v_blk_addr = ( - base_addr + block_id * page_stride + kv_block_len - ) + v_blk_addr = base_addr + block_id * page_stride + kv_block_len for s in range(gather_page_ratio): result.append( - (v_blk_addr + s * v_sub_desc_bytes, v_sub_desc_bytes, - device_id) + (v_blk_addr + s * v_desc_bytes, v_desc_bytes, device_id) ) return result diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 697c2b88352c..b40bb1c6160a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -49,7 +49,7 @@ _build_gather_read_specs, _is_attention_spec, _is_ssm_spec, - _remap_remote_blocks_to_subdesc_ids, + _remap_remote_blocks_to_desc_ids, build_fa_local_descs_for_gather_read, build_fa_local_regions, build_mamba_local_regions, @@ -189,22 +189,38 @@ def _compute_read_specs_from_plan( local sub-desc remapping via ``_build_gather_read_specs``. - **Split-read** (``remote_to_local_page_ratio > 1``): rank-independent remote sub-desc remapping via - ``_remap_remote_blocks_to_subdesc_ids``. + ``_remap_remote_blocks_to_desc_ids``. - **Standard**: direct per-rank group filtering. """ if plan.local_to_remote_page_ratio > 1: - return _build_gather_read_specs( + specs = _build_gather_read_specs( plan, local_block_ids, remote_block_ids ) + if logger.isEnabledFor(logging.DEBUG): + for s in specs: + for g in range(len(s.local_block_ids)): + if s.local_block_ids[g]: + logger.debug( + "[ReadSpec gather] rank=%d group=%d: " + "local[:5]=%s remote[:5]=%s " + "(n_local=%d, n_remote=%d)", + s.remote_rank, + g, + s.local_block_ids[g][:5], + s.remote_block_ids[g][:5], + len(s.local_block_ids[g]), + len(s.remote_block_ids[g]), + ) + return specs remote_block_ids, local_block_ids = ( - _remap_remote_blocks_to_subdesc_ids( + _remap_remote_blocks_to_desc_ids( plan, remote_block_ids, local_block_ids ) ) num_groups = len(local_block_ids) - return [ + specs = [ ReadSpec( remote_rank=rank, local_block_ids=[ @@ -222,6 +238,22 @@ def _compute_read_specs_from_plan( ) for rank in plan.all_source_ranks ] + if logger.isEnabledFor(logging.DEBUG): + for s in specs: + for g in range(num_groups): + if s.local_block_ids[g]: + logger.debug( + "[ReadSpec std/split] rank=%d group=%d: " + "local[:5]=%s remote[:5]=%s " + "(n_local=%d, n_remote=%d)", + s.remote_rank, + g, + s.local_block_ids[g][:5], + s.remote_block_ids[g][:5], + len(s.local_block_ids[g]), + len(s.remote_block_ids[g]), + ) + return specs @staticmethod def _build_local_splits_from_plan( @@ -511,7 +543,7 @@ def __init__( # Populated dynamically during handshake based on remote configuration. # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} - # Gather-read local handles: local blocks split into sub-descs + # Gather-read local handles: local blocks split into descriptors # matching remote page size. Keyed by engine_id. self._gather_read_handles: dict[EngineId, int] = {} # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. @@ -1242,6 +1274,22 @@ def add_remote_agent( f"Remote tokens_per_block_per_group length " f"{len(remote_tpb)} != {len(group_spec_types)} groups" ) + logger.info( + "[HeteroTP] Generating Gemma4 plan: " + "group_kinds=%s, total_kv_heads_per_group=%s, " + "local_tpb_per_group=%s, remote_tpb_per_group=%s, " + "local_block_len_per_layer=%s, remote_block_lens=%s, " + "local_tp=%d, remote_tp=%d, tp_rank=%d", + [k.value for k in self._group_kinds], + self._total_kv_heads_per_group, + self._local_tokens_per_block_per_group, + remote_tpb, + self.block_len_per_layer[:3], + nixl_agent_meta.block_lens[:3], + transfer_topo.tp_size, + remote_tp_size, + transfer_topo.tp_rank, + ) self._transfer_plans[engine_id] = generate_gemma4_plan( transfer_topo=transfer_topo, block_len_per_layer=self.block_len_per_layer, @@ -1306,12 +1354,12 @@ def add_remote_agent( plan.local_to_remote_page_ratio > 1 and engine_id not in self._gather_read_handles ): - # Gather-read: local page > remote page. Register local descs - # with sub-descriptors matching the remote block size. + # Gather-read: local page > remote page. Register local + # descriptors matching the remote block size. assert self.transfer_topo is not None - local_base_addresses = self.kv_caches_base_addr[ - self.engine_id - ][self.tp_rank] + local_base_addresses = self.kv_caches_base_addr[self.engine_id][ + self.tp_rank + ] gather_blocks_data = build_fa_local_descs_for_gather_read( base_addresses=local_base_addresses, device_id=self.device_id, @@ -1323,8 +1371,8 @@ def add_remote_agent( descs = self.nixl_wrapper.get_xfer_descs( gather_blocks_data, self.nixl_memory_type ) - self._gather_read_handles[engine_id] = ( - self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self._gather_read_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs ) elif ( tp_ratio < 0 @@ -1905,6 +1953,31 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): if self.use_mla and tp_ratio < 0: read_specs = read_specs[:1] + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[HeteroTP _read_blocks_for_req] req=%s engine=%s " + "tp_ratio=%d, n_read_specs=%d, " + "plan: split_ratio=%d, gather_ratio=%d, " + "group_kinds=%s, " + "local_physical_block_ids=[%s], " + "remote_block_ids=[%s]", + req_id, + engine_id, + tp_ratio, + len(read_specs), + plan.remote_to_local_page_ratio, + plan.local_to_remote_page_ratio, + [k.value for k in plan.group_kinds], + ", ".join( + f"g{i}:{meta.local_physical_block_ids[i][:5]}" + for i in range(len(meta.local_physical_block_ids)) + ), + ", ".join( + f"g{i}:{meta.remote.block_ids[i][:5]}" + for i in range(len(meta.remote.block_ids)) + ), + ) + for i, spec in enumerate(read_specs): remote_rank = spec.remote_rank local_block_ids = spec.local_block_ids @@ -1920,7 +1993,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): ) # Get side handles. if engine_id in self._gather_read_handles: - # Gather-read: local sub-desc handle matches remote page size. + # Gather-read: local descriptor handle matches remote page size. local_xfer_side_handle = self._gather_read_handles[engine_id] elif tp_ratio < 0 and not self.use_mla: assert remote_block_size == self.block_size @@ -2048,8 +2121,10 @@ def _read_blocks( # Partial prefix cache hit: trim to the shorter of local/remote. # Skip mamba groups — their blocks represent full state (conv+ssm), # not per-token data, so trimming would corrupt the transfer. - # For standard and split-read: remote >= local (trim remote). - # For gather-read: local sub-descs may exceed remote (trim local). + # After ReadSpec construction, local descriptor IDs and remote + # block IDs should already have matched lengths per group + # (gather-read pairing ensures this). Trim from the head to + # keep the tail (newest blocks). remote_block_ids = list(remote_block_ids) local_block_ids = list(local_block_ids) group_specs = self.kv_cache_config.kv_cache_groups @@ -2065,6 +2140,14 @@ def _read_blocks( if n_remote > n: remote_block_ids[i] = remote_block_ids[i][-n:] + for i in range(len(remote_block_ids)): + assert len(local_block_ids[i]) == len(remote_block_ids[i]), ( + f"Block ID length mismatch after trim: group={i}, " + f"n_local={len(local_block_ids[i])}, " + f"n_remote={len(remote_block_ids[i])}. " + f"ReadSpec should produce matched lengths." + ) + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. @@ -2097,6 +2180,31 @@ def _read_blocks( assert len(local_block_descs_ids) == len(remote_block_descs_ids) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[HeteroTP _read_blocks] req=%s rank=%d: " + "n_descs=%d, block_size_ratio=%s, " + "remote_desc_blocks=%d, local_desc_blocks=%d, " + "local_desc_ids[:10]=%s, remote_desc_ids[:10]=%s, " + "local_block_ids=[%s], remote_block_ids=[%s]", + request_id, + remote_rank, + len(local_block_descs_ids), + block_size_ratio, + remote_desc_blocks, + local_desc_blocks, + local_block_descs_ids[:10].tolist(), + remote_block_descs_ids[:10].tolist(), + ", ".join( + f"g{i}:{local_block_ids[i][:5]}" + for i in range(len(local_block_ids)) + ), + ", ".join( + f"g{i}:{remote_block_ids[i][:5]}" + for i in range(len(remote_block_ids)) + ), + ) + # Prepare transfer with Nixl. handle = None try: From e3d5ebb143d9200b615d9f5fcdb7809e4658bc85 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Tue, 28 Apr 2026 16:17:10 -0400 Subject: [PATCH 47/49] gemma4 Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 15 ++- .../kv_connector/v1/nixl/transfer_plan.py | 121 +++++++++++------- .../kv_connector/v1/nixl/worker.py | 54 +++++--- 3 files changed, 115 insertions(+), 75 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index 99a03ea6269d..d8b503e214da 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -604,7 +604,7 @@ def _make_gemma4_plan_params( FA: 5 layers, K=2, head_dim=512, P block_size=32, D block_size=16 With page unification + HMA, all groups share one physical pool. - page_size: P=65536, D=32768 → remote_to_local_page_ratio=2. + page_size: P=65536, D=32768 → remote_page > local_page (split-read). For simplicity, use 2 physical layers in tests. """ # D side (local): kv_heads_per_rank for all groups = page_size / block_size @@ -649,7 +649,8 @@ def test_plan_fields_2p4d_rank0(self): """D rank 0 at 2p4d: ratio=2, SWA head-split, FA multi-block.""" plan = generate_gemma4_plan(**_make_gemma4_plan_params(tp_rank=0)) - assert plan.remote_to_local_page_ratio == 2 + assert plan.remote_page_size == 65536 + assert plan.local_page_size == 32768 assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) assert plan.local_blocks_per_remote_block == (1, 2) assert plan.remote_desc_offset_per_group == (0, 0) # rank 0: index=0 @@ -794,7 +795,7 @@ def _make_gemma4_gather_plan_params( SWA: K=8, head_dim=256, P_tpb=16, D_tpb=16 → concat (2 P ranks) FA: K=2, head_dim=512, P_tpb=16, D_tpb=32 → gather (2P→1D block) - page_size: P=32768, D=65536 → local_to_remote_page_ratio=2. + page_size: P=32768, D=65536 → local_page > remote_page (gather-read). """ d_page = 65536 p_page = 32768 @@ -836,8 +837,8 @@ def test_plan_fields_4p2d_rank0(self): """D rank 0 at 4p2d: gather_ratio=2, SWA concat, FA gather.""" plan = generate_gemma4_plan(**_make_gemma4_gather_plan_params(tp_rank=0)) - assert plan.local_to_remote_page_ratio == 2 - assert plan.remote_to_local_page_ratio == 1 + assert plan.local_page_size == 65536 + assert plan.remote_page_size == 32768 assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) assert plan.remote_blocks_per_local_block == (1, 2) assert plan.local_blocks_per_remote_block == (1, 1) @@ -980,6 +981,6 @@ def test_4p1d_no_crash(self): params["remote_tokens_per_block"] = (16, 16) plan = generate_gemma4_plan(**params) - assert plan.local_to_remote_page_ratio == 4 - assert plan.remote_to_local_page_ratio == 1 + assert plan.local_page_size == 131072 + assert plan.remote_page_size == 32768 assert plan.remote_blocks_per_local_block == (1, 2) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 1b9743bc82e7..75b6d5279176 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -120,13 +120,17 @@ class EngineTransferPlan: # Stride for expanding remote logical block IDs to kernel block IDs. remote_expansion_stride: int - # --- HeteroTP per-group fields (e.g. Gemma4 SWA + FA) --- - # Active only when remote_to_local_page_ratio > 1. - # For Dense/Mamba (ratio=1), these are unused and default to empty. + # --- Page sizes (bytes per physical block, same for all groups) --- + # Used to determine transfer direction and descriptor layout. + # Split-read: remote_page_size > local_page_size (e.g. Gemma4 2p4d) + # Gather-read: local_page_size > remote_page_size (e.g. Gemma4 4p2d) + # Standard: local_page_size == remote_page_size + local_page_size: int + remote_page_size: int - # remote_page_size_bytes / local_page_size_bytes. - # Gemma4 2p4d: 65536 / 32768 = 2. - remote_to_local_page_ratio: int = 1 + # --- HeteroTP per-group fields (e.g. Gemma4 SWA + FA) --- + # For Dense/Mamba (equal page sizes), these are unused and default + # to empty. # Per-group: how many local (D) blocks correspond to one remote (P) # block. Computed as remote_block_size / local_block_size per group. @@ -139,18 +143,21 @@ class EngineTransferPlan: # Gemma4 2p4d rank 1: SWA = 1 (second half), FA = 0. remote_desc_offset_per_group: tuple[int, ...] = () - # --- Gather-read fields (local page > remote page, e.g. 4p2d FA) --- - # When D pages are larger than P pages, each local block is registered - # as multiple NIXL descriptors matching the remote block size. - - # local_page_size_bytes / remote_page_size_bytes. - # 4p2d Gemma4: 65536 / 32768 = 2. 1 when no gather-read. - local_to_remote_page_ratio: int = 1 - # Per-group: how many remote blocks fill one local block. # FA in 4p2d: D_tpb / P_tpb = 32 / 16 = 2. remote_blocks_per_local_block: tuple[int, ...] = () + def __post_init__(self): + big, small = ( + max(self.local_page_size, self.remote_page_size), + min(self.local_page_size, self.remote_page_size), + ) + assert small > 0, "Page sizes must be positive" + assert big % small == 0, ( + f"Page sizes must be evenly divisible: " + f"local={self.local_page_size}, remote={self.remote_page_size}" + ) + @property def all_regions(self) -> tuple[RegionPlan, ...]: return self.fa_regions + self.ssm_regions @@ -326,6 +333,13 @@ def generate_dense_plan( local_physical_blocks_per_logical: int, ) -> EngineTransferPlan: """Generate transfer plan for dense (attention-only) models.""" + local_page = block_len_per_layer[0] + remote_page = remote_meta.block_lens[0] + assert local_page == remote_page, ( + f"Dense plan does not support different page sizes: " + f"local={local_page}, remote={remote_page}" + ) + block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size tp_mapping = _compute_tp_mapping( @@ -360,6 +374,8 @@ def generate_dense_plan( all_source_ranks=tp_mapping.all_source_ranks, rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,), remote_expansion_stride=local_physical_blocks_per_logical, + local_page_size=local_page, + remote_page_size=remote_page, ) @@ -468,6 +484,13 @@ def generate_mamba_plan( ) ) + local_page = block_len_per_layer[0] + remote_page = remote_block_lens[0] + assert local_page == remote_page, ( + f"Mamba plan does not support different page sizes: " + f"local={local_page}, remote={remote_page}" + ) + n_groups = len(group_spec_types) return EngineTransferPlan( fa_regions=tuple(fa_regions), @@ -477,6 +500,8 @@ def generate_mamba_plan( all_source_ranks=tp_mapping.all_source_ranks, rank_to_attention_slot=(tp_mapping.rank_to_attention_slot,) * n_groups, remote_expansion_stride=remote_phys_ratio, + local_page_size=local_page, + remote_page_size=remote_page, ) @@ -518,11 +543,11 @@ def generate_gemma4_plan( remote_page = remote_meta.block_lens[0] if remote_page >= local_page: - split_page_ratio = remote_page // local_page - gather_page_ratio = 1 + descs_per_remote_block = remote_page // local_page + descs_per_local_block = 1 else: - split_page_ratio = 1 - gather_page_ratio = local_page // remote_page + descs_per_remote_block = 1 + descs_per_local_block = local_page // remote_page blocks_per_remote: list[int] = [] remote_blocks_per_local: list[int] = [] @@ -555,7 +580,7 @@ def generate_gemma4_plan( rank_to_slot_all.append(m_g.rank_to_attention_slot) # Head-split groups (split-read only): rank_offset selects descriptor. - if blocks_per_remote[-1] == 1 and split_page_ratio > 1: + if blocks_per_remote[-1] == 1 and descs_per_remote_block > 1: remote_desc_offset.append(m_g.rank_offset_factor) else: remote_desc_offset.append(0) @@ -569,14 +594,14 @@ def generate_gemma4_plan( logger.info( "[HeteroTP Plan] tp_rank=%d, tp_size=%d, remote_tp_size=%d, " "local_page=%d, remote_page=%d, " - "split_page_ratio=%d, gather_page_ratio=%d", + "descs_per_remote_block=%d, descs_per_local_block=%d", tp_rank, tp_size, remote_tp_size, local_page, remote_page, - split_page_ratio, - gather_page_ratio, + descs_per_remote_block, + descs_per_local_block, ) for g in range(n_groups): logger.info( @@ -606,12 +631,12 @@ def generate_gemma4_plan( ) page_stride = remote_meta.block_lens[i] - if split_page_ratio > 1: + if descs_per_remote_block > 1: # Split-read: remote blocks produce descriptors of local page size desc_bytes = local_block_len - descs_per_block = split_page_ratio + descs_per_block = descs_per_remote_block desc_stride = local_block_len - elif gather_page_ratio > 1: + elif descs_per_local_block > 1: # Gather-read: standard remote descs at remote page size remote_block_len = _get_kv_block_len( i, remote_meta.block_lens, is_blocks_first @@ -659,10 +684,10 @@ def generate_gemma4_plan( all_source_ranks=all_source_ranks, rank_to_attention_slot=tuple(rank_to_slot_all), remote_expansion_stride=1, - remote_to_local_page_ratio=split_page_ratio, + local_page_size=local_page, + remote_page_size=remote_page, local_blocks_per_remote_block=tuple(blocks_per_remote), remote_desc_offset_per_group=tuple(remote_desc_offset), - local_to_remote_page_ratio=gather_page_ratio, remote_blocks_per_local_block=tuple(remote_blocks_per_local), ) @@ -679,32 +704,32 @@ def _remap_remote_blocks_to_desc_ids( ) -> tuple[BlockIds, BlockIds]: """Convert remote block IDs into descriptor-level indices. - When ``remote_to_local_page_ratio > 1``, each remote physical block - is registered as multiple descriptors (one per local-page-sized - chunk). This function converts remote block IDs into the - descriptor index space so that ``_compute_desc_ids_from_plan`` can - look up the correct descriptors. + When ``remote_page_size > local_page_size`` (split-read), each remote + physical block is registered as multiple descriptors (one per + local-page-sized chunk). This function converts remote block IDs + into the descriptor index space so that ``_compute_desc_ids_from_plan`` + can look up the correct descriptors. Two per-group cases: * **Multi-block** (``local_blocks_per_remote_block > 1``, e.g. FA): One remote block covers multiple local blocks. Remote block ``b`` → descriptor indices - ``[b*ratio, b*ratio+1, ..., b*ratio+(n-1)]``. - Example: FA block 10, ratio=2 → desc indices [20, 21]. + ``[b*N, b*N+1, ..., b*N+(n-1)]`` where N = descs_per_remote_block. + Example: FA block 10, N=2 → desc indices [20, 21]. * **Head-split** (``local_blocks_per_remote_block == 1``, e.g. SWA): Local reads one specific chunk of the remote block. Remote block ``b`` → descriptor index - ``b*ratio + remote_desc_offset_per_group[g]``. - Example: SWA block 10, ratio=2, offset=1 → desc index 21. + ``b*N + remote_desc_offset_per_group[g]``. + Example: SWA block 10, N=2, offset=1 → desc index 21. Local block IDs are returned unchanged. """ - if plan.remote_to_local_page_ratio <= 1: + if plan.remote_page_size <= plan.local_page_size: return remote_block_ids, local_block_ids - ratio = plan.remote_to_local_page_ratio + descs_per_remote_block = plan.remote_page_size // plan.local_page_size num_groups = len(remote_block_ids) new_remote: list[list[int]] = [] new_local: list[list[int]] = [] @@ -717,11 +742,11 @@ def _remap_remote_blocks_to_desc_ids( if n_local > 1: remapped: list[int] = [] for b in r_ids: - remapped.extend(b * ratio + s for s in range(n_local)) + remapped.extend(b * descs_per_remote_block + s for s in range(n_local)) new_remote.append(remapped) else: idx = plan.remote_desc_offset_per_group[g] - new_remote.append([b * ratio + idx for b in r_ids]) + new_remote.append([b * descs_per_remote_block + idx for b in r_ids]) new_local.append(l_ids) @@ -743,7 +768,7 @@ def _build_gather_read_specs( * **Gather groups** (``remote_blocks_per_local_block > 1``, e.g. FA): N remote blocks fill one local block. Local block ``b`` → descriptor IDs - ``[b*gather_ratio, b*gather_ratio+1, ..., b*gather_ratio+(N-1)]``. + ``[b*descs_per_local_block, ..., b*descs_per_local_block+(N-1)]``. Remote block IDs are kept as-is (one remote block = one remote descriptor). The matched-length invariant ``len(local_desc_ids) == len(remote_block_ids)`` must hold; @@ -752,9 +777,9 @@ def _build_gather_read_specs( * **Concat groups** (``remote_blocks_per_local_block == 1``, e.g. SWA): Each rank writes to a specific descriptor within the local block. Local block ``b`` → descriptor ID - ``b*gather_ratio + rank_slot``. + ``b*descs_per_local_block + rank_slot``. """ - gather_ratio = plan.local_to_remote_page_ratio + descs_per_local_block = plan.local_page_size // plan.remote_page_size num_groups = len(local_block_ids) def _pair_gather_group( @@ -770,8 +795,8 @@ def _pair_gather_group( descriptors of a single local block: local block b, remote blocks [r0, r1] → - local desc b*gather_ratio + 0 paired with r0 - local desc b*gather_ratio + 1 paired with r1 + local desc b*descs_per_local_block + 0 paired with r0 + local desc b*descs_per_local_block + 1 paired with r1 When the remote block count is not a multiple of ``remote_blocks_per_local``, the remainder fills the first @@ -791,7 +816,7 @@ def _pair_gather_group( for i in range(n_full): b = g_local_block_ids[i] for s in range(remote_blocks_per_local): - local_desc_ids.append(b * gather_ratio + s) + local_desc_ids.append(b * descs_per_local_block + s) paired_remote.append( g_remote_block_ids[i * remote_blocks_per_local + s] ) @@ -800,7 +825,7 @@ def _pair_gather_group( b = g_local_block_ids[n_full] base = n_full * remote_blocks_per_local for s in range(remainder_remote): - local_desc_ids.append(b * gather_ratio + s) + local_desc_ids.append(b * descs_per_local_block + s) paired_remote.append(g_remote_block_ids[base + s]) return local_desc_ids, paired_remote @@ -834,7 +859,7 @@ def _pair_gather_group( n = min(len(l_ids), len(r_ids)) rank_local.append( [ - l_ids[i] * gather_ratio + slot + l_ids[i] * descs_per_local_block + slot for i in range(len(l_ids) - n, len(l_ids)) ] ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index b40bb1c6160a..d629b8899904 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -185,14 +185,14 @@ def _compute_read_specs_from_plan( Dispatches to the correct remapping strategy: - - **Gather-read** (``local_to_remote_page_ratio > 1``): per-rank - local sub-desc remapping via ``_build_gather_read_specs``. - - **Split-read** (``remote_to_local_page_ratio > 1``): - rank-independent remote sub-desc remapping via + - **Gather-read** (``local_page_size > remote_page_size``): per-rank + local descriptor pairing via ``_build_gather_read_specs``. + - **Split-read** (``remote_page_size > local_page_size``): + rank-independent remote descriptor remapping via ``_remap_remote_blocks_to_desc_ids``. - **Standard**: direct per-rank group filtering. """ - if plan.local_to_remote_page_ratio > 1: + if plan.local_page_size > plan.remote_page_size: specs = _build_gather_read_specs( plan, local_block_ids, remote_block_ids ) @@ -1351,12 +1351,13 @@ def add_remote_agent( ### (Optional) Register local agent memory regions. if ( - plan.local_to_remote_page_ratio > 1 + plan.local_page_size > plan.remote_page_size and engine_id not in self._gather_read_handles ): # Gather-read: local page > remote page. Register local # descriptors matching the remote block size. assert self.transfer_topo is not None + descs_per_local_block = plan.local_page_size // plan.remote_page_size local_base_addresses = self.kv_caches_base_addr[self.engine_id][ self.tp_rank ] @@ -1366,7 +1367,7 @@ def add_remote_agent( num_blocks=self.num_blocks, block_len_per_layer=self.block_len_per_layer, is_blocks_first=self.transfer_topo.is_kv_layout_blocks_first, - gather_page_ratio=plan.local_to_remote_page_ratio, + gather_page_ratio=descs_per_local_block, ) descs = self.nixl_wrapper.get_xfer_descs( gather_blocks_data, self.nixl_memory_type @@ -1957,7 +1958,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "[HeteroTP _read_blocks_for_req] req=%s engine=%s " "tp_ratio=%d, n_read_specs=%d, " - "plan: split_ratio=%d, gather_ratio=%d, " + "plan: local_page=%d, remote_page=%d, " "group_kinds=%s, " "local_physical_block_ids=[%s], " "remote_block_ids=[%s]", @@ -1965,8 +1966,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): engine_id, tp_ratio, len(read_specs), - plan.remote_to_local_page_ratio, - plan.local_to_remote_page_ratio, + plan.local_page_size, + plan.remote_page_size, [k.value for k in plan.group_kinds], ", ".join( f"g{i}:{meta.local_physical_block_ids[i][:5]}" @@ -2152,16 +2153,29 @@ def _read_blocks( # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. - # Get descs ids. - # For split-read (page_ratio > 1), each remote block is registered as - # multiple descriptors, so scale the remote descriptor-space count. - # For gather-read, scale the local descriptor-space count instead. - remote_desc_blocks = ( - self.dst_num_blocks[dst_engine_id] * plan.remote_to_local_page_ratio - ) - local_desc_blocks = self.dst_num_blocks[self.engine_id] - if plan.local_to_remote_page_ratio > 1: - local_desc_blocks *= plan.local_to_remote_page_ratio + # Get descs ids. Both calls use the same plan since region counts + # (len(fa_regions), len(ssm_regions)) are model-determined and + # identical across engines. + if plan.remote_page_size > plan.local_page_size: + # Split-read: each remote block → multiple descriptors. + remote_desc_blocks = ( + self.dst_num_blocks[dst_engine_id] + * plan.remote_page_size + // plan.local_page_size + ) + local_desc_blocks = self.dst_num_blocks[self.engine_id] + elif plan.local_page_size > plan.remote_page_size: + # Gather-read: each local block → multiple descriptors. + remote_desc_blocks = self.dst_num_blocks[dst_engine_id] + local_desc_blocks = ( + self.dst_num_blocks[self.engine_id] + * plan.local_page_size + // plan.remote_page_size + ) + else: + # Standard: 1:1 block-to-descriptor mapping. + remote_desc_blocks = self.dst_num_blocks[dst_engine_id] + local_desc_blocks = self.dst_num_blocks[self.engine_id] remote_block_descs_ids = self._compute_desc_ids_from_plan( plan, From 56d12e75acf381fe66a921c2cead3a17ce73078d Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 29 Apr 2026 18:07:34 -0400 Subject: [PATCH 48/49] fix tests Signed-off-by: Zhanqiu Hu --- .../kv_connector/unit/test_transfer_plan.py | 50 ++++++++++++------- .../kv_connector/v1/nixl/transfer_plan.py | 17 ++----- .../kv_connector/v1/nixl/worker.py | 36 ++++++------- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_transfer_plan.py b/tests/v1/kv_connector/unit/test_transfer_plan.py index d8b503e214da..fc7018fa2c54 100644 --- a/tests/v1/kv_connector/unit/test_transfer_plan.py +++ b/tests/v1/kv_connector/unit/test_transfer_plan.py @@ -415,8 +415,10 @@ def _make_mamba_plan_for_desc_ids( group_spec_types=group_spec_types, source_ranks_per_group=source_ranks_per_group, all_source_ranks=(0,), - rank_to_attention_slot=({0: 0},) * len(group_kinds), + rank_to_attention_slot=({0: 0},) * len(group_spec_types), remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, ) @@ -490,6 +492,8 @@ def test_all_source_ranks_serve_fa(self): all_source_ranks=(0, 1), rank_to_attention_slot=({0: 0, 1: 1}, {0: 0, 1: 1}), remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, ) local_ids = ([1, 2], [3, 4]) @@ -515,6 +519,8 @@ def test_non_fa_rank_skips_fa_groups(self): all_source_ranks=(0, 1, 2), rank_to_attention_slot=({0: 0}, {0: 0}), remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, ) local_ids = ([1, 2], [3, 4]) @@ -561,6 +567,8 @@ def test_fa_and_ssm_different_split_factors(self): all_source_ranks=(0, 1), rank_to_attention_slot=({0: 0, 1: 0}, {0: 0, 1: 0}), remote_expansion_stride=1, + local_page_size=100, + remote_page_size=100, ) # 2 FA descs + 1 SSM desc @@ -635,7 +643,7 @@ def _make_gemma4_plan_params( block_lens=[p_page] * num_layers, block_size=16, ), - group_kinds=(GroupKind.SWA, GroupKind.FA), + group_spec_types=(FullAttentionSpec, FullAttentionSpec), total_num_kv_heads_per_group=(8, 2), local_tokens_per_block=(16, 16), remote_tokens_per_block=(16, 32), @@ -651,7 +659,7 @@ def test_plan_fields_2p4d_rank0(self): assert plan.remote_page_size == 65536 assert plan.local_page_size == 32768 - assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) + assert plan.group_spec_types == (FullAttentionSpec, FullAttentionSpec) assert plan.local_blocks_per_remote_block == (1, 2) assert plan.remote_desc_offset_per_group == (0, 0) # rank 0: index=0 assert plan.all_source_ranks == (0,) @@ -692,7 +700,7 @@ def test_descs_per_block(self): num_blocks=500, block_lens=[65536, 65536], ) - descs = build_remote_descs_from_plan(plan, meta) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) # 2 layers × 1 region/layer × 500 blocks × 2 descs/block = 2000 expected_count = 2 * 500 * 2 @@ -706,7 +714,7 @@ def test_desc_stride_within_block(self): num_blocks=500, block_lens=[65536, 65536], ) - descs = build_remote_descs_from_plan(plan, meta) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) # First block, layer 0: descriptor 0 and descriptor 1 addr_d0, len_d0, _ = descs[0] @@ -724,7 +732,7 @@ def test_remapped_block_ids(self): # SWA blocks [3, 7], FA blocks [10, 11] # Remapped to descriptor indices: - # SWA (sub_desc_index=0): [3*2+0, 7*2+0] = [6, 14] + # SWA (desc_index=0): [3*2+0, 7*2+0] = [6, 14] # FA (2 local per remote): [10*2, 10*2+1, 11*2, 11*2+1] = [20,21,22,23] # # dst_num_blocks = 500 * 2 = 1000 (num_blocks * descs_per_block) @@ -741,7 +749,7 @@ def test_remapped_block_ids(self): remote_swa = [3, 7] remote_fa = [10, 11] - specs = compute_read_specs_from_plan( + specs = NixlConnectorWorker._compute_read_specs_from_plan( plan, local_block_ids=(local_swa, local_fa), remote_block_ids=(remote_swa, remote_fa), @@ -758,19 +766,23 @@ def test_remapped_block_ids(self): assert list(spec.local_block_ids[1]) == [20, 21, 22, 23] # Now compute desc IDs with the remapped remote blocks - remote_ids = compute_desc_ids_from_plan( + remote_ids = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=spec.remote_block_ids, dst_num_blocks=500 * 2, # num_blocks * descs_per_block + block_size_ratio=None, + physical_blocks_per_logical=1, ) expected_remote = [6, 14, 1006, 1014, 20, 21, 22, 23, 1020, 1021, 1022, 1023] assert list(remote_ids) == expected_remote # Local desc IDs (standard, descs_per_block=1 locally) - local_ids = compute_desc_ids_from_plan( + local_ids = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=spec.local_block_ids, dst_num_blocks=1000, # local num_blocks + block_size_ratio=None, + physical_blocks_per_logical=1, ) expected_local = [10, 11, 1010, 1011, 20, 21, 22, 23, 1020, 1021, 1022, 1023] assert list(local_ids) == expected_local @@ -823,7 +835,7 @@ def _make_gemma4_gather_plan_params( block_lens=[p_page] * num_layers, block_size=16, ), - group_kinds=(GroupKind.SWA, GroupKind.FA), + group_spec_types=(FullAttentionSpec, FullAttentionSpec), total_num_kv_heads_per_group=(8, 2), local_tokens_per_block=(16, 32), remote_tokens_per_block=(16, 16), @@ -839,7 +851,7 @@ def test_plan_fields_4p2d_rank0(self): assert plan.local_page_size == 65536 assert plan.remote_page_size == 32768 - assert plan.group_kinds == (GroupKind.SWA, GroupKind.FA) + assert plan.group_spec_types == (FullAttentionSpec, FullAttentionSpec) assert plan.remote_blocks_per_local_block == (1, 2) assert plan.local_blocks_per_remote_block == (1, 1) # SWA: D rank 0 reads from P rank 0 and P rank 1 @@ -874,7 +886,7 @@ def test_standard_descs_per_block(self): num_blocks=500, block_lens=[32768, 32768], ) - descs = build_remote_descs_from_plan(plan, meta) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) # 2 layers × 1 region/layer × 500 blocks × 1 desc/block = 1000 assert len(descs) == 2 * 500 * 1 @@ -887,7 +899,7 @@ def test_desc_bytes_match_remote_page(self): num_blocks=500, block_lens=[32768, 32768], ) - descs = build_remote_descs_from_plan(plan, meta) + descs = NixlConnectorWorker._build_remote_descs_from_plan(plan, meta) for _, length, _ in descs: assert length == 32768 @@ -907,7 +919,7 @@ def test_gather_read_specs_4p2d_rank0(self): remote_swa = [5, 6] remote_fa = [30, 31] - specs = compute_read_specs_from_plan( + specs = NixlConnectorWorker._compute_read_specs_from_plan( plan, local_block_ids=(local_swa, local_fa), remote_block_ids=(remote_swa, remote_fa), @@ -943,7 +955,7 @@ def test_gather_read_desc_ids_match(self): remote_swa = [5, 6] remote_fa = [30, 31] - specs = compute_read_specs_from_plan( + specs = NixlConnectorWorker._compute_read_specs_from_plan( plan, local_block_ids=(local_swa, local_fa), remote_block_ids=(remote_swa, remote_fa), @@ -951,16 +963,20 @@ def test_gather_read_desc_ids_match(self): for spec in specs: # Remote desc IDs: standard (no sub-descs), num_blocks=500 - remote_ids = compute_desc_ids_from_plan( + remote_ids = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=spec.remote_block_ids, dst_num_blocks=500, + block_size_ratio=None, + physical_blocks_per_logical=1, ) # Local desc IDs: gather sub-descs, num_blocks=1000*gather_ratio - local_ids = compute_desc_ids_from_plan( + local_ids = NixlConnectorWorker._compute_desc_ids_from_plan( plan, block_ids=spec.local_block_ids, dst_num_blocks=1000 * 2, # local_num_blocks * gather_ratio + block_size_ratio=None, + physical_blocks_per_logical=1, ) assert len(remote_ids) == len(local_ids), ( f"Desc ID length mismatch for rank {spec.remote_rank}: " diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py index 75b6d5279176..66ff4f2e9be1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py @@ -335,10 +335,6 @@ def generate_dense_plan( """Generate transfer plan for dense (attention-only) models.""" local_page = block_len_per_layer[0] remote_page = remote_meta.block_lens[0] - assert local_page == remote_page, ( - f"Dense plan does not support different page sizes: " - f"local={local_page}, remote={remote_page}" - ) block_size_ratio = transfer_topo.block_size // remote_info.remote_block_size @@ -486,10 +482,6 @@ def generate_mamba_plan( local_page = block_len_per_layer[0] remote_page = remote_block_lens[0] - assert local_page == remote_page, ( - f"Mamba plan does not support different page sizes: " - f"local={local_page}, remote={remote_page}" - ) n_groups = len(group_spec_types) return EngineTransferPlan( @@ -605,12 +597,12 @@ def generate_gemma4_plan( ) for g in range(n_groups): logger.info( - "[HeteroTP Plan] group=%d kind=%s: K=%d, " + "[HeteroTP Plan] group=%d spec=%s: K=%d, " "local_tpb=%d, remote_tpb=%d, " "blocks_per_remote=%d, remote_blocks_per_local=%d, " "desc_offset=%d, source_ranks=%s, slot_map=%s", g, - group_kinds[g].value, + group_spec_types[g].__name__, total_num_kv_heads_per_group[g], local_tokens_per_block[g], remote_tokens_per_block[g], @@ -651,7 +643,6 @@ def generate_gemma4_plan( fa_regions.append( RegionPlan( - kind=RegionKind.FA_K, layer_idx=i, descriptor_bytes=desc_bytes, offset_in_page=0, @@ -665,7 +656,6 @@ def generate_gemma4_plan( if is_blocks_first: fa_regions.append( RegionPlan( - kind=RegionKind.FA_V, layer_idx=i, descriptor_bytes=desc_bytes, offset_in_page=page_stride // 2, @@ -728,6 +718,8 @@ def _remap_remote_blocks_to_desc_ids( """ if plan.remote_page_size <= plan.local_page_size: return remote_block_ids, local_block_ids + if not plan.local_blocks_per_remote_block: + return remote_block_ids, local_block_ids descs_per_remote_block = plan.remote_page_size // plan.local_page_size num_groups = len(remote_block_ids) @@ -888,7 +880,6 @@ def _pair_gather_group( logger = logging.getLogger(__name__) - # ====================================================================== # 4. Local descriptor building # ====================================================================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index d629b8899904..68cafed2dc8f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -109,8 +109,10 @@ def _build_remote_descs_from_plan( for region in plan.all_regions: base_addr = nixl_agent_meta.kv_caches_base_addr[region.layer_idx] for blk in range(region.num_blocks): - addr = base_addr + blk * region.page_stride + region.offset_in_page - result.append((addr, region.descriptor_bytes, dev_id)) + blk_addr = base_addr + blk * region.page_stride + region.offset_in_page + for sub in range(region.descs_per_block): + addr = blk_addr + sub * region.desc_stride_bytes + result.append((addr, region.descriptor_bytes, dev_id)) return result @@ -192,10 +194,11 @@ def _compute_read_specs_from_plan( ``_remap_remote_blocks_to_desc_ids``. - **Standard**: direct per-rank group filtering. """ - if plan.local_page_size > plan.remote_page_size: - specs = _build_gather_read_specs( - plan, local_block_ids, remote_block_ids - ) + if ( + plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block + ): + specs = _build_gather_read_specs(plan, local_block_ids, remote_block_ids) if logger.isEnabledFor(logging.DEBUG): for s in specs: for g in range(len(s.local_block_ids)): @@ -213,10 +216,8 @@ def _compute_read_specs_from_plan( ) return specs - remote_block_ids, local_block_ids = ( - _remap_remote_blocks_to_desc_ids( - plan, remote_block_ids, local_block_ids - ) + remote_block_ids, local_block_ids = _remap_remote_blocks_to_desc_ids( + plan, remote_block_ids, local_block_ids ) num_groups = len(local_block_ids) @@ -291,10 +292,12 @@ def _build_local_splits_from_plan( else 0 ) + fa_slot_map = plan.rank_to_attention_slot[0] + result: list[list[tuple[int, int, int]]] = [] for p_idx, p_rank in enumerate(plan.all_source_ranks): - fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) + fa_slot = fa_slot_map.get(p_rank, 0) handle: list[tuple[int, int, int]] = [] for j, (addr, local_len, dev) in enumerate(src_blocks_data): @@ -1267,8 +1270,7 @@ def add_remote_agent( elif self._is_hetero_attn: remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ()) group_spec_types = tuple( - type(g.kv_cache_spec) - for g in self.kv_cache_config.kv_cache_groups + type(g.kv_cache_spec) for g in self.kv_cache_config.kv_cache_groups ) assert len(remote_tpb) == len(group_spec_types), ( f"Remote tokens_per_block_per_group length " @@ -1276,11 +1278,11 @@ def add_remote_agent( ) logger.info( "[HeteroTP] Generating Gemma4 plan: " - "group_kinds=%s, total_kv_heads_per_group=%s, " + "group_spec_types=%s, total_kv_heads_per_group=%s, " "local_tpb_per_group=%s, remote_tpb_per_group=%s, " "local_block_len_per_layer=%s, remote_block_lens=%s, " "local_tp=%d, remote_tp=%d, tp_rank=%d", - [k.value for k in self._group_kinds], + [t.__name__ for t in group_spec_types], self._total_kv_heads_per_group, self._local_tokens_per_block_per_group, remote_tpb, @@ -1959,7 +1961,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): "[HeteroTP _read_blocks_for_req] req=%s engine=%s " "tp_ratio=%d, n_read_specs=%d, " "plan: local_page=%d, remote_page=%d, " - "group_kinds=%s, " + "group_spec_types=%s, " "local_physical_block_ids=[%s], " "remote_block_ids=[%s]", req_id, @@ -1968,7 +1970,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): len(read_specs), plan.local_page_size, plan.remote_page_size, - [k.value for k in plan.group_kinds], + [t.__name__ for t in plan.group_spec_types], ", ".join( f"g{i}:{meta.local_physical_block_ids[i][:5]}" for i in range(len(meta.local_physical_block_ids)) From 80487f10611513efbdbc31c92c6baac1625c7747 Mon Sep 17 00:00:00 2001 From: Zhanqiu Hu Date: Wed, 29 Apr 2026 22:25:56 -0400 Subject: [PATCH 49/49] fix Signed-off-by: Zhanqiu Hu --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 68cafed2dc8f..2466293468e4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1354,6 +1354,7 @@ def add_remote_agent( ### (Optional) Register local agent memory regions. if ( plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block and engine_id not in self._gather_read_handles ): # Gather-read: local page > remote page. Register local @@ -2158,16 +2159,22 @@ def _read_blocks( # Get descs ids. Both calls use the same plan since region counts # (len(fa_regions), len(ssm_regions)) are model-determined and # identical across engines. - if plan.remote_page_size > plan.local_page_size: - # Split-read: each remote block → multiple descriptors. + if ( + plan.remote_page_size > plan.local_page_size + and plan.local_blocks_per_remote_block + ): + # Split-read (Gemma4): each remote block → multiple descriptors. remote_desc_blocks = ( self.dst_num_blocks[dst_engine_id] * plan.remote_page_size // plan.local_page_size ) local_desc_blocks = self.dst_num_blocks[self.engine_id] - elif plan.local_page_size > plan.remote_page_size: - # Gather-read: each local block → multiple descriptors. + elif ( + plan.local_page_size > plan.remote_page_size + and plan.remote_blocks_per_local_block + ): + # Gather-read (Gemma4): each local block → multiple descriptors. remote_desc_blocks = self.dst_num_blocks[dst_engine_id] local_desc_blocks = ( self.dst_num_blocks[self.engine_id]