[Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL#44848
[Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL#44848JaredforReal wants to merge 37 commits into
Conversation
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
|
need some accuracy test, draft for now |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Update NIXL KV-transfer worker logic to better support hybrid attention+SSM (Mamba) models and heterogeneous TP/block-size setups by correctly separating FA vs Mamba descriptor regions and relaxing strict block-size divisibility assumptions.
Changes:
- Track per-region SSM vs attention layout (
_is_ssm_region) and use it to filter FA/Mamba descriptor construction and region counts. - Improve handling of heterogeneous block sizes by extending
block_size_ratio()semantics and adding fallbacks during handshake/reads. - Adjust MLA+SSM behavior for multi-rank reads and notifications under TP down/up-scaling.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | Adds SSM region tracking to prevent invalid descriptor builds; adds hetero block-size fallbacks and MLA+SSM multi-rank handling. |
| vllm/distributed/kv_transfer/kv_connector/utils.py | Updates block_size_ratio() to support remote>local via negative ratios and improves docstring/error messages. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def block_size_ratio(self, remote_block_size: int) -> int: | ||
| """Calculate the block size ratio between local and remote.""" | ||
| assert self.block_size % remote_block_size == 0, ( | ||
| f"Local block size {self.block_size} is not divisible " | ||
| f"by remote block size {remote_block_size} or vice versa." | ||
| """Calculate the block size ratio between local and remote. | ||
|
|
||
| Positive when local >= remote (local blocks are larger). | ||
| Negative when remote > local (remote blocks are larger). | ||
| """ | ||
| if self.block_size == remote_block_size: | ||
| return 1 | ||
| if self.block_size > remote_block_size: | ||
| assert self.block_size % remote_block_size == 0, ( | ||
| f"Local block size {self.block_size} is not divisible " | ||
| f"by remote block size {remote_block_size}." | ||
| ) | ||
| return self.block_size // remote_block_size | ||
| assert remote_block_size % self.block_size == 0, ( | ||
| f"Remote block size {remote_block_size} is not divisible " | ||
| f"by local block size {self.block_size}." | ||
| ) | ||
| return self.block_size // remote_block_size | ||
| return -(remote_block_size // self.block_size) |
| try: | ||
| block_size_ratio = self.transfer_topo.block_size_ratio( | ||
| remote_info.remote_block_size | ||
| ) | ||
| except AssertionError: | ||
| block_size_ratio = 1 | ||
| if block_size_ratio > 1: | ||
| # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. | ||
| assert not self._is_hma_required |
| try: | ||
| block_size_ratio = self.transfer_topo.block_size_ratio( | ||
| nixl_agent_meta.block_size | ||
| ) | ||
| except AssertionError: | ||
| # Heterogeneous TP with non-divisible block sizes (e.g. hybrid | ||
| # MLA+GDN). Use 1 as a safe fallback for validation checks. | ||
| block_size_ratio = 1 |
There was a problem hiding this comment.
not a fan of this pattern either, let's do a proper check inside transfer_topo, we should have the elements to determine is this is the case we're trying to catch, and get rid of the except here @JaredforReal
There was a problem hiding this comment.
I have moved the except to utils.py, and replaced block size ratio assert with new comment.
Open to more suggestions
| if ( | ||
| self.block_len_per_layer | ||
| and nixl_agent_meta.block_lens | ||
| and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0] | ||
| ): | ||
| local_bytes = self.block_len_per_layer[0] | ||
| remote_bytes = nixl_agent_meta.block_lens[0] | ||
| if local_bytes > remote_bytes and local_bytes % remote_bytes == 0: | ||
| block_size_ratio = local_bytes // remote_bytes | ||
| elif remote_bytes > local_bytes and remote_bytes % local_bytes == 0: | ||
| block_size_ratio = -(remote_bytes // local_bytes) | ||
| else: | ||
| # Non-exact byte division (e.g. hybrid models with | ||
| # TP-independent MLA component). Use 1 as fallback; | ||
| # _build_fa_remote handles bytes via remote block_lens. | ||
| block_size_ratio = 1 |
| # Only record non-Mamba page sizes. | ||
| if isinstance(layer_spec, MambaSpec): | ||
| if is_ssm: | ||
| self.block_len_per_layer.append( | ||
| physical_page_size // self._physical_blocks_per_logical_kv_block | ||
| ) |
| self.block_len_per_layer = list[int]() | ||
| self._is_ssm_region = list[bool]() |
| # Only build Mamba descriptors for SSM/Mamba regions. | ||
| # Attention/MLA regions do not contain conv or temporal state. | ||
| if i < len(self._is_ssm_region) and not self._is_ssm_region[i]: | ||
| continue |
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
NickLucche
left a comment
There was a problem hiding this comment.
will get back to it cc @ZhanqiuHu
| try: | ||
| block_size_ratio = self.transfer_topo.block_size_ratio( | ||
| nixl_agent_meta.block_size | ||
| ) | ||
| except AssertionError: | ||
| # Heterogeneous TP with non-divisible block sizes (e.g. hybrid | ||
| # MLA+GDN). Use 1 as a safe fallback for validation checks. | ||
| block_size_ratio = 1 |
There was a problem hiding this comment.
not a fan of this pattern either, let's do a proper check inside transfer_topo, we should have the elements to determine is this is the case we're trying to catch, and get rid of the except here @JaredforReal
Signed-off-by: JaredforReal <w13431838023@gmail.com>
We are working on this as a part of the KV cache layout standardization, and also this would make NIXL connector easier to support arbitrary number of views (currently we have dual-view for Attention + Mamba): #45205 |
Do you have any timeline? |
|
This pull request has merge conflicts that must be resolved before it can be |
NickLucche
left a comment
There was a problem hiding this comment.
@JaredforReal thanks for your patience, I am trying to land a big PR here #35264 which has been stuck due to rebasing this week.
In the interest of time, we can get this merged after that one.
This way we don't have to wait for the big refactor, as timeline is still unclear @ZJY0516 (we were hoping 0.24 but it'll slip to 0.25)
|
@JaredforReal in the meantime, is there any model we can use for e2e testing to ensure KDA+MLA is covered? |
This is the only oss model uses MLA + KDA |
|
Hi, I wasn't able to reproduce the accurate regression on main using |
'lm_eval --model local-completions --model_args "model=kimi,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=128000" --tasks gsm8k --num_fewshot 5' Right here @ZhanqiuHu |
What about the serving commands? You mentioned in your PR description that your setting is |
|
@ZhanqiuHu Serving command and Benchmark command are already in the PR Description |
|
Sorry let me confirm if I understand the issue correctly. It seems like you also mentioned a bug with homogeneous TP setting? Just would like to confirm is homogeneous TP is broken on your end too? |
Summary
Enable homogeneous-TP PD separation for KimiLinear — a hybrid model with 20 KDA/GDN (SSM) layers and 7 MLA attention layers sharing physical tensors via HMA. All changes are in
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py.Problem
KimiLinear's HMA pools KDA and MLA layers into shared physical tensors, creating 7 dual-purpose regions. Each region has two views with different strides:
num_kv_heads=1)The existing code stored only the KDA stride in
block_len_per_layerand used it for all FA descriptors, causing MLA data to be read at the wrong stride.Changes
Dual-purpose region detection: Added
_is_ssm_region,_is_attn_region, and_attn_block_len(maps region index → MLA stride). Populated during HMA dedup when multiple layer types share the samebase_addr.FA descriptors use MLA stride for dual-purpose regions (MLA only): In
_build_fa_local/_build_fa_remote, dual-purpose regions use_attn_block_len[i]aspage_strideandkv_block_len. This path is gated byself.use_mla— standard attention models (e.g. Qwen GQA) fall through to the originalblock_size_ratiopath, preserving heterogeneous TP correctness.NIXL registration expansion: Dual-purpose regions expand registration size to
num_blocks * max(KDA_stride, MLA_stride).num_regions = sum(_is_attn_region): Correctly counts dual-purpose regions as attention regions for FA descriptor ID computation.SSM/attention block length validation: Handshake validation checks SSM and attention regions independently with appropriate assertions.
HMA assertion relaxed for MLA+SSM: The
assert block_size_ratio == 1check is relaxed when both MLA and SSM are present, since SSM scales with TP but MLA attention is replicated.Non-HMA Model Safety
attn_stridepath is gated byself.use_mla— Qwen (standard GQA) always uses the originalblock_size_ratiopath, so heterogeneous TP is unaffected._attn_block_lenis empty, so all new code paths are no-ops.block_size_ratiocomputation is also gated byself.use_mla.Testing
block_size_ratiobypass, remote descriptorsattn_strideignored whenuse_mla=False,block_size_ratiocorrectly applied for heterogeneous TPnum_regionscomputation, SSM-only skipping, non-HMA pathsNon-HMA Model Safety
All new code paths are gated behind
_attn_block_len.get(i)which returnsNonefor non-HMA models, falling through to the existing code path. Verified for:_is_attn_regionall False → no FA descs built_attn_block_lenempty → standard pathbackward compatibility for Qwen3.6-27B test
Detailed Serving and Testing:
1P1D TP4:
TP8:
Testing
Known Issue
Kimi Linear PD dissagg the TP=2 instance would hang in CG warmup, need
--enforce-eager