Skip to content

[NIXL] Per-region KV transfer classification for mixed full-attn + MLA groups#44583

Merged
WoosukKwon merged 20 commits into
vllm-project:mainfrom
Dao007forever:nixl_clean
Jun 12, 2026
Merged

[NIXL] Per-region KV transfer classification for mixed full-attn + MLA groups#44583
WoosukKwon merged 20 commits into
vllm-project:mainfrom
Dao007forever:nixl_clean

Conversation

@Dao007forever

@Dao007forever Dao007forever commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Purpose

Within a single KV-cache group the NIXL connector may need to transfer regions of two kinds:

  • Full-attention (GQA) layers, whose KV is head-sharded across TP → SPLIT (each rank reads its head slice from the remote at a per-rank offset).
  • MLA layers, whose latent KV is replicated on every rank and is key-only → REPLICATE (whole block read from one rank at offset 0, no V stream).

Previously this was handled by scattering if use_mla / if is_replicated branches across the block-len gate and the descriptor builders, under a single homogeneous block_len assumption. That hit a hard assert tensor_size_bytes == curr_tensor_size_bytes (and NotImplementedError on the heterogeneous branch) for any mixed group at tp_ratio != 1.

This PR classifies each region once, at registration, from its concrete spec (isinstance(layer_spec, MLAAttentionSpec)), storing a per-region _region_is_mla flag. Everything that used to branch on global flags now dispatches per region off a single _is_region_replicated(region_idx) predicate:

  • the P/D block-len validation gate,
  • the local/remote descriptor builders (_build_fa_local / _build_fa_remote),
  • num_regions accounting (key-only REPLICATE regions contribute one desc stream, not two),
  • the per-stream block length in get_backend_aware_kv_block_len.

Homogeneous models are all-SPLIT or all-REPLICATE, so their behavior is unchanged.

The size-uniformity assertion is kept, not dropped: every non-MLA region in a group must still share one tensor size (this also covers Mamba-like models). The sole exemption is MLA regions — the DeepSeek indexer legitimately differs in size within a UniformTypeKVCacheSpecs group.

Two motivations:

  1. Enable Full-Attention (GQA) main model + MLA Eagle-3 draft — such a model registers GQA FullAttentionSpec regions alongside MLA MLAAttentionSpec regions with different per-layer KV sizes (both land in one UniformTypeKVCacheSpecs group, since MLAAttentionSpec subclasses FullAttentionSpec). The old non-MLA path rejected this at registration. Per-region classification (GQA → SPLIT, MLA → REPLICATE) is what makes tp_ratio != 1 work for the mix.
  2. Simplification — replace scattered global-flag branches with one per-region predicate, keeping the logic plain and inline (no new abstraction layer).

Review history: an earlier revision introduced a RegionTransferClass dataclass to own these rules. Per review feedback (@NickLucche), that abstraction was removed in favor of inlining the SPLIT/REPLICATE branches directly in the worker, keyed off the single _is_region_replicated predicate — easier to clean up / converge later than a set of dataclasses.

Repro

vllm serve moonshotai/Kimi-K2.5 \
  --tensor-parallel-size {{PREFILL_TP_SIZE}} \
  --gpu-memory-utilization {{MEMORY_UTIL}} \
  --speculative_config '{"model": "lightseekorg/kimi-k2.5-eagle3", "num_speculative_tokens": 3, "method": "eagle3"}' \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

Test Plan

Adds test_handshake_mixed_fa_mla_hetero_tp, covering a mixed full-attn + MLA single KV group under heterogeneous TP, including the per-region gate rejecting a wrongly-scaled replicated block_len.

.venv/bin/python -m pytest \
  tests/v1/kv_connector/unit/test_nixl_connector.py \
  tests/v1/kv_connector/unit/test_tp_mapping.py -v

Test Result

Full tests/v1/kv_connector/unit/test_nixl_connector.py run: 60 passed, 2 skipped, including the new test_handshake_mixed_fa_mla_hetero_tp and all pre-existing handshake / heterogeneous-TP tests (test_prefill_tp_size_greater_than_decode_tp_size, test_prefill_tp_size_greater_than_decode_tp_size_mla, test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental, test_kv_connector_stats). test_tp_mapping.py (per-region split handles, FA + SSM split factors): all pass. The refactor that removed RegionTransferClass is behavior-preserving — no tests were added or removed by it.

Not a duplicate

The closest open / related NIXL PRs address different axes:

None classify per-region transfer behavior to support a mixed full-attn + MLA group at tp_ratio != 1.

AI assistance disclosure

AI assistance (Claude) was used in preparing this change. A human author has reviewed every changed line and is accountable for the change end-to-end.

Within a single KV-cache group the NIXL connector may need to transfer
regions of two kinds: full-attention (GQA) layers, whose KV is head-sharded
across TP, and MLA layers, whose latent KV is replicated on every rank and is
key-only. Previously this was handled by scattering `if use_mla` /
`if is_replicated` branches across the block-len gate and the descriptor
builders, with a single homogeneous `block_len` assumption that hit a hard
`assert tensor_size_bytes == curr_tensor_size_bytes` (and, on the heterogeneous
branch, `NotImplementedError`) for any mixed group at tp_ratio != 1.

Two motivations:

- Simplify: introduce a single `RegionTransferClass` (SPLIT vs REPLICATE)
  that owns the stream count, remote read count, rank offset, local-split
  descriptor, and block-len validation rules for a region. The gate and the
  descriptor builders now dispatch per region instead of branching on global
  flags. Each registered layer is tagged from its concrete spec
  (`isinstance(layer_spec, MLAAttentionSpec)`), so homogeneous models are
  all-SPLIT or all-REPLICATE and behavior is unchanged for them.

- Make a Full-Attention (GQA) main model work with an MLA Eagle-3 draft: such
  a model registers GQA `FullAttentionSpec` regions alongside MLA
  `MLAAttentionSpec` regions with different per-layer KV sizes (both land in
  one `UniformTypeKVCacheSpecs` group since `MLAAttentionSpec` subclasses
  `FullAttentionSpec`). The old non-MLA path rejected this at registration.
  Per-region validation/transfer classifies the GQA regions as SPLIT and the
  MLA region as REPLICATE, which is what makes tp_ratio != 1 work for the mix.

Adds a unit test covering a mixed full-attn + MLA single KV group under
heterogeneous TP, including the per-region gate rejecting a wrongly scaled
replicated block_len.

Signed-off-by: Dao Le <Dao007forever@gmail.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Trim verbose comments/docstrings in the per-region KV transfer code, remove the unused RegionTransferClass.name field, inline the single-use _region_is_replicate helper, and drop the redundant __init__ init of _region_is_mla (now only set in register_kv_caches, like block_len_per_layer).

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
@zixi-qi zixi-qi added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 5, 2026
@mergify

mergify Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Hi @Dao007forever, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@ivanium ivanium left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a minor comment. cc @NickLucche for more comments.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Signed-off-by: Dao Le <Dao007forever@gmail.com>

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ivanium @Dao007forever thanks for the PR.
Do you have some code we can run to reproduce the setup?

Enable Full-Attention (GQA) main model + MLA Eagle-3 draft

This bit is slightly more concerning to me, because this is kind of config was not on my radar and I would feel better if we implemented it once the infra proposed here #42082 lands.
That should enable much higher customization by pushing this sort of model-dependent behavior into KVCacheSpec/model-scope, so one wouldn't need to support all possible combinations with a single connector.

I am very open to chat some more on this to get aligned on the use case though.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Signed-off-by: Dao Le <Dao007forever@gmail.com>
@mergify

mergify Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Hi @Dao007forever, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify

mergify Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Hi @Dao007forever, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a deeper look at the issue, sorry for the initial hasty review @Dao007forever @ivanium .

Generally, I believe the root problem that is hurting both maintainability and clarity boils down to UniformTypeKVCacheSpecs being kinda of an HMA exception.
In the connector, it forces the worker to reason in terms of layers rather than groups with separate views on the same shared num_blocks x page_size tensor.
It's much more similar to handling tensors before we switched to HMA.
This forces the worker to remain in a state where it needs to be able to handle both cases with ambiguous branches.
I think we should tackle the problem much more drastically here #42082 - #42449 and have a single way to address mixed Attention compositions/specs @LucasWilkinson

Coming to this PR, this is quite similar to what attempted with Gemma4 here #41169 where SW/FA layers might need to be replicated differently.
It is also related to #44848.

However I am struggling to understand the value of the RegionTransferClass abstraction. It doesn't serve purpose outside of the MLA/GQA issue here, so it appears to me to be a more elaborate way to solve this very specific combination.
It's not really generalizable either, as with mamba (or "pure HMA") regions != layers but rather group_size.
For the time being I would suggest to simplify the abstraction and keep logic plain as in #44848 while trying to sync on some shared structure like region_spec so we don't have to keep is_ssm/is_mla/is_fa..
This might look uglier, but it's easier to clean up than a lot of dataclasses, as I ultimately believe we should go toward the solution for UniformTypeKVCacheSpecs proposed here #42449.

PS to repro the setup described in this PR

       vllm serve moonshotai/Kimi-K2.5  \
      --tensor-parallel-size {{PREFILL_TP_SIZE}} \
      --gpu-memory-utilization {{MEMORY_UTIL}} \
      --speculative_config '{"model": "lightseekorg/kimi-k2.5-eagle3", "num_speculative_tokens": 3, "method": "eagle3"}' \
      --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' 

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Signed-off-by: Dao Le <Dao007forever@gmail.com>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
@Dao007forever Dao007forever changed the title [NIXL] Per-region KV transfer classes for mixed full-attn + MLA groups [NIXL] Per-region KV transfer classification for mixed full-attn + MLA groups Jun 10, 2026
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Comment on lines +2507 to +2510
# Per-stream block length: a 2-stream SPLIT region (full-attn under
# the virtually-split layout) uses block_len//2; REPLICATE (MLA,
# key-only) and non-split layouts use the whole block.
two_streams = virtually_split and not self._is_region_replicated(layer_idx)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "stream" feels a bit odd without the Region* dataclass

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on the PR @Dao007forever !
Given timeline I am going to accept this PR as is before any of the major RFC rework lands.
I only left a few minor comments, feel free to address them as you see fit :)

It would be quite important to follow up with an e2e test on CI with any MLA/GQA combo to sure it's actually covered.

Dao007forever and others added 5 commits June 11, 2026 10:16
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Signed-off-by: Dao007forever <dao007forever@gmail.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Signed-off-by: Dao007forever <dao007forever@gmail.com>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
@mergify

mergify Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Hi @Dao007forever, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Dao007forever and others added 2 commits June 11, 2026 12:10
Signed-off-by: Dao Le <Dao007forever@gmail.com>
@Dao007forever

Copy link
Copy Markdown
Contributor Author

Re-ran with the newest code and confirmed it works.

@WoosukKwon WoosukKwon merged commit 6fbfdd1 into vllm-project:main Jun 12, 2026
77 of 79 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants