Skip to content

[Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL#44848

Open
JaredforReal wants to merge 37 commits into
vllm-project:mainfrom
JaredforReal:kda
Open

[Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL#44848
JaredforReal wants to merge 37 commits into
vllm-project:mainfrom
JaredforReal:kda

Conversation

@JaredforReal

@JaredforReal JaredforReal commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Summary

Enable homogeneous-TP PD separation for KimiLinear — a hybrid model with 20 KDA/GDN (SSM) layers and 7 MLA attention layers sharing physical tensors via HMA. All changes are in vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py.

Problem

KimiLinear's HMA pools KDA and MLA layers into shared physical tensors, creating 7 dual-purpose regions. Each region has two views with different strides:

  • KDA stride: TP-dependent (conv/ssm dimensions scale with TP)
  • MLA stride: TP-independent (num_kv_heads=1)

The existing code stored only the KDA stride in block_len_per_layer and used it for all FA descriptors, causing MLA data to be read at the wrong stride.

PD disagg 4+4 at main branch gsm8k too low
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.3624 | ±   | 0.0132 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.3245 | ±   | 0.0129 |

Changes

  1. Dual-purpose region detection: Added _is_ssm_region, _is_attn_region, and _attn_block_len (maps region index → MLA stride). Populated during HMA dedup when multiple layer types share the same base_addr.

  2. FA descriptors use MLA stride for dual-purpose regions (MLA only): In _build_fa_local/_build_fa_remote, dual-purpose regions use _attn_block_len[i] as page_stride and kv_block_len. This path is gated by self.use_mla — standard attention models (e.g. Qwen GQA) fall through to the original block_size_ratio path, preserving heterogeneous TP correctness.

  3. NIXL registration expansion: Dual-purpose regions expand registration size to num_blocks * max(KDA_stride, MLA_stride).

  4. num_regions = sum(_is_attn_region): Correctly counts dual-purpose regions as attention regions for FA descriptor ID computation.

  5. SSM/attention block length validation: Handshake validation checks SSM and attention regions independently with appropriate assertions.

  6. HMA assertion relaxed for MLA+SSM: The assert block_size_ratio == 1 check is relaxed when both MLA and SSM are present, since SSM scales with TP but MLA attention is replicated.

Non-HMA Model Safety

  • The attn_stride path is gated by self.use_mla — Qwen (standard GQA) always uses the original block_size_ratio path, so heterogeneous TP is unaffected.
  • For non-HMA models, _attn_block_len is empty, so all new code paths are no-ops.
  • Byte-level block_size_ratio computation is also gated by self.use_mla.

Testing

  • Verified: KimiLinear homogeneous TP PD separation with accuracy ~0.89
  • Verified: Qwen3.5 heterogeneous TP PD separation (P4D2) still works correctly
  • Unit tests: 20+ tests covering:
    • KimiLinear: MLA stride selection, block_size_ratio bypass, remote descriptors
    • Qwen regression: attn_stride ignored when use_mla=False, block_size_ratio correctly applied for heterogeneous TP
    • num_regions computation, SSM-only skipping, non-HMA paths
Kimi Linear TP8 gsm8k

| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.8886 | ±   | 0.0087 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.8688 | ±   | 0.0093 |

Kimi Linear P4D4 gsm8k
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.8832 | ±   | 0.0088 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.8666 | ±   | 0.0094 |

Kimi Linear P2D2 gsm8k
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.8916 | ±   | 0.0086 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.8749 | ±   | 0.0091 |

Non-HMA Model Safety

All new code paths are gated behind _attn_block_len.get(i) which returns None for non-HMA models, falling through to the existing code path. Verified for:

  • Qwen3.5 GDN (pure GDN, no HMA): _is_attn_region all False → no FA descs built
  • Pure MLA (no HMA): _attn_block_len empty → standard path
  • Pure attention (no SSM): unaffected

backward compatibility for Qwen3.6-27B test

# P2+D2 gsm8k
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.6005 | ±   | 0.0135 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.5845 | ±   | 0.0136 |

# P4+D2
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.5974 | ±   | 0.0135 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.5777 | ±   | 0.0136 |

# P4+D4
| Tasks | Version | Filter           | n-shot | Metric      |     |  Value |     | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | -----: | --- | -----: |
| gsm8k |       3 | flexible-extract |      5 | exact_match | ↑   | 0.5967 | ±   | 0.0135 |
|       |         | strict-match     |      5 | exact_match | ↑   | 0.5785 | ±   | 0.0136 |

Detailed Serving and Testing:

1P1D TP4:

# Prefill:
CUDA_VISIBLE_DEVICES=0,1,2,3 \
VLLM_SSM_CONV_STATE_LAYOUT=DS \
VLLM_KV_CACHE_LAYOUT=HND \
VLLM_NIXL_SIDE_CHANNEL_PORT=5560 \
UCX_NET_DEVICES=all \
vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct \
  --trust-remote-code \
  --port 8100 \
  --host 0.0.0.0 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.9 \
  --no-disable-hybrid-kv-cache-manager \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' \
  --served-model-name kimi \
  --no-async-scheduling

# Decode
CUDA_VISIBLE_DEVICES=4,5,6,7 \
VLLM_SSM_CONV_STATE_LAYOUT=DS \
VLLM_KV_CACHE_LAYOUT=HND \
VLLM_NIXL_SIDE_CHANNEL_PORT=5660 \
UCX_NET_DEVICES=all \
vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct \
  --trust-remote-code \
  --port 8200 \
  --host 0.0.0.0 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.9 \
  --no-disable-hybrid-kv-cache-manager \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' \
  --served-model-name kimi \
  --no-async-scheduling

# Router
vllm-router --policy round_robin  --vllm-pd-disaggregation   --prefill http://10.1.50.70:8100     --decode http://10.1.50.70:8200     --host 0.0.0.0     --port 8000     --intra-node-data-parallel-size 1

TP8:

vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct/ \
  --tensor-parallel 8 \
  --served-model-name kimi \
  --trust-remote-code \
  --gpu-memory-utilization 0.9 \
  --port 8000 \
  --host 0.0.0.0

Testing

lm_eval --model local-completions --model_args "model=kimi,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=128000" --tasks gsm8k --num_fewshot 5

Known Issue

Kimi Linear PD dissagg the TP=2 instance would hang in CG warmup, need --enforce-eager

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copilot AI review requested due to automatic review settings June 8, 2026 07:59

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the kv-connector label Jun 8, 2026
@JaredforReal JaredforReal marked this pull request as draft June 8, 2026 08:00
@JaredforReal

Copy link
Copy Markdown
Contributor Author

need some accuracy test, draft for now

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Update NIXL KV-transfer worker logic to better support hybrid attention+SSM (Mamba) models and heterogeneous TP/block-size setups by correctly separating FA vs Mamba descriptor regions and relaxing strict block-size divisibility assumptions.

Changes:

  • Track per-region SSM vs attention layout (_is_ssm_region) and use it to filter FA/Mamba descriptor construction and region counts.
  • Improve handling of heterogeneous block sizes by extending block_size_ratio() semantics and adding fallbacks during handshake/reads.
  • Adjust MLA+SSM behavior for multi-rank reads and notifications under TP down/up-scaling.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Adds SSM region tracking to prevent invalid descriptor builds; adds hetero block-size fallbacks and MLA+SSM multi-rank handling.
vllm/distributed/kv_transfer/kv_connector/utils.py Updates block_size_ratio() to support remote>local via negative ratios and improves docstring/error messages.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +535 to +553
def block_size_ratio(self, remote_block_size: int) -> int:
"""Calculate the block size ratio between local and remote."""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
"""Calculate the block size ratio between local and remote.

Positive when local >= remote (local blocks are larger).
Negative when remote > local (remote blocks are larger).
"""
if self.block_size == remote_block_size:
return 1
if self.block_size > remote_block_size:
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size}."
)
return self.block_size // remote_block_size
assert remote_block_size % self.block_size == 0, (
f"Remote block size {remote_block_size} is not divisible "
f"by local block size {self.block_size}."
)
return self.block_size // remote_block_size
return -(remote_block_size // self.block_size)
Comment on lines 2210 to 2218
try:
block_size_ratio = self.transfer_topo.block_size_ratio(
remote_info.remote_block_size
)
except AssertionError:
block_size_ratio = 1
if block_size_ratio > 1:
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert not self._is_hma_required
Comment on lines +1488 to +1495
try:
block_size_ratio = self.transfer_topo.block_size_ratio(
nixl_agent_meta.block_size
)
except AssertionError:
# Heterogeneous TP with non-divisible block sizes (e.g. hybrid
# MLA+GDN). Use 1 as a safe fallback for validation checks.
block_size_ratio = 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of this pattern either, let's do a proper check inside transfer_topo, we should have the elements to determine is this is the case we're trying to catch, and get rid of the except here @JaredforReal

@JaredforReal JaredforReal Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have moved the except to utils.py, and replaced block size ratio assert with new comment.
Open to more suggestions

Comment on lines +1358 to +1373
if (
self.block_len_per_layer
and nixl_agent_meta.block_lens
and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0]
):
local_bytes = self.block_len_per_layer[0]
remote_bytes = nixl_agent_meta.block_lens[0]
if local_bytes > remote_bytes and local_bytes % remote_bytes == 0:
block_size_ratio = local_bytes // remote_bytes
elif remote_bytes > local_bytes and remote_bytes % local_bytes == 0:
block_size_ratio = -(remote_bytes // local_bytes)
else:
# Non-exact byte division (e.g. hybrid models with
# TP-independent MLA component). Use 1 as fallback;
# _build_fa_remote handles bytes via remote block_lens.
block_size_ratio = 1
Comment on lines 908 to 912
# Only record non-Mamba page sizes.
if isinstance(layer_spec, MambaSpec):
if is_ssm:
self.block_len_per_layer.append(
physical_page_size // self._physical_blocks_per_logical_kv_block
)
Comment on lines 850 to +851
self.block_len_per_layer = list[int]()
self._is_ssm_region = list[bool]()
Comment on lines +1104 to +1107
# Only build Mamba descriptors for SSM/Mamba regions.
# Attention/MLA regions do not contain conv or temporal state.
if i < len(self._is_ssm_region) and not self._is_ssm_region[i]:
continue
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@mergify mergify Bot added the v1 label Jun 9, 2026
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal JaredforReal changed the title [Core] Support heterogeneous TP disaggregated inference for hybrid MLA+GDN (KDA) models [Core] Enable KimiLinear (KDA/GDN + MLA) PD Separation via NIXL Jun 9, 2026
@JaredforReal JaredforReal marked this pull request as ready for review June 9, 2026 07:54
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will get back to it cc @ZhanqiuHu

Comment on lines +1488 to +1495
try:
block_size_ratio = self.transfer_topo.block_size_ratio(
nixl_agent_meta.block_size
)
except AssertionError:
# Heterogeneous TP with non-divisible block sizes (e.g. hybrid
# MLA+GDN). Use 1 as a safe fallback for validation checks.
block_size_ratio = 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of this pattern either, let's do a proper check inside transfer_topo, we should have the elements to determine is this is the case we're trying to catch, and get rid of the except here @JaredforReal

Signed-off-by: JaredforReal <w13431838023@gmail.com>
@ZhanqiuHu

Copy link
Copy Markdown
Contributor

will get back to it cc @ZhanqiuHu

We are working on this as a part of the KV cache layout standardization, and also this would make NIXL connector easier to support arbitrary number of views (currently we have dual-view for Attention + Mamba): #45205

@ZJY0516

ZJY0516 commented Jun 11, 2026

Copy link
Copy Markdown
Member

will get back to it cc @ZhanqiuHu

We are working on this as a part of the KV cache layout standardization, and also this would make NIXL connector easier to support arbitrary number of views (currently we have dual-view for Attention + Mamba): #45205

Do you have any timeline?

@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JaredforReal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 12, 2026

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JaredforReal thanks for your patience, I am trying to land a big PR here #35264 which has been stuck due to rebasing this week.
In the interest of time, we can get this merged after that one.
This way we don't have to wait for the big refactor, as timeline is still unclear @ZJY0516 (we were hoping 0.24 but it'll slip to 0.25)

@NickLucche

Copy link
Copy Markdown
Member

@JaredforReal in the meantime, is there any model we can use for e2e testing to ensure KDA+MLA is covered?
I am afraid moonshotai/Kimi-Linear-48B-A3B-Instruct is larger than the models we'd like to target for CI runs

@ZJY0516

ZJY0516 commented Jun 12, 2026

Copy link
Copy Markdown
Member

@JaredforReal in the meantime, is there any model we can use for e2e testing to ensure KDA+MLA is covered? I am afraid moonshotai/Kimi-Linear-48B-A3B-Instruct is larger than the models we'd like to target for CI runs

This is the only oss model uses MLA + KDA

@ZhanqiuHu

Copy link
Copy Markdown
Contributor

Hi, I wasn't able to reproduce the accurate regression on main using moonshotai/Kimi-Linear-48B-A3B-Instruct. Do you have the commands and logs you used?

@JaredforReal

Copy link
Copy Markdown
Contributor Author

Hi, I wasn't able to reproduce the accurate regression on main using moonshotai/Kimi-Linear-48B-A3B-Instruct. Do you have the commands and logs you used?

'lm_eval --model local-completions --model_args "model=kimi,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=128000" --tasks gsm8k --num_fewshot 5'

Right here @ZhanqiuHu

@ZhanqiuHu

ZhanqiuHu commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Hi, I wasn't able to reproduce the accurate regression on main using moonshotai/Kimi-Linear-48B-A3B-Instruct. Do you have the commands and logs you used?

'lm_eval --model local-completions --model_args "model=kimi,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=128000" --tasks gsm8k --num_fewshot 5'

Right here @ZhanqiuHu

What about the serving commands?
Great if you have the logs too.

You mentioned in your PR description that your setting is PD disagg 4+4, but are you actually using hetero TP by any chance?

@JaredforReal

Copy link
Copy Markdown
Contributor Author

@ZhanqiuHu Serving command and Benchmark command are already in the PR Description
Since KDA State split along side TP Size while KVCache of MLA replicated in each rank, TP2's block-size cannot divided by TP4's block-size should be the main issue that blocking the hetero TP PD for KimiLinear
And SGL seems not support hetero TP for KimiLinear either

@ZhanqiuHu

ZhanqiuHu commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Sorry let me confirm if I understand the issue correctly. It seems like you also mentioned a bug with homogeneous TP setting? Just would like to confirm is homogeneous TP is broken on your end too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants