[KV Connector][3/N][NIXL] Per-layer-name HMA routing for hybrid (Mamba/SSM) models under PP#43368
[KV Connector][3/N][NIXL] Per-layer-name HMA routing for hybrid (Mamba/SSM) models under PP#43368zixi-qi wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for pipeline parallelism (PP) in the NIXL KV transfer connector, updating the handshake protocol to version 5 and refactoring the worker and topology logic to manage transfers per PP shard. It also addresses a bug in the model runner where returning None instead of an empty output interfered with output aggregation. A critical issue was identified in the descriptor ID calculation for interleaved memory layouts, such as those used by FlashInfer, which would lead to corrupted KV transfers; a code suggestion was provided to correctly handle indexing for both interleaved and standard layouts.
| group_arr = np.asarray(block_ids[group_id], dtype=np.int64) | ||
| if group_arr.size == 0: | ||
| continue | ||
| desc_ids.append(region_id * num_blocks + group_arr + offset) |
There was a problem hiding this comment.
The descriptor ID calculation region_id * num_blocks + group_arr assumes a non-interleaved layout in the dlist (i.e., all blocks for region 0, then all blocks for region 1). However, for backends where is_kv_layout_blocks_first is true (like FlashInfer), the registration logic in _build_fa_local and _build_fa_remote produces an interleaved layout [K0, V0, K1, V1, ...]. Using the current formula with an interleaved layout will result in incorrect descriptor indexing, leading to corrupted KV transfers. For interleaved layouts, the index for block i of region r (where r is 0 for K and 1 for V of the same layer) should be group_arr * 2 + (region_id % 2) relative to the layer's start offset.
| desc_ids.append(region_id * num_blocks + group_arr + offset) | |
| if not include_mamba and self.transfer_topo.is_kv_layout_blocks_first: | |
| # Interleaved layout: [K0, V0, K1, V1, ...] | |
| desc_ids.append((region_id // 2) * (2 * num_blocks) + | |
| group_arr * 2 + (region_id % 2) + offset) | |
| else: | |
| # Standard layout: [R0_B0, R0_B1, ..., R1_B0, R1_B1, ...] | |
| desc_ids.append(region_id * num_blocks + group_arr + offset) |
73b732b to
4fa8b01
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
4fa8b01 to
4fe81c1
Compare
4fe81c1 to
b7c267c
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
…iate-PP output plumbing Co-authored-by: Claude Signed-off-by: zixi-qi <zixi@inferact.ai>
…superseded by vllm-project#43732 Co-authored-by: Claude Signed-off-by: zixi-qi <zixi@inferact.ai>
b7c267c to
9e9e904
Compare
… base Co-authored-by: Claude Signed-off-by: zixi-qi <zixi@inferact.ai>
…, no HMA) Signed-off-by: zixi-qi <zixi@inferact.ai>
…r PP Signed-off-by: zixi-qi <zixi@inferact.ai>
9e9e904 to
3200172
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Extends NIXL PD-disaggregated serving to hybrid (Mamba/SSM) models under
pipeline parallelism. PR #43366 lands the PP consumer per-shard refactor but
explicitly rejects hybrid producers when
pp_size > 1because the per-sharddescriptor builder doesn't yet carry Mamba region state. This PR adds the
missing Mamba/SSM bookkeeping so hybrid models (Jamba-style, Mamba-based,
etc.) work end-to-end on heterogeneous PP × TP topologies.
Net delta (this PR only, on top of #43366)
Changes
_ShardDescLayoutgrows two fields:mamba_region_count: int = 0mamba_region_group_ids: tuple[int, ...] = ()_register_local_xfer_handler_for_shardbuilds local Mamba descriptorswhen
self._has_mambais set, computesmamba_region_group_ids(each KV-groupid replicated 4× for Mamba's 4 SSM regions per layer), and embeds the result
into the per-shard layout.
add_remote_agentregisters remote Mamba blocks per shard via_build_mamba_remote(nixl_agent_meta, tp_ratio, transfer_info)and emits aremote
_ShardDescLayoutcarryingmamba_region_count/mamba_region_group_idsfor the consumer's transfer descriptor table._get_block_descs_ids_for_shardroutes Mamba shards through a logical-block-aware path: FA regions use
layout.num_blocks, Mamba regions uselayout.num_blocks // physical_blocks_per_logicalwith an offset ofnum_fa_descs. The non-Mamba path is unchanged.The
register_kv_cachesrejection guard added in #43366 is dropped — hybridproducers are now supported with
pp_size > 1.Test Plan
Unit test
Covers the Mamba region group construction path: validates
mamba_region_count,mamba_region_group_ids, descriptor ID offset math,and the FA-vs-Mamba grouping in
_ShardDescLayout.End-to-end (not yet validated on this PR)
We have not run an E2E hybrid model on PP × PD yet on the GB200 rig
this branch was developed on — the rig doesn't have a Mamba-based hybrid
model loaded. The unit test covers the per-shard descriptor construction
logic, which is the part PR #43366's rejection guard explicitly defers.
E2E validation on Jamba (or a similar hybrid) would need:
Happy to keep this PR in draft until I (or a reviewer with access to a
Mamba-capable rig) runs that smoke test. Marking it draft for now.
Lint
All hooks: Passed.
Test Result
Why this is not a duplicate
Searched vLLM open PRs/issues on 2026-05-21 for
HMA pipeline parallel,Mamba disaggregated NIXL,hybrid PP P/D. No open work targets theMamba × NIXL PD path under pipeline parallelism. The HMA × P/D paths that
do exist (e.g.
test_nixl_connector_hma.pyupstream) cover non-PPtopologies only.
AI assistance disclosure
This change was drafted with AI assistance (Claude Code, Opus 4.7). The
submitting human reviewed every changed line and ran the unit test
referenced above. This PR is the deliberate HMA × PP follow-up referenced
in PR #43366's description.
Essential Elements of an Effective PR Description Checklist