[NIXL] Per-region KV transfer classification for mixed full-attn + MLA groups#44583
Conversation
Within a single KV-cache group the NIXL connector may need to transfer regions of two kinds: full-attention (GQA) layers, whose KV is head-sharded across TP, and MLA layers, whose latent KV is replicated on every rank and is key-only. Previously this was handled by scattering `if use_mla` / `if is_replicated` branches across the block-len gate and the descriptor builders, with a single homogeneous `block_len` assumption that hit a hard `assert tensor_size_bytes == curr_tensor_size_bytes` (and, on the heterogeneous branch, `NotImplementedError`) for any mixed group at tp_ratio != 1. Two motivations: - Simplify: introduce a single `RegionTransferClass` (SPLIT vs REPLICATE) that owns the stream count, remote read count, rank offset, local-split descriptor, and block-len validation rules for a region. The gate and the descriptor builders now dispatch per region instead of branching on global flags. Each registered layer is tagged from its concrete spec (`isinstance(layer_spec, MLAAttentionSpec)`), so homogeneous models are all-SPLIT or all-REPLICATE and behavior is unchanged for them. - Make a Full-Attention (GQA) main model work with an MLA Eagle-3 draft: such a model registers GQA `FullAttentionSpec` regions alongside MLA `MLAAttentionSpec` regions with different per-layer KV sizes (both land in one `UniformTypeKVCacheSpecs` group since `MLAAttentionSpec` subclasses `FullAttentionSpec`). The old non-MLA path rejected this at registration. Per-region validation/transfer classifies the GQA regions as SPLIT and the MLA region as REPLICATE, which is what makes tp_ratio != 1 work for the mix. Adds a unit test covering a mixed full-attn + MLA single KV group under heterogeneous TP, including the per-region gate rejecting a wrongly scaled replicated block_len. Signed-off-by: Dao Le <Dao007forever@gmail.com> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Trim verbose comments/docstrings in the per-region KV transfer code, remove the unused RegionTransferClass.name field, inline the single-use _region_is_replicate helper, and drop the redundant __init__ init of _region_is_mla (now only set in register_kv_caches, like block_len_per_layer). Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
|
Hi @Dao007forever, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
ivanium
left a comment
There was a problem hiding this comment.
LGTM with a minor comment. cc @NickLucche for more comments.
Signed-off-by: Dao Le <Dao007forever@gmail.com>
NickLucche
left a comment
There was a problem hiding this comment.
Hey @ivanium @Dao007forever thanks for the PR.
Do you have some code we can run to reproduce the setup?
Enable Full-Attention (GQA) main model + MLA Eagle-3 draft
This bit is slightly more concerning to me, because this is kind of config was not on my radar and I would feel better if we implemented it once the infra proposed here #42082 lands.
That should enable much higher customization by pushing this sort of model-dependent behavior into KVCacheSpec/model-scope, so one wouldn't need to support all possible combinations with a single connector.
I am very open to chat some more on this to get aligned on the use case though.
Signed-off-by: Dao Le <Dao007forever@gmail.com>
|
Hi @Dao007forever, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
|
Hi @Dao007forever, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
There was a problem hiding this comment.
I've taken a deeper look at the issue, sorry for the initial hasty review @Dao007forever @ivanium .
Generally, I believe the root problem that is hurting both maintainability and clarity boils down to UniformTypeKVCacheSpecs being kinda of an HMA exception.
In the connector, it forces the worker to reason in terms of layers rather than groups with separate views on the same shared num_blocks x page_size tensor.
It's much more similar to handling tensors before we switched to HMA.
This forces the worker to remain in a state where it needs to be able to handle both cases with ambiguous branches.
I think we should tackle the problem much more drastically here #42082 - #42449 and have a single way to address mixed Attention compositions/specs @LucasWilkinson
Coming to this PR, this is quite similar to what attempted with Gemma4 here #41169 where SW/FA layers might need to be replicated differently.
It is also related to #44848.
However I am struggling to understand the value of the RegionTransferClass abstraction. It doesn't serve purpose outside of the MLA/GQA issue here, so it appears to me to be a more elaborate way to solve this very specific combination.
It's not really generalizable either, as with mamba (or "pure HMA") regions != layers but rather group_size.
For the time being I would suggest to simplify the abstraction and keep logic plain as in #44848 while trying to sync on some shared structure like region_spec so we don't have to keep is_ssm/is_mla/is_fa..
This might look uglier, but it's easier to clean up than a lot of dataclasses, as I ultimately believe we should go toward the solution for UniformTypeKVCacheSpecs proposed here #42449.
PS to repro the setup described in this PR
vllm serve moonshotai/Kimi-K2.5 \
--tensor-parallel-size {{PREFILL_TP_SIZE}} \
--gpu-memory-utilization {{MEMORY_UTIL}} \
--speculative_config '{"model": "lightseekorg/kimi-k2.5-eagle3", "num_speculative_tokens": 3, "method": "eagle3"}' \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
Signed-off-by: Dao Le <Dao007forever@gmail.com>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
| # Per-stream block length: a 2-stream SPLIT region (full-attn under | ||
| # the virtually-split layout) uses block_len//2; REPLICATE (MLA, | ||
| # key-only) and non-split layouts use the whole block. | ||
| two_streams = virtually_split and not self._is_region_replicated(layer_idx) |
There was a problem hiding this comment.
nit: "stream" feels a bit odd without the Region* dataclass
NickLucche
left a comment
There was a problem hiding this comment.
Thanks for iterating on the PR @Dao007forever !
Given timeline I am going to accept this PR as is before any of the major RFC rework lands.
I only left a few minor comments, feel free to address them as you see fit :)
It would be quite important to follow up with an e2e test on CI with any MLA/GQA combo to sure it's actually covered.
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Signed-off-by: Dao007forever <dao007forever@gmail.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Signed-off-by: Dao007forever <dao007forever@gmail.com>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
|
Hi @Dao007forever, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
Signed-off-by: Dao Le <Dao007forever@gmail.com>
|
Re-ran with the newest code and confirmed it works. |
Purpose
Within a single KV-cache group the NIXL connector may need to transfer regions of two kinds:
Previously this was handled by scattering
if use_mla/if is_replicatedbranches across the block-len gate and the descriptor builders, under a single homogeneousblock_lenassumption. That hit a hardassert tensor_size_bytes == curr_tensor_size_bytes(andNotImplementedErroron the heterogeneous branch) for any mixed group attp_ratio != 1.This PR classifies each region once, at registration, from its concrete spec (
isinstance(layer_spec, MLAAttentionSpec)), storing a per-region_region_is_mlaflag. Everything that used to branch on global flags now dispatches per region off a single_is_region_replicated(region_idx)predicate:_build_fa_local/_build_fa_remote),num_regionsaccounting (key-only REPLICATE regions contribute one desc stream, not two),get_backend_aware_kv_block_len.Homogeneous models are all-SPLIT or all-REPLICATE, so their behavior is unchanged.
The size-uniformity assertion is kept, not dropped: every non-MLA region in a group must still share one tensor size (this also covers Mamba-like models). The sole exemption is MLA regions — the DeepSeek indexer legitimately differs in size within a
UniformTypeKVCacheSpecsgroup.Two motivations:
FullAttentionSpecregions alongside MLAMLAAttentionSpecregions with different per-layer KV sizes (both land in oneUniformTypeKVCacheSpecsgroup, sinceMLAAttentionSpecsubclassesFullAttentionSpec). The old non-MLA path rejected this at registration. Per-region classification (GQA → SPLIT, MLA → REPLICATE) is what makestp_ratio != 1work for the mix.Repro
vllm serve moonshotai/Kimi-K2.5 \ --tensor-parallel-size {{PREFILL_TP_SIZE}} \ --gpu-memory-utilization {{MEMORY_UTIL}} \ --speculative_config '{"model": "lightseekorg/kimi-k2.5-eagle3", "num_speculative_tokens": 3, "method": "eagle3"}' \ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'Test Plan
Adds
test_handshake_mixed_fa_mla_hetero_tp, covering a mixed full-attn + MLA single KV group under heterogeneous TP, including the per-region gate rejecting a wrongly-scaled replicatedblock_len.Test Result
Full
tests/v1/kv_connector/unit/test_nixl_connector.pyrun: 60 passed, 2 skipped, including the newtest_handshake_mixed_fa_mla_hetero_tpand all pre-existing handshake / heterogeneous-TP tests (test_prefill_tp_size_greater_than_decode_tp_size,test_prefill_tp_size_greater_than_decode_tp_size_mla,test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental,test_kv_connector_stats).test_tp_mapping.py(per-region split handles, FA + SSM split factors): all pass. The refactor that removedRegionTransferClassis behavior-preserving — no tests were added or removed by it.Not a duplicate
The closest open / related NIXL PRs address different axes:
tp_ratio != 1. This PR's per-region predicate is complementary — open to converging on a shared region descriptor with [Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL #44848 (see thread).split_k_and_vpolicies) handles joint-vs-separate K/V tensor layouts across heterogeneous backends — orthogonal to mixing SPLIT and REPLICATE regions within one group.None classify per-region transfer behavior to support a mixed full-attn + MLA group at
tp_ratio != 1.AI assistance disclosure
AI assistance (Claude) was used in preparing this change. A human author has reviewed every changed line and is accountable for the change end-to-end.