Skip to content

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology#39529

Merged
NickLucche merged 11 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-tpkv-transferconfig-unification
Apr 20, 2026
Merged

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology#39529
NickLucche merged 11 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-tpkv-transferconfig-unification

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Apr 10, 2026

Summary

  • Introduces TransferTopology class that unifies TpKVTopology and HeteroTPTransferConfig into a single source of truth
  • Adds EngineTransferInfo / MambaEngineTransferInfo frozen dataclasses for per-remote-engine transfer state
  • Migrates worker.py to use TransferTopology, removing redundant _tp_size, _block_size, _transfer_configs dicts
  • Deletes HeteroTPTransferConfig class entirely

Depends on #39354

New: EngineTransferInfo (frozen dataclass)

Stores per-remote-engine transfer facts, computed once at handshake:

Field Description
remote_tp_size Remote engine's tensor parallel size
remote_block_len Remote block length in bytes
remote_block_size Remote tokens per block
remote_physical_blocks_per_logical Physical-to-logical block ratio

New: MambaEngineTransferInfo (inherits EngineTransferInfo)

Extends with Mamba-hybrid transfer geometry:

Field Description
remote_fa_source_ranks Remote ranks carrying unique FA heads for this local rank
remote_all_source_ranks All remote ranks to read from (FA + Mamba union)
remote_num_fa_reads Number of distinct remote ranks needed for FA data
remote_num_mamba_reads Number of distinct remote ranks needed for Mamba data
remote_fa_descriptor_bytes Byte size of one FA K/V descriptor entry
is_remote_replicated Whether remote TP > total KV heads
remote_physical_heads Physical KV heads stored per remote rank

New: TransferTopology (replaces TpKVTopology + HeteroTPTransferConfig)

Local info (one copy, set at init):

Field Description
tp_rank, tp_size, block_size Local TP identity
engine_id Local engine ID
is_mla, is_mamba, total_num_kv_heads Model topology flags
local_physical_heads max(1, total_num_kv_heads // tp_size)
is_kv_layout_blocks_first, cross_layers_blocks, split_k_and_v Layout detection flags

Remote info (dict, populated via register_remote_engine()):

Field Description
_engines: dict[EngineId, EngineTransferInfo] One entry per remote engine
_fa_source_sets, _fa_source_indices Mamba FA lookup caches (per engine)

Methods (unified from both old classes):

Category Methods
General tp_ratio(), block_size_ratio(), handshake_target_ranks(), target_remote_ranks(), is_kv_replicated(), replicates_kv_cache(), get_transfer_cache_regions()
Mamba-specific should_skip_fa(), fa_head_slot(), fa_rank_offset(), needs_split_handles(), compute_split_handle_data(), filter_block_ids_for_rank(), describe_mamba()
Internal register_remote_engine(), get_engine_info(), _get_mamba_info(), _build_mamba_info()

worker.py migration

Before After
self.kv_topo: TpKVTopology self.transfer_topo: TransferTopology
self._tp_size: dict Removed — in EngineTransferInfo.remote_tp_size
self._block_size: dict Removed — in EngineTransferInfo.remote_block_size
self._transfer_configs: dict Removed — absorbed into MambaEngineTransferInfo
3 scattered dict updates + HeteroTPTransferConfig() Single register_remote_engine() call
compute_mamba_phys_ratio() Renamed to compute_physical_blocks_per_logical()

Deleted

  • HeteroTPTransferConfig class (entire class, ~300 lines in utils.py)

What's NOT changed (kept for now)

  • TpKVTopology still exists — used by mooncake_connector.py and tests
  • _physical_blocks_per_logical dict still in worker.py (duplicates info for local engine)
  • Model-specific logic still lives in TransferTopology (Phase 2 will extract into policy)

Change Log

Annotated diff: utils.py (click to expand)

Annotated Diff: Full new code in utils.py

File: vllm/distributed/kv_transfer/kv_connector/utils.py
Branch: nixl-tpkv-transferconfig-unification
Base: PR #39354 (9cd664152)

All changes are pure additions appended after line 940. No existing code modified.

Legend (on each line or block):

  • [NEW] — New code, no existing equivalent
  • [COPIED] — Logic identical to existing code
  • [RENAMED] — Same logic, variable/field names changed (d_/p_ → local_/remote_)
  • [SIG CHANGE] — Signature changed but body logic identical

Each section shows the original source being compared to.

Note: is_kv_replicated and replicates_kv_cache currently take remote_engine_id
but only do a single lookup (remote_tp_size > total_num_kv_heads). They could be
reverted to take remote_tp_size: int for consistency with tp_ratio/block_size_ratio.
Left as-is for now — can be cleaned up in a follow-up.


EngineTransferInfo (lines 946–962)

Source: [NEW] — fields previously scattered as dicts in worker.py

# ---- Per-engine transfer info ----


@dataclass(frozen=True)
class EngineTransferInfo:                                          # [NEW]
    """Common per-remote-engine transfer state, computed at handshake.

    Stored per ``engine_id`` inside ``TransferTopology._engines``.
    """

    remote_tp_size: int                                            # [NEW] was: worker._tp_size[eid]

    remote_block_len: int                                          # [NEW] was: local var in worker
    """Block length (bytes)"""

    remote_block_size: int                                         # [NEW] was: worker._block_size[eid]
    """Tokens per block."""

    remote_physical_blocks_per_logical: int                        # [NEW] was: worker._mamba_phys_ratio[eid]
    """Physical blocks per logical block."""

MambaEngineTransferInfo (lines 965–994)

Source: [RENAMED] from HeteroTPTransferConfig derived fields (lines 580–609)

@dataclass(frozen=True)
class MambaEngineTransferInfo(EngineTransferInfo):                 # [NEW] inheritance structure
    """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry.

    For hybrid SSM+Attention models, FA and Mamba layers may require
    different numbers of reads from different remote ranks.  This
    dataclass captures that per-engine transfer plan.
    """

    remote_fa_source_ranks: tuple[int, ...]                        # [RENAMED] was: fa_read_targets: list[int]
    """Remote ranks carrying unique FA heads for this local rank."""

    remote_all_source_ranks: tuple[int, ...]                       # [RENAMED] was: transfer_targets: list[int]
    """All remote ranks this local rank reads from (FA + Mamba)."""

    remote_num_fa_reads: int                                       # [RENAMED] was: physical_fa_num_reads
    """Number of distinct remote ranks needed for FA data."""

    remote_num_mamba_reads: int                                    # [RENAMED] was: mamba_num_reads
    """Number of distinct remote ranks needed for Mamba data."""

    remote_fa_descriptor_bytes: int                                # [RENAMED] was: fa_entry_size
    """Byte size of one FA K (or V) descriptor entry."""

    is_remote_replicated: bool                                     # [RENAMED] was: is_p_replicated
    """Whether the remote engine has replicated KV heads
    (remote_tp_size > total_num_kv_heads)."""

    remote_physical_heads: int                                     # [RENAMED] was: p_physical_heads
    """Physical KV heads stored per remote rank."""

Fields moved out:

-    local_physical_heads: int      → now TransferTopology.local_physical_heads (set in __init__)

Old fields NOT carried over:

-    _fa_target_set: frozenset[int]      → now TransferTopology._fa_source_sets[eid]
-    _fa_target_index: dict[int, int]    → now TransferTopology._fa_source_indices[eid]
-    def __post_init__(self)             → now TransferTopology._build_mamba_info()

TransferTopology class + __init__ (lines 1000–1075)

Source: __init__ params from TpKVTopology class attrs (lines 323–337);
layout detection from TpKVTopology.__post_init__ (lines 339–378)

# ---- Transfer topology ----


class TransferTopology:                                            # [NEW] class (replaces TpKVTopology + HeteroTPTransferConfig)
    """Single source of truth for local TP identity and per-engine remote info.

    Replaces the combination of ``TpKVTopology`` (local identity + layout
    detection + per-engine dicts) and ``HeteroTPTransferConfig`` (Mamba
    transfer geometry) with a single object per worker.
    """

    def __init__(
        self,
        tp_rank: int,                                              # [COPIED]  was: TpKVTopology.tp_rank
        tp_size: int,                                              # [NEW]     was: property from remote_tp_size[self.engine_id]
        block_size: int,                                           # [NEW]     was: property from remote_block_size[self.engine_id]
        engine_id: EngineId,                                       # [COPIED]  was: TpKVTopology.engine_id
        is_mla: bool,                                              # [COPIED]  was: TpKVTopology.is_mla
        is_mamba: bool,                                            # [COPIED]  was: TpKVTopology.is_mamba
        total_num_kv_heads: int,                                   # [COPIED]  was: TpKVTopology.total_num_kv_heads
        attn_backends: list[type[AttentionBackend]],               # [COPIED]  was: TpKVTopology.attn_backends
        tensor_shape: torch.Size | None = None,                    # [COPIED]  was: TpKVTopology.tensor_shape
    ):
        self.tp_rank = tp_rank                                     # [COPIED]
        self.tp_size = tp_size                                     # [NEW]  was: @property
        self.block_size = block_size                               # [NEW]  was: @property
        self.engine_id = engine_id                                 # [COPIED]
        self.is_mla = is_mla                                      # [COPIED]
        self.is_mamba = is_mamba                                   # [COPIED]
        self.total_num_kv_heads = total_num_kv_heads               # [COPIED]
        self.attn_backends = attn_backends                         # [COPIED]
        self.tensor_shape = tensor_shape                           # [COPIED]

        self.local_physical_heads = max(1, total_num_kv_heads // tp_size)  # [NEW] moved from MambaEngineTransferInfo
        self._engines: dict[EngineId, EngineTransferInfo] = {}     # [NEW]  replaces remote_tp_size/remote_block_size dicts
        # FA source lookup caches (Mamba only, built in register_remote_engine)
        self._fa_source_sets: dict[EngineId, frozenset[int]] = {}  # [NEW]  replaces HeteroTP._fa_target_set
        self._fa_source_indices: dict[EngineId, dict[int, int]] = {}  # [NEW]  replaces HeteroTP._fa_target_index
        # ---- Layout detection (from TpKVTopology.__post_init__) ----
        # Figure out whether the first dimension of the cache is K/V
        # or num_blocks.
        attn_backend = attn_backends[0]                            # [COPIED]
        if not is_mamba:                                           # [COPIED]
            _MOCK_BLOCK_SIZE = 16                                  # [COPIED]
            kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(  # [COPIED]
                num_blocks=1,                                      # [COPIED]
                block_size=_MOCK_BLOCK_SIZE,                       # [COPIED]
                num_kv_heads=1,                                    # [COPIED]
                head_size=1,                                       # [COPIED]
            )                                                      # [COPIED]
            logger.debug("Test kv_cache_shape: %s", kv_cache_shape)  # [COPIED]
        # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
        # we just mock num_blocks to 1 for the dimension check below.
        # Hybrid SSM models assume a single blocks_first layout
        self._is_kv_layout_blocks_first = is_mamba or (            # [COPIED]
            len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1    # [COPIED]
        )                                                          # [COPIED]

        self._cross_layers_blocks = False                          # [COPIED]
        if tensor_shape is not None:                               # [COPIED]
            self._cross_layers_blocks = (                          # [COPIED]
                len(tensor_shape) == len(kv_cache_shape) + 1       # [COPIED]
            )                                                      # [COPIED]

        if self._cross_layers_blocks:                              # [COPIED]
            logger.debug("Using cross-layer KV cache")             # [COPIED]
            _MOCK_NUM_LAYERS = 80                                  # [COPIED]
            kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape  # [COPIED]
            try:                                                   # [COPIED]
                kv_cache_stride_order = (                          # [COPIED]
                    attn_backend.get_kv_cache_stride_order(        # [COPIED]
                        include_num_layers_dimension=self._cross_layers_blocks  # [COPIED]
                    )                                              # [COPIED]
                )                                                  # [COPIED]
            except (AttributeError, NotImplementedError):          # [COPIED]
                assert tensor_shape is not None                    # [COPIED]
                kv_cache_stride_order = tuple(range(len(tensor_shape)))  # [COPIED]
            kv_cache_shape = tuple(                                # [COPIED]
                kv_cache_shape[i] for i in kv_cache_stride_order   # [COPIED]
            )                                                      # [COPIED]

register_remote_engine (lines 1081–1133)

Source: [NEW] — replaces scattered dict updates in worker.py

    def register_remote_engine(                                    # [NEW]
        self,
        remote_engine_id: EngineId,
        remote_tp_size: int,
        remote_block_size: int,
        remote_block_len: int,
        remote_physical_blocks_per_logical: int,
        *,
        local_block_len: int = 0,
    ) -> EngineTransferInfo:
        """Register a remote engine, replacing scattered worker dicts.

        Only remote engines should be registered here — the local engine's
        identity (tp_size, block_size, etc.) is set via ``__init__`` params.
        """
        assert remote_engine_id != self.engine_id, (               # [NEW] guard
            ...
        )
        info: EngineTransferInfo                                   # [NEW] type annotation for mypy
        if self.is_mamba:
            info = self._build_mamba_info(...)                     # [NEW]
            self._fa_source_sets[remote_engine_id] = frozenset(...)  # [NEW]
            self._fa_source_indices[remote_engine_id] = {...}      # [NEW]
        else:
            info = EngineTransferInfo(...)                          # [NEW]
        self._engines[remote_engine_id] = info                     # [NEW]
        return info

    def get_engine_info(                                           # [NEW]
        self, remote_engine_id: EngineId
    ) -> EngineTransferInfo:
        return self._engines[remote_engine_id]

Layout properties (lines 1139–1152)

Source: [COPIED] from TpKVTopology (lines 380–401)

    @property                                                      # [COPIED]
    def is_kv_layout_blocks_first(self) -> bool:                   # [COPIED]
        return self._is_kv_layout_blocks_first                     # [COPIED]

    @property                                                      # [COPIED]
    def cross_layers_blocks(self) -> bool:                         # [COPIED]
        return self._cross_layers_blocks                           # [COPIED]

    @property                                                      # [COPIED]
    def split_k_and_v(self) -> bool:                               # [COPIED]
        return not (                                               # [COPIED]
            self._cross_layers_blocks                              # [COPIED]
            or self.is_mla                                         # [COPIED]
            or self.is_kv_layout_blocks_first                      # [COPIED]
        )                                                          # [COPIED]

tp_ratio (lines 1158–1175)

Source: [COPIED] from TpKVTopology.tp_ratio (lines 403–427)

    def tp_ratio(self, remote_tp_size: int) -> int:                # [COPIED] identical signature
        """Calculate the tensor parallel ratio between local and remote TP.

        Positive when local_tp >= remote_tp (local workers read from the
        same remote worker in groups of size ``tp_ratio``).  Negative when
        remote_tp > local_tp (ratio is flipped).
        """
        if self.tp_size >= remote_tp_size:                         # [COPIED]
            assert self.tp_size % remote_tp_size == 0, (...)       # [COPIED]
            return self.tp_size // remote_tp_size                  # [COPIED]
        assert remote_tp_size % self.tp_size == 0, (...)           # [COPIED]
        return -(remote_tp_size // self.tp_size)                   # [COPIED]

block_size_ratio (lines 1177–1183)

Source: [COPIED] from TpKVTopology.block_size_ratio (lines 429–440)

    def block_size_ratio(self, remote_block_size: int) -> int:     # [COPIED] identical signature
        """Calculate the block size ratio between local and remote."""
        assert self.block_size % remote_block_size == 0, (...)     # [COPIED]
        return self.block_size // remote_block_size                # [COPIED]

is_kv_replicated (lines 1185–1189)

Source: [COPIED] from TpKVTopology.is_kv_replicated (lines 456–464)

    def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:  # [COPIED]
        return (                                                   # [COPIED]
            self._engines[remote_engine_id].remote_tp_size         # [COPIED] was: self.remote_tp_size[engine_id]
            > self.total_num_kv_heads                              # [COPIED]
        )

replicates_kv_cache (lines 1191–1193)

Source: [COPIED] from TpKVTopology.replicates_kv_cache (lines 466–468)

    def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:  # [COPIED]
        return self.is_mla or self.is_kv_replicated(remote_engine_id)  # [COPIED]

handshake_target_ranks (lines 1195–1205)

Source: [NEW] — extracted from TpKVTopology.get_target_remote_ranks (lines 470–485)

    def handshake_target_ranks(self, remote_tp_size: int) -> list[int]:  # [NEW]
        """Pre-registration: compute which remote TP ranks to handshake with.

        Pure math based on local/remote TP sizes — does not require
        the remote engine to be registered yet.
        """
        tp_ratio = self.tp_ratio(remote_tp_size)                   # [COPIED] logic from get_target_remote_ranks
        if tp_ratio > 0:                                           # [COPIED]
            return [self.tp_rank // tp_ratio]                      # [COPIED]
        abs_ratio = -tp_ratio                                      # [COPIED]
        return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]  # [COPIED]

target_remote_ranks (lines 1207–1224)

Source: [SIG CHANGE] from TpKVTopology.get_target_remote_ranks (lines 470–485),
with [NEW] Mamba branch

    def target_remote_ranks(                                       # [SIG CHANGE] was: get_target_remote_ranks(self, remote_tp_size: int)
        self, remote_engine_id: EngineId
    ) -> list[int]:
        """Get the remote TP rank(s) that the current local TP rank will
        read from.

        For Mamba models, returns the precomputed ``all_source_ranks``
        (FA + Mamba union).
        """
        info = self._engines[remote_engine_id]                     # [NEW]
        if isinstance(info, MambaEngineTransferInfo):              # [NEW] *** Mamba branch ***
            return list(info.remote_all_source_ranks)              # [NEW]

        tp_ratio = self.tp_ratio(info.remote_tp_size)              # [COPIED] uses info lookup
        if tp_ratio > 0:                                           # [COPIED]
            return [self.tp_rank // tp_ratio]                      # [COPIED]
        # remote TP > local TP: read from |tp_ratio| remote workers
        abs_ratio = -tp_ratio                                      # [COPIED]
        return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]  # [COPIED]

get_transfer_cache_regions (lines 1226–1249)

Source: [COPIED] from TpKVTopology.get_transfer_cache_regions (lines 494–516)

    def get_transfer_cache_regions(                                # [COPIED]
        self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
    ) -> list[torch.Tensor] | torch.Tensor:
        ...                                                        # [COPIED] — all logic identical

_get_mamba_info (lines 1258–1268)

Source: [NEW] — helper to type-check engine info lookup

    def _get_mamba_info(                                           # [NEW]
        self, remote_engine_id: EngineId
    ) -> MambaEngineTransferInfo:
        assert self.is_mamba, (...)                                # [NEW] guard
        info = self._engines[remote_engine_id]                     # [NEW]
        assert isinstance(info, MambaEngineTransferInfo), (...)    # [NEW]
        return info                                                # [NEW]

should_skip_fa (lines 1270–1272)

Source: [SIG CHANGE] from HeteroTPTransferConfig.should_skip_fa (lines 733–735)

    def should_skip_fa(                                            # [SIG CHANGE] added remote_engine_id
        self, remote_engine_id: EngineId, remote_rank: int
    ) -> bool:
        return remote_rank not in self._fa_source_sets[remote_engine_id]  # [RENAMED]

fa_head_slot (lines 1274–1292)

Source: [SIG CHANGE] from HeteroTPTransferConfig.fa_head_slot (lines 737–752)

    def fa_head_slot(                                              # [SIG CHANGE] added remote_engine_id
        self, remote_engine_id: EngineId, remote_rank: int
    ) -> int:
        fa_index = self._fa_source_indices[remote_engine_id]       # [RENAMED] was: self._fa_target_index
        if remote_rank in fa_index:                                # [RENAMED]
            return fa_index[remote_rank]                           # [RENAMED]
        mamba_info = self._get_mamba_info(remote_engine_id)        # [NEW] lookup
        K = self.total_num_kv_heads                                # [RENAMED] was: self.K
        remote_tp = mamba_info.remote_tp_size                      # [RENAMED] was: self.p_tp
        r_head = _physical_head_range(remote_tp, K, remote_rank)   # [RENAMED]
        for target in mamba_info.remote_fa_source_ranks:           # [RENAMED] was: self.fa_read_targets
            t_head = _physical_head_range(remote_tp, K, target)    # [RENAMED]
            if _range_overlap(r_head, t_head):                     # [COPIED]
                return fa_index[target]                            # [RENAMED]
        return 0                                                   # [COPIED]

fa_rank_offset (lines 1294–1315)

Source: [SIG CHANGE] from HeteroTPTransferConfig.fa_rank_offset (lines 754–770)

    def fa_rank_offset(                                            # [SIG CHANGE] added remote_engine_id
        self, remote_engine_id: EngineId, remote_kv_block_len: int
    ) -> int:
        mamba_info = self._get_mamba_info(remote_engine_id)        # [NEW] lookup
        tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)        # [NEW] computed; was: self.tp_ratio (stored field)
        if self.is_mla or tp_ratio <= 0:                           # [RENAMED] was: self.use_mla
            return 0                                               # [COPIED]
        K = self.total_num_kv_heads                                # [RENAMED] was: self.K
        is_local_replicated = self.tp_size > K                     # [RENAMED] was: self.is_d_replicated
        if is_local_replicated:                                    # [RENAMED]
            local_head = self.tp_rank * K // self.tp_size          # [RENAMED] was: self.d_rank * self.K // self.d_tp
            p_rank = mamba_info.remote_fa_source_ranks[0]          # [RENAMED] was: self.fa_read_targets[0]
            p_start = p_rank * K // mamba_info.remote_tp_size      # [RENAMED]
            return (local_head - p_start) * remote_kv_block_len    # [RENAMED]
        return self.tp_rank % tp_ratio * remote_kv_block_len       # [RENAMED]

needs_split_handles (lines 1317–1329)

Source: [SIG CHANGE] from HeteroTPTransferConfig.needs_split_handles property (lines 772–779)

    def needs_split_handles(                                       # [SIG CHANGE] was: @property
        self, remote_engine_id: EngineId
    ) -> bool:
        mamba_info = self._get_mamba_info(remote_engine_id)        # [NEW] lookup
        tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)        # [NEW] computed; was: self.tp_ratio (stored)
        return (                                                   # [COPIED]
            tp_ratio < 0                                           # [RENAMED] was: self.tp_ratio < 0
            and not self.is_mla                                    # [RENAMED] was: not self.use_mla
            and len(mamba_info.remote_all_source_ranks) > 1        # [RENAMED] was: len(self.transfer_targets) > 1
        )

compute_split_handle_data (lines 1331–1360)

Source: [SIG CHANGE] from HeteroTPTransferConfig.compute_split_handle_data (lines 781–810)

    def compute_split_handle_data(                                 # [SIG CHANGE] added remote_engine_id
        self,
        remote_engine_id: EngineId,
        src_blocks_data: list[tuple[int, int, int]],
        num_fa_descs: int,
        abs_tp: int,
    ) -> list[list[tuple[int, int, int]]]:
        mamba_info = self._get_mamba_info(remote_engine_id)        # [NEW] lookup
        all_handle_data: list[list[tuple[int, int, int]]] = []     # [COPIED]
        for p_idx, p_rank in enumerate(                            # [RENAMED] was: self.transfer_targets
            mamba_info.remote_all_source_ranks
        ):
            handle_data: list[tuple[int, int, int]] = []           # [COPIED]
            skip_fa = self.should_skip_fa(remote_engine_id, p_rank)  # [RENAMED]
            fa_slot = (                                            # [COPIED]
                self.fa_head_slot(remote_engine_id, p_rank)        # [RENAMED]
                if not skip_fa else 0
            )
            for j, (addr, local_len, dev) in enumerate(src_blocks_data):  # [COPIED]
                if j < num_fa_descs:                               # [COPIED]
                    assert mamba_info.remote_num_fa_reads >= 1     # [RENAMED]
                    fa_chunk = local_len // mamba_info.remote_num_fa_reads  # [RENAMED]
                    handle_data.append(                            # [COPIED]
                        (addr + fa_slot * fa_chunk, fa_chunk, dev)
                    )
                else:                                              # [COPIED]
                    mamba_chunk = local_len // abs_tp               # [COPIED]
                    handle_data.append(                            # [COPIED]
                        (addr + p_idx * mamba_chunk, mamba_chunk, dev)
                    )
            all_handle_data.append(handle_data)                    # [COPIED]
        return all_handle_data                                     # [COPIED]

filter_block_ids_for_rank (lines 1362–1385)

Source: [SIG CHANGE] from HeteroTPTransferConfig.filter_block_ids_for_rank (lines 812–834)

    def filter_block_ids_for_rank(                                 # [SIG CHANGE] added remote_engine_id
        self,
        remote_engine_id: EngineId,
        remote_rank: int,
        local_ids: BlockIds,
        remote_ids: BlockIds,
        is_mamba_group: list[bool],
    ) -> tuple[BlockIds, BlockIds]:
        if not self.should_skip_fa(remote_engine_id, remote_rank): # [RENAMED]
            return local_ids, remote_ids                           # [COPIED]
        num_groups = len(local_ids)                                # [COPIED]
        filtered_local: list[list[int]] = [                        # [COPIED]
            [] if not is_mamba_group[g] else local_ids[g]
            for g in range(num_groups)
        ]
        filtered_remote: list[list[int]] = [                       # [COPIED]
            [] if not is_mamba_group[g] else remote_ids[g]
            for g in range(num_groups)
        ]
        return filtered_local, filtered_remote                     # [COPIED]

describe_mamba (lines 1387–1403)

Source: [SIG CHANGE] from HeteroTPTransferConfig.describe (lines 836–848)

    def describe_mamba(                                            # [SIG CHANGE] was: describe(self) -> str
        self, remote_engine_id: EngineId
    ) -> str:
        mamba_info = self._get_mamba_info(remote_engine_id)        # [NEW] lookup
        return (
            f"TransferTopology.mamba("                             # [RENAMED] was: f"HeteroTPTransferConfig("
            f"tp_ratio={self.tp_ratio(mamba_info.remote_tp_size)}, "  # [RENAMED] uses method call
            f"K={self.total_num_kv_heads}, "                       # [RENAMED] was: self.K
            f"local_tp={self.tp_size}, "                           # [RENAMED] was: d_tp={self.d_tp}
            f"remote_tp={mamba_info.remote_tp_size}, "             # [RENAMED] was: p_tp={self.p_tp}
            f"local_rank={self.tp_rank}, "                         # [RENAMED] was: d_rank={self.d_rank}
            f"fa_reads={mamba_info.remote_num_fa_reads}, "         # [RENAMED]
            f"mamba_reads={mamba_info.remote_num_mamba_reads}, "   # [RENAMED]
            f"fa_sources={list(mamba_info.remote_fa_source_ranks)}, "  # [RENAMED]
            f"all_sources={list(mamba_info.remote_all_source_ranks)}, "  # [RENAMED]
            f"fa_desc_bytes={mamba_info.remote_fa_descriptor_bytes}, "  # [RENAMED]
            f"remote_block_len={mamba_info.remote_block_len})"     # [RENAMED]
        )

_build_mamba_info (lines 1410–1531)

**Source: [SIG CHANGE] from HeteroTPTransferConfig.__post_init__ (lines 611–687)

  • HeteroTPTransferConfig._validate (lines 689–729)**
    def _build_mamba_info(                                         # [SIG CHANGE] was: __post_init__(self)
        self,
        remote_tp_size: int,                                       # [NEW] was: self.p_tp
        remote_block_size: int,                                    # [NEW] was: not in HeteroTP
        remote_block_len: int,                                     # [NEW] was: self.p_block_len
        remote_physical_blocks_per_logical: int,                   # [NEW] was: not in HeteroTP
        local_block_len: int,                                      # [NEW] was: self.d_block_len
    ) -> MambaEngineTransferInfo:                                  # [NEW] returns dataclass
        """Compute Mamba transfer plan."""
        K = self.total_num_kv_heads                                # [RENAMED] was: K = self.K
        local_tp = self.tp_size                                    # [RENAMED] was: self.d_tp
        local_rank = self.tp_rank                                  # [RENAMED] was: self.d_rank

        is_remote_replicated = remote_tp_size > K                  # [RENAMED]
        remote_physical_heads = max(1, K // remote_tp_size)        # [RENAMED]

        if local_tp >= remote_tp_size:                             # [RENAMED]
            assert local_tp % remote_tp_size == 0                  # [COPIED]
            tp_ratio = local_tp // remote_tp_size                  # [COPIED]
        else:                                                      # [COPIED]
            assert remote_tp_size % local_tp == 0                  # [COPIED]
            tp_ratio = -(remote_tp_size // local_tp)               # [COPIED]

        abs_tp = -tp_ratio if tp_ratio < 0 else 1                  # [COPIED]

        mamba_range: range | None = None                           # [COPIED]
        if tp_ratio < 0:                                           # [COPIED]
            mamba_range = range(                                   # [COPIED]
                local_rank * abs_tp, (local_rank + 1) * abs_tp     # [RENAMED]
            )

        # ---- FA read targets ----
        if self.is_mla or tp_ratio >= 0:                           # [RENAMED] was: self.use_mla or self.tp_ratio >= 0
            num_fa_reads = 1                                       # [RENAMED]
            fa_source_ranks: list[int] = (                         # [RENAMED] was: self.fa_read_targets
                [0] if self.is_mla                                 # [RENAMED]
                else [local_rank // tp_ratio if tp_ratio > 0 else local_rank]  # [RENAMED]
            )
        else:
            local_needs = _physical_head_range(local_tp, K, local_rank)  # [RENAMED]
            search_range = (                                       # [COPIED]
                mamba_range if mamba_range is not None
                else range(remote_tp_size)                         # [RENAMED]
            )
            seen: set[tuple[int, int]] = set()                     # [COPIED]
            fa_source_ranks = []                                   # [RENAMED]
            for p in search_range:                                 # [COPIED]
                p_has = _physical_head_range(remote_tp_size, K, p) # [RENAMED]
                ov = _range_overlap(local_needs, p_has)            # [RENAMED]
                if len(ov) > 0:                                    # [COPIED]
                    key = (ov.start, ov.stop)                      # [COPIED]
                    if key not in seen:                             # [COPIED]
                        seen.add(key)                              # [COPIED]
                        fa_source_ranks.append(p)                  # [RENAMED]
            if not fa_source_ranks:                                # [RENAMED]
                for p in range(remote_tp_size):                    # [RENAMED]
                    ...                                            # [COPIED] (same fallback loop)
            num_fa_reads = len(fa_source_ranks)                    # [RENAMED]

        # ---- All source ranks (mamba + FA) ----
        if mamba_range is not None and abs_tp > num_fa_reads:      # [RENAMED]
            num_mamba_reads = abs_tp                               # [RENAMED]
            all_source_ranks = list(mamba_range)                   # [RENAMED]
        else:
            num_mamba_reads = num_fa_reads                         # [RENAMED]
            all_source_ranks = list(fa_source_ranks)               # [RENAMED]

        # ---- FA descriptor bytes ----
        effective_block_len = min(local_block_len, remote_block_len)  # [RENAMED]
        if self.is_kv_layout_blocks_first:                         # [RENAMED] was: self.is_blocks_first
            fa_descriptor_bytes = effective_block_len // 2          # [RENAMED]
        else:
            fa_descriptor_bytes = effective_block_len               # [RENAMED]

        # ---- Validation (from HeteroTPTransferConfig._validate) ----
        ...                                                        # [RENAMED] all variable names, logic identical

        return MambaEngineTransferInfo(                            # [NEW] returns dataclass
            remote_tp_size=remote_tp_size,                         # [RENAMED]
            remote_block_len=remote_block_len,                     # [RENAMED]
            remote_block_size=remote_block_size,                   # [NEW] not in original HeteroTP
            remote_physical_blocks_per_logical=...,                # [NEW] not in original HeteroTP
            remote_fa_source_ranks=tuple(fa_source_ranks),         # [RENAMED]
            remote_all_source_ranks=tuple(all_source_ranks),       # [RENAMED]
            remote_num_fa_reads=num_fa_reads,                      # [RENAMED]
            remote_num_mamba_reads=num_mamba_reads,                # [RENAMED]
            remote_fa_descriptor_bytes=fa_descriptor_bytes,        # [RENAMED]
            is_remote_replicated=is_remote_replicated,             # [RENAMED]
            remote_physical_heads=remote_physical_heads,           # [RENAMED]
        )
Annotated diff: worker.py (click to expand)

Annotated Diff: worker.py — All Changes (Consolidated)

File: vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Branch: nixl-tpkv-transferconfig-unification
Base: PR #39354 (9cd664152)

All changes are in-place modifications to migrate worker.py from
TpKVTopology + HeteroTPTransferConfig to the unified TransferTopology.

Legend:

  • [RENAME] — Variable/method name change, no logic change
  • [REPLACE] — Swapped to new API, equivalent behavior
  • [DELETE] — Code removed (consolidated into TransferTopology)
  • [NEW] — New code not in original

1. Imports

# BEFORE:
from vllm.distributed.kv_transfer.kv_connector.utils import (
    BlockIds,
    EngineId,
    HeteroTPTransferConfig,       # [DELETE]
    TpKVTopology,                 # [DELETE]
    ...
)

# AFTER:
from vllm.distributed.kv_transfer.kv_connector.utils import (
    BlockIds,
    EngineId,
    MambaEngineTransferInfo,      # [NEW] for isinstance checks in logging
    TransferTopology,             # [REPLACE] replaces both deleted imports
    ...
)

2. Instance variable declarations (__init__)

# BEFORE:
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {}   # [DELETE]
self.kv_topo: TpKVTopology | None = None                         # [RENAME]
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}  # [DELETE]
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}  # [DELETE]
self._mamba_phys_ratio: dict[EngineId, int] = {}                 # [RENAME]

# AFTER:
self.transfer_topo: TransferTopology | None = None                # [RENAME]
self._physical_blocks_per_logical: dict[EngineId, int] = {}      # [RENAME]
# _transfer_configs, _tp_size, _block_size are all gone —
# TransferTopology._engines holds per-engine info internally.

3. Dead write in _setup_hma removed

# BEFORE:
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size    # [DELETE]
self.num_blocks *= self._physical_blocks_per_logical_kv_block

# AFTER:
self.block_size = kernel_block_size
# dead write removed — local block_size is passed to TransferTopology.__init__
self.num_blocks *= self._physical_blocks_per_logical_kv_block

4. Handshake: get_target_remote_rankshandshake_target_ranks

# BEFORE:
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)

# AFTER:
assert self.transfer_topo is not None
p_remote_ranks = self.transfer_topo.handshake_target_ranks(remote_tp_size)

5. register_kv_caches: TpKVTopologyTransferTopology construction

# BEFORE:
self.kv_topo = TpKVTopology(
    tp_rank=self.tp_rank,
    engine_id=self.engine_id,
    remote_tp_size=self._tp_size,      # shared mutable dict
    remote_block_size=self._block_size, # shared mutable dict
    is_mla=self.use_mla,
    total_num_kv_heads=...,
    attn_backends=...,
    kv_cache_spec=...,
    block_size=self.block_size,
    is_mamba=self._has_mamba,
)

# AFTER:
self.transfer_topo = TransferTopology(
    tp_rank=self.tp_rank,
    tp_size=self.world_size,           # [REPLACE] explicit int, not shared dict
    block_size=self.block_size,        # [REPLACE] explicit int, not shared dict
    engine_id=self.engine_id,
    is_mla=self.use_mla,
    total_num_kv_heads=...,
    attn_backends=...,
    kv_cache_spec=...,
    is_mamba=self._has_mamba,
)

6. All self.kv_topoself.transfer_topo renames

Every self.kv_topo reference is renamed to self.transfer_topo.
Every local kv_topo = self.kv_topo is renamed to transfer_topo = self.transfer_topo.

Affected locations (all [RENAME], no logic change):

Method What changed
register_kv_caches self.kv_topoself.transfer_topo
(descriptor registration loop) self.kv_topo._cross_layers_blocksself.transfer_topo._cross_layers_blocks
(descriptor registration loop) self.kv_topo.is_kv_layout_blocks_firstself.transfer_topo.is_kv_layout_blocks_first
_validate_remote_agent_handshake self.kv_topoself.transfer_topo
build_remote_descs kv_topo local var → transfer_topo
_build_fa_remote_for_mamba kv_topo param → transfer_topo
sync_recved_kv_to_device self.kv_topoself.transfer_topo
get_finished self.kv_topoself.transfer_topo
_get_new_notifs self.kv_topoself.transfer_topo
_read_blocks_for_req self.kv_topoself.transfer_topo
_read_blocks self.kv_topoself.transfer_topo
get_backend_aware_kv_block_len self.kv_topoself.transfer_topo

7. Remote engine registration replaces scattered dict updates

# BEFORE (in build_remote_descs):
if engine_id not in self._tp_size:
    self._tp_size[engine_id] = remote_tp_size
if engine_id not in self._block_size:
    self._block_size[engine_id] = nixl_agent_meta.block_size

# ... later ...
if self._has_mamba and engine_id not in self._transfer_configs:
    self._transfer_configs[engine_id] = HeteroTPTransferConfig(
        tp_ratio=tp_ratio,
        K=kv_topo.total_num_kv_heads,
        d_tp=self.world_size,
        p_tp=remote_tp_size,
        d_rank=self.tp_rank,
        use_mla=self.use_mla,
        d_block_len=self.block_len_per_layer[0],
        p_block_len=nixl_agent_meta.block_lens[0],
        is_blocks_first=kv_topo.is_kv_layout_blocks_first,
    )

# AFTER (single call, idempotent):                       [REPLACE]
physical_blocks_per_logical = (                            # [NEW] computed here
    compute_physical_blocks_per_logical(                    #   instead of later
        nixl_agent_meta.ssm_sizes,
        nixl_agent_meta.block_lens[0],
    )
    if self._has_mamba
    else 1
)
transfer_topo.register_remote_engine(
    remote_engine_id=engine_id,
    remote_tp_size=remote_tp_size,
    remote_block_size=nixl_agent_meta.block_size,
    remote_block_len=nixl_agent_meta.block_lens[0],
    remote_physical_blocks_per_logical=physical_blocks_per_logical,
    local_block_len=(self.block_len_per_layer[0] if self._has_mamba else 0),
)

Also, _physical_blocks_per_logical[engine_id] assignment was moved up to right after
register_remote_engine:

# BEFORE (was inside the descriptor registration block):
self._mamba_phys_ratio[engine_id] = compute_physical_blocks_per_logical(...)

# AFTER (immediately after register_remote_engine):       [REPLACE]
if self._has_mamba and engine_id not in self._physical_blocks_per_logical:
    self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical
    logger.info(
        "Mamba transfer plan: %s",
        transfer_topo.describe_mamba(engine_id),
    )

8. _validate_remote_agent_handshake: dict lookups → get_engine_info

# BEFORE:
assert self._tp_size[remote_engine_id] == remote_tp_size
assert self.kv_topo is not None

tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(remote_engine_id)

# AFTER:                                                  [REPLACE]
assert self.transfer_topo is not None
remote_info = self.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size

tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
block_size_ratio = self.transfer_topo.block_size_ratio(nixl_agent_meta.block_size)

Same pattern for is_kv_replicated and replicates_kv_cache — still take
engine_id (they need Mamba-specific lookup internally):

# BEFORE:
self.kv_topo.is_kv_replicated(remote_engine_id)
self.kv_topo.replicates_kv_cache(engine_id)

# AFTER (unchanged signature):
self.transfer_topo.is_kv_replicated(remote_engine_id)
self.transfer_topo.replicates_kv_cache(engine_id)

9. build_remote_descs: block_size_ratio lookup change

# BEFORE:
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)

# AFTER:                                                  [REPLACE]
block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size)
# BEFORE:
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)

# AFTER:                                                  [REPLACE]
tp_ratio = transfer_topo.tp_ratio(remote_tp_size)

10. Mamba split handles: transfer_cfg.method()transfer_topo.method(eid)

# BEFORE:
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
if transfer_cfg.needs_split_handles:
    for handle_data in transfer_cfg.compute_split_handle_data(
        self.src_blocks_data, self.num_descs, abs_tp
    ):
        ...

# AFTER:                                                  [REPLACE]
if transfer_topo.needs_split_handles(engine_id):
    for handle_data in transfer_topo.compute_split_handle_data(
        engine_id, self.src_blocks_data, self.num_descs, abs_tp
    ):
        ...

Logging also updated to use MambaEngineTransferInfo fields:

# BEFORE:
logger.info(
    "Mamba-HMA split handles: targets=%s, fa_reads=%s, ..."
    transfer_cfg.transfer_targets,
    transfer_cfg.physical_fa_num_reads,
    transfer_cfg.fa_entry_size,
    transfer_cfg.mamba_num_reads,
    ...
)

# AFTER:                                                  [REPLACE]
mamba_info = transfer_topo.get_engine_info(engine_id)
assert isinstance(mamba_info, MambaEngineTransferInfo)
logger.info(
    "Mamba-HMA split handles: targets=%s, fa_reads=%s, ..."
    mamba_info.remote_all_source_ranks,
    mamba_info.remote_num_fa_reads,
    mamba_info.remote_fa_descriptor_bytes,
    mamba_info.remote_num_mamba_reads,
    ...
)

11. _build_fa_remote_for_mamba signature change

# BEFORE:
def _build_fa_remote_for_mamba(
    self,
    nixl_agent_meta: NixlAgentMetadata,
    transfer_cfg: HeteroTPTransferConfig,      # [DELETE]
    block_size_ratio: int,
    kv_topo: TpKVTopology,                     # [DELETE]
) -> list[tuple[int, int, int]]:

# AFTER:
def _build_fa_remote_for_mamba(
    self,
    nixl_agent_meta: NixlAgentMetadata,
    block_size_ratio: int,
    transfer_topo: TransferTopology,            # [REPLACE] replaces both params
    remote_engine_id: EngineId,                 # [NEW] needed for lookups
) -> list[tuple[int, int, int]]:

Body changes:

# BEFORE:
tp_ratio = transfer_cfg.tp_ratio
...
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads
...
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len)

# AFTER:                                                  [REPLACE]
mamba_info = transfer_topo.get_engine_info(remote_engine_id)
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size)
...
local_block_len = local_block_len // mamba_info.remote_num_fa_reads
...
rank_offset = transfer_topo.fa_rank_offset(remote_engine_id, remote_kv_block_len)

Call site also updated:

# BEFORE:
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
blocks_data.extend(
    self._build_fa_remote_for_mamba(
        nixl_agent_meta, transfer_cfg, block_size_ratio, kv_topo,
    )
)

# AFTER:                                                  [REPLACE]
blocks_data.extend(
    self._build_fa_remote_for_mamba(
        nixl_agent_meta, block_size_ratio, transfer_topo, engine_id,
    )
)

12. get_finished: block_size_ratio_from_engine_id → raw int

# BEFORE:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
    meta.remote.engine_id
)

# AFTER:                                                  [REPLACE]
remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id)
block_size_ratio = self.transfer_topo.block_size_ratio(remote_info.remote_block_size)

13. _get_new_notifs: kv_topo.tp_ratio() rename only

# BEFORE:
tp_ratio = self.kv_topo.tp_ratio(n_consumers)

# AFTER:                                                  [RENAME]
tp_ratio = self.transfer_topo.tp_ratio(n_consumers)

14. _read_blocks_for_req: multiple API changes

# BEFORE:
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
    meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)

# AFTER:                                                  [REPLACE]
assert meta.remote is not None and self.transfer_topo is not None
engine_id = meta.remote.engine_id
remote_ranks = self.transfer_topo.target_remote_ranks(engine_id)
remote_info = self.transfer_topo.get_engine_info(engine_id)
tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size)
# BEFORE:
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]

# AFTER:                                                  [REPLACE]
remote_block_size = remote_info.remote_block_size
# BEFORE:
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id)
assert transfer_cfg is not None
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
    remote_rank, local_ids, remote_ids,
)

# AFTER:                                                  [REPLACE]
local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank(
    engine_id, remote_rank, local_ids, remote_ids,
)

15. _read_blocks: same pattern

# BEFORE:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)

# AFTER:                                                  [REPLACE]
remote_info = self.transfer_topo.get_engine_info(dst_engine_id)
block_size_ratio = self.transfer_topo.block_size_ratio(remote_info.remote_block_size)

Summary

Category Count Description
[RENAME] ~20 kv_topotransfer_topo, inforemote_info, _mamba_phys_ratio_physical_blocks_per_logical, compute_mamba_phys_ratiocompute_physical_blocks_per_logical
[REPLACE] ~15 API calls migrated to new signatures
[DELETE] 4 _transfer_configs, _tp_size, _block_size dicts, HeteroTPTransferConfig construction, dead _block_size write
[NEW] 2 register_remote_engine() call block, remote_engine_id param in _build_fa_remote_for_mamba

Net lines: roughly −20 (removed boilerplate dicts + HeteroTPTransferConfig
construction, replaced with single register_remote_engine call).

Logic changes: Zero. All computation paths are identical — only the API
surface changed (where data is stored and how it's accessed).

Test plan

  • NIXL connector test cases

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 10, 2026

Documentation preview: https://vllm--39529.org.readthedocs.build/en/39529/

@mergify mergify Bot added documentation Improvements or additions to documentation v1 kv-connector labels Apr 10, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 10, 2026
@ZhanqiuHu ZhanqiuHu changed the title nixl: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology nixl refactor: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology Apr 10, 2026
@ZhanqiuHu ZhanqiuHu changed the title nixl refactor: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology Apr 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the NIXL connector by modularizing its components into a new package structure and introducing a TransferTopology class to centralize transfer geometry logic. My review identified three critical issues: a potential NameError in TransferTopology when is_mamba is true, an incorrect registration of Mamba KV cache tensors that would exclude SSM state, and a missing import for compute_mamba_phys_ratio in the worker module.

Comment on lines +758 to +759
if tensor_shape is not None:
self._cross_layers_blocks = len(tensor_shape) == len(kv_cache_shape) + 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If is_mamba is True, kv_cache_shape is never defined (see line 741). This will cause a NameError here if tensor_shape is provided. Since hybrid SSM models do not yet support cross-layer layout (as noted in connector.py), this logic should be guarded by not is_mamba to avoid accessing the undefined variable.

Suggested change
if tensor_shape is not None:
self._cross_layers_blocks = len(tensor_shape) == len(kv_cache_shape) + 1
if not is_mamba and tensor_shape is not None:
self._cross_layers_blocks = len(tensor_shape) == len(kv_cache_shape) + 1

Comment thread vllm/distributed/kv_transfer/kv_connector/utils.py
assert self.transfer_topo is not None
transfer_topo = self.transfer_topo
physical_blocks_per_logical = (
compute_mamba_phys_ratio(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

compute_mamba_phys_ratio is used here but it is not imported in this file. This will cause a NameError at runtime when registering a remote engine for a Mamba model. Please ensure it is imported from vllm.distributed.kv_transfer.kv_connector.utils.

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-tpkv-transferconfig-unification branch 2 times, most recently from 6988b62 to 0a6585a Compare April 13, 2026 13:55
@mergify mergify Bot removed the needs-rebase label Apr 13, 2026
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review April 13, 2026 14:41
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reviewed offline

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you did not delete TpKVTopology, which in turn leads to changes to mooncake and test files. We should be able to just replace it safely though right?

I left some comments for now, will continue asap

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
remote_block_size=nixl_agent_meta.block_size,
remote_block_len=nixl_agent_meta.block_lens[0],
remote_physical_blocks_per_logical=physical_blocks_per_logical,
local_block_len=(self.block_len_per_layer[0] if self._has_mamba else 0),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sort of logic is already handled in the register_remote_engine function so you can just do

Suggested change
local_block_len=(self.block_len_per_layer[0] if self._has_mamba else 0),
local_block_len=self.block_len_per_layer[0],

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @ZhanqiuHu !
Looking good, I only left a few minor comments to address.
Getting CI rolling in the meantime

Comment on lines +331 to +333
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method could be moved into TransferTopology it's not used anywhere else

kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)

# ============================================================
# Engine registration (new)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Engine registration (new)
# Engine registration

*,
local_block_len: int = 0,
) -> EngineTransferInfo:
"""Register a remote engine, replacing scattered worker dicts.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
"""Register a remote engine, replacing scattered worker dicts.
"""Register a remote engine, unifying worker dicts state.

Comment on lines +1181 to +1184
logger.info(
"Mamba transfer plan: %s",
transfer_topo.describe_mamba(engine_id),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a single describe method here and describe the instance regardless (not under has_mamba guard), figuring out inside the method whether we're describing mamba or not

Comment on lines +1255 to +1258
mamba_info.remote_all_source_ranks,
mamba_info.remote_num_fa_reads,
mamba_info.remote_fa_descriptor_bytes,
mamba_info.remote_num_mamba_reads,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this ~kinda like a .describe?

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 17, 2026
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…ransferTopology

Replace scattered per-engine dicts (_tp_size, _block_size) and separate
HeteroTPTransferConfig construction with unified TransferTopology that
stores per-engine facts atomically via register_remote_engine().

Key changes in worker.py:
- kv_topo → transfer_topo (TransferTopology type)
- TpKVTopology constructor → TransferTopology (explicit int params)
- register_remote_engine() replaces _tp_size/_block_size dict updates
  and HeteroTPTransferConfig instantiation (idempotent, like original)
- _from_engine_id() calls → get_engine_info() + raw-int method calls
- transfer_cfg.method() → transfer_topo.method(engine_id)
- _build_fa_remote_for_mamba takes TransferTopology + engine_id
- Dead code removed: _transfer_configs, _tp_size, _block_size dicts,
  HeteroTPTransferConfig/TpKVTopology imports

Key changes in utils.py:
- HeteroTPTransferConfig deleted (~300 lines, zero consumers)
- TransferTopology: added idempotent registration, get_engine_info
- Unused `field` import removed

TpKVTopology retained for mooncake_connector.py and tests.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-tpkv-transferconfig-unification branch from 8b5fc45 to e4153f3 Compare April 17, 2026 17:31
@NickLucche NickLucche merged commit cc3993b into vllm-project:main Apr 20, 2026
57 checks passed
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 20, 2026
… TransferTopology (vllm-project#39529)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
… TransferTopology (vllm-project#39529)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
… TransferTopology (vllm-project#39529)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Apr 29, 2026
…stream breakages: NIXL connector, TpKVTopology rename, MoE refactor, transformers v5 (#1377)

## Summary

Compatibility fixes for vLLM bump to `3975eb6de6`. Addresses breakages
from multiple upstream PRs affecting NIXL connectors, MoE runner
refactor, offloading tests, Qwen3 MoE models, and transformers v5
upgrade.

## Root Cause

1. **NIXL import gate** — Upstream PR
vllm-project/vllm#39529 (commit `cc3993b05d`)
moved NIXL imports to `vllm/distributed/nixl_utils.py` and changed the
platform gate from `if not is_rocm()` to `if is_cuda()`. HPU is neither
CUDA nor ROCm, so it falls into the `else` branch → tries `rixl._api`
(ROCm-only) → fails → `NixlWrapper = None` → `RuntimeError("NIXL is not
available")`.

2. **TpKVTopology rename** — Same upstream PR #39529 unified
`TpKVTopology` + `HeteroTPTransferConfig` into `TransferTopology`,
breaking vllm-gaudi NIXL connector imports.

3. **Offloading tests** — Upstream PR
vllm-project/vllm#36645 changed
`OffloadingManager.lookup()` API.

4. **MoE runner refactor** — Upstream PR
vllm-project/vllm#35949 (commit `726efe177b`)
moved reduce logic into `MoERunnerBase`, removing `reduce_results`,
renaming `forward_dispatch` → `_forward_dispatch`, `forward_entry` →
`_forward_entry`, `_maybe_reduce_output` → `_maybe_reduce_final_output`.
Follow-up PR moved `MoERunnerBase` and `get_layer_from_name` to
`moe_runner_base.py`.

5. **Qwen3 MoE** — `SharedFusedMoE` returns a combined tensor (not a
tuple), and MoE runner now handles TP reduction internally, causing
double-reduce in `qwen3_moe.py` / `qwen3_next.py`.

6. **Transformers v5 — granite tokenizer** — Upstream PR
vllm-project/vllm#30566 updated transformers to
allow v5. GPT2Tokenizer in v5 now respects `add_bos_token=True`
(silently ignored in v4), causing degenerate outputs and 0.0 GSM8K
accuracy on granite models.

7. **Transformers v5.6.x — DeepSeek-V2-Lite tokenizer** — In
transformers v5.6.x, `LlamaTokenizerFast` was unified into
`LlamaTokenizer`, which does not apply the ByteLevel BPE decoder
declared in `tokenizer.json`. DeepSeek-V2-Lite-Chat's tokenizer decoding
strips all spaces (Ġ chars not converted back), producing garbled output
and 0.0 accuracy on GSM8K. Fixed natively in transformers v5.7.0.

## Fix

1. **NIXL import patch**: Add `patch_nixl_utils_for_hpu()` in
`register_utils()` to monkey-patch `vllm.distributed.nixl_utils` —
imports from `nixl._api` instead of `rixl._api` on HPU. Update
`hetero_hpu_nixl_connector.py` to import from
`vllm.distributed.nixl_utils` instead of hardcoded `nixl._api`.
2. **TpKVTopology → TransferTopology**: Rename in NIXL connector imports
and monkey-patches.
3. **Offloading tests**: Replace `runner.manager.lookup.return_value`
with `connector_scheduler._maximal_prefix_lookup`.
4. **MoE refactor**: Update imports (`MoERunnerBase` from
`moe_runner_base`), method names (`_forward_dispatch`, `_forward_entry`,
`_maybe_reduce_final_output`), remove dead `reduce_results` /
`reduce_output()`.
5. **Qwen3 MoE**: Remove incorrect shared_expert tuple indexing and
double TP reduction.
6. **Transformers v5 — granite**: Remove hardcoded `add_bos_token=True`
from lm-eval model_args to fix GSM8K accuracy regression.
7. **Transformers v5.6.x — DeepSeek-V2-Lite**: Exclude `transformers
5.6.*` in `requirements.txt` to prevent installation of versions with
broken ByteLevel BPE tokenizer decoding. Verified on Gaudi2: gsm8k
accuracy 0.65 (expected 0.66, within tolerance) with transformers 5.7.0.

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
… TransferTopology (vllm-project#39529)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Adrian <info@zzit.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants