Skip to content

nixl refactor: new transfer design#40731

Merged
vllm-bot merged 44 commits into
vllm-project:mainfrom
ZhanqiuHu:nixl-refactor-plan-based-poc
May 6, 2026
Merged

nixl refactor: new transfer design#40731
vllm-bot merged 44 commits into
vllm-project:mainfrom
ZhanqiuHu:nixl-refactor-plan-based-poc

Conversation

@ZhanqiuHu

@ZhanqiuHu ZhanqiuHu commented Apr 23, 2026

Copy link
Copy Markdown
Contributor

Refactor 3/N

@ZhanqiuHu ZhanqiuHu changed the title nixl refactor: plan-based transfer design nixl refactor: new transfer design Apr 23, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the NIXL connector to use a plan-based transfer design, which unifies Dense and Mamba implementations into a model-agnostic execution path. By pre-generating an EngineTransferPlan during the handshake, the hot path is simplified and model-specific branching is removed. The review feedback highlights several opportunities to improve robustness, specifically by replacing assertions with explicit error handling for block count mismatches and adding divisibility checks when calculating logical blocks and chunk sizes to prevent potential data corruption or crashes in heterogeneous tensor parallel configurations.

Comment on lines 1788 to +1790
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
assert num_local_blocks <= len(remote_group)
if num_local_blocks < len(remote_group):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert num_local_blocks <= len(remote_group) can cause a hard crash of the worker if the producer sends fewer blocks than the consumer expects (e.g., due to a race condition or misconfiguration). It is safer to handle this as a transfer failure for the specific request, allowing the engine to continue processing other requests. Since _read_blocks is called within a try-except block in _read_blocks_for_req, raising a ValueError will be caught and handled gracefully.

Suggested change
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
assert num_local_blocks <= len(remote_group)
if num_local_blocks < len(remote_group):
num_local_blocks = len(local_block_ids[i])
if num_local_blocks > len(remote_group):
raise ValueError(
f"Group {i}: local block count ({num_local_blocks}) "
f"exceeds remote block count ({len(remote_group)})")
if num_local_blocks < len(remote_group):

Comment on lines +565 to +566
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of logical_blocks = num_blocks // ratio assumes that the total number of kernel blocks is a perfect multiple of the physical-to-logical ratio. If this invariant is violated, the descriptor IDs for SSM regions will be incorrectly computed, leading to memory corruption or incorrect data transfer. An explicit check should be added to verify this invariant.

Suggested change
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio
ratio = physical_blocks_per_logical
if num_blocks % ratio != 0:
raise ValueError(f"num_blocks {num_blocks} is not a multiple of "
f"physical_blocks_per_logical {ratio}")
logical_blocks = num_blocks // ratio

Comment on lines +646 to +651
if j < num_fa_descs:
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using integer division to compute chunk sizes for split handles can lead to silent data loss if the total length is not perfectly divisible by the number of splits. While standard vLLM configurations usually satisfy this, heterogeneous TP scenarios with non-standard world sizes could trigger this issue. It is safer to explicitly check for divisibility.

Suggested change
if j < num_fa_descs:
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))
if j < num_fa_descs:
if local_len % fa_num_splits != 0:
raise ValueError(f"FA descriptor length {local_len} is not "
f"divisible by split count {fa_num_splits}")
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
if local_len % ssm_num_splits != 0:
raise ValueError(f"SSM descriptor length {local_len} is not "
f"divisible by split count {ssm_num_splits}")
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch 2 times, most recently from 4852347 to fcf7418 Compare April 27, 2026 15:41
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review April 27, 2026 18:01

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, thanks @ZhanqiuHu !

Comment on lines +91 to +93
# ------------------------------------------------------------------
# Plan executors (static — no self access)
# ------------------------------------------------------------------

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels a little bit too "claudy"

Comment on lines +54 to +60
@dataclass(frozen=True)
class RegionPlan:
"""Geometry for one descriptor region.

Everything needed to build NIXL descriptors and compute descriptor
IDs is baked in. The caller plugs in ``base_addr`` and
``device_id`` when constructing the final descriptor tuples.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not fully convinced about this abstraction.
I am afraid this may actually be harder to work with rather than a basic region described by (base_addr, len).

Like do we care about keeping track of things like

  • page_stride
  • offset_in_page

once the starting address of the region has been computed?
'Cause if we don't, we might as well just store the address and len directly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is purely my personal preference, but I was thinking of moving the
geometry computation (stride, offset, descriptor size) from worker.py
into transfer_plan.py to reduce the worker code. RegionPlan packs
the output of that. Originally the functions included base_addr and
device_id, but then I wanted to reduce the arguments to
_build_fa_regions, build_fa_local_regions, and
build_mamba_local_regions, and base_addr and device_id are not used for block geometry computation.

I was also thinking of adding parameters like descs_per_block and desc_stride_bytes to RegionPlan,
so we can handle different cache groups with different block size ratios
(e.g, Gemma4 HeteroTP where SWA and FA have different tokens-per-block).

Comment on lines +252 to +255
handle.append((addr + p_idx * chunk, chunk, dev))
result.append(handle)

return result

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that it matters much in terms of speed, but this whole method could yield handle here an be a generator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will update _build_local_splits_from_plan to yield handle.

Comment on lines +143 to +151
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
spec_type = plan.group_spec_types[i]
if _is_attention_spec(spec_type):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lost a nice optimization for non mamba models that used vectorized np ops only here

        if not self._has_mamba:
            block_ids = np.concatenate(block_ids)[None, :]
            descs_ids = region_ids * num_blocks + block_ids
            return descs_ids.flatten()
        else:
            # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged

@ZhanqiuHu could you double check

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's truth. I was thinking of removing the branch. I can add it back like this:

@staticmethod
def _compute_desc_ids_from_plan(
    plan: EngineTransferPlan,
    block_ids: BlockIds,
    dst_num_blocks: int,
    block_size_ratio: float | None,
    physical_blocks_per_logical: int,
) -> np.ndarray:
    """Compute NIXL descriptor IDs for given block IDs."""
    num_fa_regions = len(plan.fa_regions)
    num_ssm_regions = len(plan.ssm_regions)

    num_blocks = dst_num_blocks
    if block_size_ratio is not None:
        num_blocks = int(num_blocks * block_size_ratio)
    ratio = physical_blocks_per_logical
    logical_blocks = num_blocks // ratio

    num_fa_descs = num_fa_regions * num_blocks

    # All-attention fast path: single vectorized broadcast.
    if num_ssm_regions == 0:
        block_arr = np.concatenate(block_ids)[None, :]
        region_ids = np.arange(num_fa_regions)[:, None]
        return (region_ids * num_blocks + block_arr).flatten()

    # NOTE (NickLucche) With HMA, every kv group has the same number
    # of layers and layers from different groups share the same kv
    # tensor.  Therefore we compute desc IDs per group using the
    # right stride:
    # FA descs have num_blocks entries per region (kernel granularity),
    # SSM descs have logical_blocks entries per region (no kernel
    # splitting).
    all_descs: list[np.ndarray] = []
    for i, group in enumerate(block_ids):
        group_arr = np.asarray(group)
        if _is_attention_spec(plan.group_spec_types[i]):
            fa_region_ids = np.arange(num_fa_regions)[:, None]
            all_descs.append(
                (fa_region_ids * num_blocks + group_arr[None, :]).flatten()
            )
        elif _is_ssm_spec(plan.group_spec_types[i]):
            # NOTE (NickLucche) SSM and Attention block regions can
            # be exchanged arbitrarily by manager.  Therefore, descs
            # are laid out as:
            #   [descs_fa (all regions) | descs_ssm (all regions)].
            # num_fa_descs offset must be computed per-engine since
            # P and D can have different num_blocks (and thus
            # different FA desc counts).
            ssm_region_ids = np.arange(num_ssm_regions)[:, None]
            all_descs.append(
                (
                    ssm_region_ids * logical_blocks
                    + group_arr[None, :]
                    + num_fa_descs
                ).flatten()
            )
        else:
            raise ValueError(
                f"Unknown spec type {plan.group_spec_types[i]} at index {i}"
            )

    return np.concatenate(all_descs)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok

Comment on lines +88 to +92
# Per-group ordered source ranks. Position = local piece index.
source_ranks_per_group: tuple[tuple[int, ...], ...]

# Superset of all source ranks (union of all groups).
all_source_ranks: tuple[int, ...]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think it's very clear what "source_rank" is here..

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

source_rank is suppose to mean the remote TP ranks that this local rank reads from.

source_ranks_per_group is the source_rank for each kv cache group (e.g., FA source ranks will be < mamba source ranks if FA is replicated and Mamba is sharded).

Should we renamed it to something else?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add this as comment above here

group_spec_types: tuple[type[KVCacheSpec], ...],
local_physical_blocks_per_logical: int,
) -> EngineTransferPlan:
"""Generate transfer plan for dense (attention-only) models."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: very minor, but we may want to choose some other name or clarify that for dense we still encompass all non-mamba, including SW and DSA

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about generate_pure_attention_plan() and
generate_ssm_attention_hybrid_plan()?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's ask claude for some more options here

Comment on lines +257 to +261
def _build_local_descs(
self,
base_addresses: list[int],
block_size_ratio: int,
) -> list[tuple[int, int, int]]:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is only used once in register_local_xfer_handler.
I don't think there's a lot of value in added clarity in separating this snippet as a staticmethod

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will inline it.

Comment on lines +174 to +179
def _compute_read_specs_from_plan(
plan: EngineTransferPlan,
local_block_ids: BlockIds,
remote_block_ids: BlockIds,
) -> list[ReadSpec]:
"""Compute read specs from plan.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also not sure whether this should be a function, or an inline for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will inline it.

Comment on lines +1853 to +1854
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this continuation comment has lost its first part now. It was meant to be

            # MLA opt: when P TP > D TP, only a single read is executed for
            # the first remote rank (cache is duplicated)..
            # ..but we still need to notify the other remote ranks that we
            # have the blocks we need so they can update the request state.

but we can just rephrase

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part is actually before

if self.use_mla and tp_ratio < 0:
            read_specs = read_specs[:1]

The change is already in main (main ref), probably introduced by a previous PR.

Comment on lines +1841 to +1846
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=local_ids,
remote_block_ids=remote_ids,
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not fully sure, but should we change _read_blocks interface to just be

    def _read_blocks(
        self,
        dst_engine_id: str,
        request_id: str,
        read_spec: ReadSpec, 
        remote_request_id: str,
        local_xfer_side_handle: int,
        remote_xfer_side_handle: int,
    ):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we can do

def _read_blocks(
    self,
    read_spec: ReadSpec,
    dst_engine_id: str,
    remote_request_id: str,  
    local_xfer_side_handle: int,  
    remote_xfer_side_handle: int,  
):
    local_block_ids = read_spec.local_block_ids
    remote_block_ids = read_spec.remote_block_ids
    remote_rank = read_spec.remote_rank

@ZhanqiuHu ZhanqiuHu left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed comments

Comment on lines +88 to +92
# Per-group ordered source ranks. Position = local piece index.
source_ranks_per_group: tuple[tuple[int, ...], ...]

# Superset of all source ranks (union of all groups).
all_source_ranks: tuple[int, ...]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add this as comment above here

group_spec_types: tuple[type[KVCacheSpec], ...],
local_physical_blocks_per_logical: int,
) -> EngineTransferPlan:
"""Generate transfer plan for dense (attention-only) models."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's ask claude for some more options here

Comment on lines +143 to +151
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
spec_type = plan.group_spec_types[i]
if _is_attention_spec(spec_type):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok

Comment on lines -2267 to -2295
Get the block length for one K/V element (K and V have the same size).

For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
Similarly, for SSM-based models, state and conv are interleaved, but crucially
the their size differs.
Reference diagram:
KVCacheTensor (Shared)
/ \\
/ \\
/ \\
Attention (FlashInfer) View Mamba View
| |
| |
+-------------------+ +-------------------+
| KVCacheTensor | | KVCacheTensor |
| | | |
|<----- page ------>| |<----- page ------->|
| size | | size |
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
| ... | ... | | ... | ... |
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lost this whole diagram :( @ZhanqiuHu

Comment on lines -389 to -391
@dataclass(frozen=True)
class MambaEngineTransferInfo(EngineTransferInfo):
"""Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup in this file

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from a6e5266 to 51f297a Compare May 4, 2026 18:11
@mergify

mergify Bot commented May 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 4, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from 51f297a to 7881c46 Compare May 4, 2026 18:17
@mergify mergify Bot removed the needs-rebase label May 4, 2026

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a particularly complicated part of the codebase, thanks for the great work here @ZhanqiuHu trying to improve clarity and maintainability of it!

Happy to finally approve this PR. Left some comments, mostly around preserving some of the context we had in comments. Will push something to help the work around the clock

Comment on lines +867 to +869
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be cruft

Comment on lines +1 to +3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP mapping computation for NIXL KV cache transfers."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is tp_mapping now :)

Comment on lines -1269 to -1274
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].

# Register all remote blocks, but only the corresponding kv heads.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might still be good context

)

if transfer_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

Comment on lines -1061 to -1063
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines -964 to -965
# Mamba conv state is always TP-sharded, even when attention KV
# is replicated (num_kv_heads < tp_size).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:(

Comment on lines -271 to -276
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine.
# physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is also good context

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label May 5, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from f23d68c to 98c3207 Compare May 5, 2026 15:43
@NickLucche NickLucche enabled auto-merge (squash) May 5, 2026 17:13
ZhanqiuHu added 3 commits May 5, 2026 14:30
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
ZhanqiuHu and others added 5 commits May 5, 2026 14:30
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
auto-merge was automatically disabled May 5, 2026 18:31

Head branch was pushed to by a user without write access

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from 98c3207 to 1232865 Compare May 5, 2026 18:31
@vllm-bot vllm-bot merged commit df8e63f into vllm-project:main May 6, 2026
62 of 64 checks passed
@ywang96

ywang96 commented May 7, 2026

Copy link
Copy Markdown
Member

It seems that this PR breaks DSv4 PD with dynamo

(Worker_DP0_EP0 pid=2959201) INFO 05-07 07:42:03 [worker.py:541] NIXL compatibility check passed (hash: eb7849ad8e051023aa810813258c453818ec45450c54665207c2bd83b5606974)
(Worker_DP0_EP0 pid=2959201) INFO 05-07 07:42:03 [worker.py:1259] Transfer plan: TransferTopology(tp_ratio=1, K=1, local_tp=1, remote_tp=1, local_rank=0, remote_block_len=1728)
(Worker_DP1_EP1 pid=2959202) INFO 05-07 07:42:03 [worker.py:541] NIXL compatibility check passed (hash: eb7849ad8e051023aa810813258c453818ec45450c54665207c2bd83b5606974)
(Worker_DP1_EP1 pid=2959202) INFO 05-07 07:42:03 [worker.py:1259] Transfer plan: TransferTopology(tp_ratio=1, K=1, local_tp=1, remote_tp=1, local_rank=0, remote_block_len=1728)
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706] NIXL transfer failure: handshake_setup_failed | Context: {'failure_type': 'handshake_setup_failed', 'request_id': None, 'engine_id': '6d8107b7-096a-4da8-9fe2-929d116adffa_dp0', 'remote_engine_id': '74b0ea09-24ec-4d38-a3c7-6d0c9d305e3a_dp0'}
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706] Traceback (most recent call last):
(Worker_DP0_EP0 pid=2959201) ERROR 05-07 07:42:04 [worker.py:706] NIXL transfer failure: handshake_setup_failed | Context: {'failure_type': 'handshake_setup_failed', 'request_id': None, 'engine_id': '815af56e-bd22-4382-8f4a-d809aadf0930_dp0', 'remote_engine_id': '5126e220-4b46-4a89-9e3c-050321fd6fc2_dp0'}
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py", line 704, in done_callback

Have we verified the implementation with an end-to-end integration load test? We're investigating on parallel.

@ZhanqiuHu

ZhanqiuHu commented May 7, 2026

Copy link
Copy Markdown
Contributor Author

It seems that this PR breaks DSv4 PD with dynamo

(Worker_DP0_EP0 pid=2959201) INFO 05-07 07:42:03 [worker.py:541] NIXL compatibility check passed (hash: eb7849ad8e051023aa810813258c453818ec45450c54665207c2bd83b5606974)
(Worker_DP0_EP0 pid=2959201) INFO 05-07 07:42:03 [worker.py:1259] Transfer plan: TransferTopology(tp_ratio=1, K=1, local_tp=1, remote_tp=1, local_rank=0, remote_block_len=1728)
(Worker_DP1_EP1 pid=2959202) INFO 05-07 07:42:03 [worker.py:541] NIXL compatibility check passed (hash: eb7849ad8e051023aa810813258c453818ec45450c54665207c2bd83b5606974)
(Worker_DP1_EP1 pid=2959202) INFO 05-07 07:42:03 [worker.py:1259] Transfer plan: TransferTopology(tp_ratio=1, K=1, local_tp=1, remote_tp=1, local_rank=0, remote_block_len=1728)
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706] NIXL transfer failure: handshake_setup_failed | Context: {'failure_type': 'handshake_setup_failed', 'request_id': None, 'engine_id': '6d8107b7-096a-4da8-9fe2-929d116adffa_dp0', 'remote_engine_id': '74b0ea09-24ec-4d38-a3c7-6d0c9d305e3a_dp0'}
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706] Traceback (most recent call last):
(Worker_DP0_EP0 pid=2959201) ERROR 05-07 07:42:04 [worker.py:706] NIXL transfer failure: handshake_setup_failed | Context: {'failure_type': 'handshake_setup_failed', 'request_id': None, 'engine_id': '815af56e-bd22-4382-8f4a-d809aadf0930_dp0', 'remote_engine_id': '5126e220-4b46-4a89-9e3c-050321fd6fc2_dp0'}
(Worker_DP1_EP1 pid=2959202) ERROR 05-07 07:42:04 [worker.py:706]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py", line 704, in done_callback

Have we verified the implementation with an end-to-end integration load test? We're investigating on parallel.

I tested this end-to-end with Qwen and Nemotron and Deepseek-V2, haven't test with Deepseek v4 yet.

Will look into this. By the way do you have the setup and full traces? Thanks!

@NickLucche

Copy link
Copy Markdown
Member

might be a compat version bump, thanks for reporting @ywang96 will check this out

@gau-nernst

gau-nernst commented May 7, 2026

Copy link
Copy Markdown
Contributor

@ZhanqiuHu Here is an MRE using Slurm, launching 1 prefill and 1 decode worker

Script
"""Launch one DSv4-Flash prefill worker and one decode worker via Slurm.

Please change the hard-coded nodes accordingly.

Install NIXL
  uv pip install 'nixl[cu13]==0.10.1' 'nixl-cu12==0.10.1' 'nixl-cu13==0.10.1' 'cupy-cuda12x==14.0.1'
"""

import contextlib
from dataclasses import dataclass
import json
import os
import signal
import shlex
import subprocess
import sys
import time
import uuid
from pathlib import Path

import requests

MODEL = "deepseek-ai/DeepSeek-V4-Flash"
SERVER_HOST = "0.0.0.0"
LOG_DIR = Path(__file__).parent
GPUS_PER_WORKER = 4
CPUS_PER_WORKER = 32
READY_TIMEOUT_S = 900
REQUEST_TIMEOUT_S = 300


@dataclass(frozen=True)
class Worker:
    name: str
    node: str
    port: int
    nixl_port: int

    @property
    def host(self) -> str:
        return self.node

    @property
    def base_url(self) -> str:
        return f"http://{self.host}:{self.port}"

    @property
    def log_path(self) -> Path:
        return LOG_DIR / f"{self.name}.log"

### PLEASE UPDATE THIS ###
PREFILL = Worker(name="prefill", node="node-01", port=8100, nixl_port=5610)
DECODE = Worker(name="decode", node="node-02", port=8200, nixl_port=5620)

PROMPT = (
    "Write a detailed explanation of paged attention for transformer inference, "
    "including how KV cache blocks are allocated, transferred, and reused across "
    "prefill and decode workers in a disaggregated serving system. Include enough "
    "detail that this prompt exceeds the KV cache block threshold."
)

VLLM_CMD = f"""
{shlex.quote(sys.executable)}
-m vllm.entrypoints.openai.api_server
--model {shlex.quote(MODEL)}
--host {SERVER_HOST}
--trust-remote-code
--load-format dummy
--data-parallel-size {GPUS_PER_WORKER}
--enable-expert-parallel
--max-model-len 512
--max-num-batched-tokens 512
--block-size 256
--gpu-memory-utilization 0.80
--kv-cache-dtype fp8
--tokenizer-mode deepseek_v4
--moe-backend deep_gemm_mega_moe
--enforce-eager
--compilation-config '{{"mode":0}}'
--attention-config '{{"use_fp4_indexer_cache":true}}'
--kv-transfer-config '{{"kv_connector":"NixlConnector","kv_role":"kv_both"}}'
"""


def launch(worker: Worker) -> subprocess.Popen:
    cmd = [
        "srun",
        "--nodes=1",
        "--ntasks=1",
        f"--nodelist={worker.node}",
        f"--gpus-per-task={GPUS_PER_WORKER}",
        f"--cpus-per-task={CPUS_PER_WORKER}",
        "--export=ALL",
        *shlex.split(VLLM_CMD),
        "--port",
        str(worker.port),
    ]
    env = os.environ | {
        "VLLM_NIXL_SIDE_CHANNEL_HOST": worker.host,
        "VLLM_NIXL_SIDE_CHANNEL_PORT": str(worker.nixl_port),
        "VLLM_WORKER_MULTIPROC_METHOD": os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
        "UCX_TLS": os.getenv("UCX_TLS", "cuda_copy,cuda_ipc,tcp"),
        "UCX_MEMTYPE_CACHE": os.getenv("UCX_MEMTYPE_CACHE", "n"),
        "UCX_MEMTYPE_REG_WHOLE": os.getenv("UCX_MEMTYPE_REG_WHOLE", "n"),
    }

    print(f"\n{worker.name}: {shlex.join(cmd)}", flush=True)
    print(
        f"{worker.name} srun: node={worker.node} gpus_per_task={GPUS_PER_WORKER} "
        f"cpus_per_task={CPUS_PER_WORKER} "
        f"{worker.name} env: VLLM_NIXL_SIDE_CHANNEL_HOST={env['VLLM_NIXL_SIDE_CHANNEL_HOST']} "
        f"VLLM_NIXL_SIDE_CHANNEL_PORT={env['VLLM_NIXL_SIDE_CHANNEL_PORT']}",
        flush=True,
    )
    return subprocess.Popen(
        cmd,
        env=env,
        stdout=worker.log_path.open("w", encoding="utf-8"),
        stderr=subprocess.STDOUT,
        text=True,
        start_new_session=True,
    )


def wait_ready(worker: Worker, proc: subprocess.Popen) -> None:
    url = f"{worker.base_url}/v1/models"
    deadline = time.monotonic() + READY_TIMEOUT_S
    while time.monotonic() < deadline:
        if proc.poll() is not None:
            raise RuntimeError(f"{worker.name} exited early; see {worker.log_path}")
        try:
            requests.get(url, timeout=5).raise_for_status()
            print(f"{worker.name} ready: {url}", flush=True)
            return
        except requests.RequestException:
            time.sleep(2)
    raise TimeoutError(f"{worker.name} did not become ready; see {worker.log_path}")


def raise_for_status(response: requests.Response) -> None:
    try:
        response.raise_for_status()
    except requests.HTTPError as exc:
        body = response.text[:4000]
        raise RuntimeError(f"{response.request.method} {response.url} failed: {exc}\n{body}") from exc


def stop(proc: subprocess.Popen | None) -> None:
    if proc is None or proc.poll() is not None:
        return
    with contextlib.suppress(ProcessLookupError):
        os.killpg(proc.pid, signal.SIGTERM)
    try:
        proc.wait(timeout=20)
    except subprocess.TimeoutExpired:
        with contextlib.suppress(ProcessLookupError):
            os.killpg(proc.pid, signal.SIGKILL)


def main() -> None:
    LOG_DIR.mkdir(parents=True, exist_ok=True)
    print(f"logs: {LOG_DIR}", flush=True)
    print(
        f"prefill: node={PREFILL.node} host={PREFILL.host} port={PREFILL.port} nixl_port={PREFILL.nixl_port}",
        flush=True,
    )
    print(
        f"decode: node={DECODE.node} host={DECODE.host} port={DECODE.port} nixl_port={DECODE.nixl_port}",
        flush=True,
    )

    prefill = decode = None
    try:
        prefill = launch(PREFILL)
        decode = launch(DECODE)
        wait_ready(PREFILL, prefill)
        wait_ready(DECODE, decode)

        request_id = str(uuid.uuid4())
        prefill_response = requests.post(
            f"{PREFILL.base_url}/v1/completions",
            headers={"X-Request-Id": request_id},
            timeout=REQUEST_TIMEOUT_S,
            json={
                "model": MODEL,
                "prompt": PROMPT,
                "max_tokens": 1,
                "temperature": 0,
                "kv_transfer_params": {
                    "do_remote_decode": True,
                    "do_remote_prefill": False,
                    "remote_engine_id": None,
                    "remote_block_ids": None,
                    "remote_host": None,
                    "remote_port": None,
                },
            },
        )
        raise_for_status(prefill_response)
        prefill_json = prefill_response.json()
        print("\nprefill response:", json.dumps(prefill_json, indent=2)[:4000])

        kv_transfer_params = prefill_json.get("kv_transfer_params")
        if not kv_transfer_params:
            raise RuntimeError(f"prefill response did not include kv_transfer_params: {prefill_json}")

        decode_response = requests.post(
            f"{DECODE.base_url}/v1/completions",
            headers={"X-Request-Id": request_id},
            timeout=REQUEST_TIMEOUT_S,
            json={
                "model": MODEL,
                "prompt": "decode side placeholder prompt",
                "max_tokens": 8,
                "temperature": 0,
                "kv_transfer_params": kv_transfer_params,
            },
        )
        print("\ndecode status:", decode_response.status_code)
        print(decode_response.text[:4000])
        raise_for_status(decode_response)
    finally:
        print(f"\nlogs: {LOG_DIR}", flush=True)
        stop(prefill)
        stop(decode)


if __name__ == "__main__":
    main()

Here is another MRE generated by Codex, looking into the internals. I'm not too familiar with this code so I'm not entirely sure if it makes sense. But maybe you will understand

"""Minimal repro for the suspected NIXL PR #40731 MLA classification bug.
"""

import torch

from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import (
    _is_attention_spec,
)
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    MambaSpec,
    UniformTypeKVCacheSpecs,
)


def classify_like_pr40731(spec: KVCacheSpec) -> type[KVCacheSpec]:
    return type(spec)


def classify_with_uniform_unwrap(spec: KVCacheSpec) -> type[KVCacheSpec]:
    if isinstance(spec, UniformTypeKVCacheSpecs):
        inner_specs = tuple(spec.kv_cache_specs.values())
        if inner_specs and all(isinstance(s, AttentionSpec) for s in inner_specs):
            return AttentionSpec
        if inner_specs and all(isinstance(s, MambaSpec) for s in inner_specs):
            return MambaSpec
    return type(spec)


def find_fa_group(group_spec_types: tuple[type[KVCacheSpec], ...]) -> int:
    return next(
        i for i, spec_type in enumerate(group_spec_types) if _is_attention_spec(spec_type)
    )


def main() -> None:
    mla_spec = MLAAttentionSpec(
        block_size=64,
        num_kv_heads=1,
        head_size=512,
        dtype=torch.bfloat16,
        cache_dtype_str="fp8_ds_mla",
        compress_ratio=1,
        model_version="deepseek_v4",
    )
    uniform_mla_spec = UniformTypeKVCacheSpecs(
        block_size=64,
        kv_cache_specs={
            "model.layers.0.self_attn": mla_spec,
            "model.layers.1.self_attn": mla_spec,
        },
    )

    old_group_spec_types = (classify_like_pr40731(uniform_mla_spec),)
    print(
        "old_group_spec_types =",
        [spec_type.__name__ for spec_type in old_group_spec_types],
    )
    try:
        old_idx = find_fa_group(old_group_spec_types)
    except StopIteration:
        print("old path reproduced StopIteration")
    else:
        raise AssertionError(f"old path unexpectedly found FA group {old_idx}")

    fixed_group_spec_types = (classify_with_uniform_unwrap(uniform_mla_spec),)
    print(
        "fixed_group_spec_types =",
        [spec_type.__name__ for spec_type in fixed_group_spec_types],
    )
    fixed_idx = find_fa_group(fixed_group_spec_types)
    print("fixed path found FA group", fixed_idx)


if __name__ == "__main__":
    main()

@ZhanqiuHu

Copy link
Copy Markdown
Contributor Author

@gau-nernst Thanks! I will take a look. IIRC UniformTypeKVCacheSpecs is not specially handled in PD

@NickLucche

Copy link
Copy Markdown
Member

UniformTypeKVCacheSpecs is not specially handled in PD

yep when registering kv tensors

NickLucche pushed a commit to NickLucche/vllm that referenced this pull request May 7, 2026
…roject#40731

PR vllm-project#40449 fixed the MLA broadcast notification to use
`meta.remote.request_id` (the prefill-side id) instead of
`req_id` (the decode-side id). The refactor in vllm-project#40731
inadvertently reverted this back to `req_id`, causing
"Potentially invalid KV blocks for unrecognized request"
errors on non-reading P-ranks when P_TP > D_TP.

Fixes vllm-project#41974

Signed-off-by: Zhanqiu Hu <zhanqiuhu@gmail.com>
NickLucche added a commit to NickLucche/vllm that referenced this pull request May 7, 2026
[Bugfix] restore MLA notif request_id accidentally reverted in vllm-project#40731
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request May 7, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request May 8, 2026
…roject#40731

PR vllm-project#40449 fixed the MLA broadcast notification to use
`meta.remote.request_id` (the prefill-side id) instead of
`req_id` (the decode-side id). The refactor in vllm-project#40731
inadvertently reverted this back to `req_id`, causing
"Potentially invalid KV blocks for unrecognized request"
errors on non-reading P-ranks when P_TP > D_TP.

Fixes vllm-project#41974

Signed-off-by: Zhanqiu Hu <zhanqiuhu@gmail.com>

Signed-off-by: NickLucche <nlucches@redhat.com>
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
@ZhanqiuHu ZhanqiuHu deleted the nixl-refactor-plan-based-poc branch June 4, 2026 17:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants