Skip to content

[Model] Introduce Kimi Linear to vLLM#27809

Merged
youkaichao merged 2 commits intovllm-project:mainfrom
zhiyuan1i:kimi-linear
Oct 30, 2025
Merged

[Model] Introduce Kimi Linear to vLLM#27809
youkaichao merged 2 commits intovllm-project:mainfrom
zhiyuan1i:kimi-linear

Conversation

@zhiyuan1i
Copy link
Contributor

@zhiyuan1i zhiyuan1i commented Oct 30, 2025

Purpose

Introducing Kimi Linear, an advanced hybrid attention model that combines the efficiency of Kimi Delta Attention (KDA), a refined version of Gated DeltaNet, with reduced memory requirements and superior performance across short, long, and reinforcement learning contexts. By introducing an optimized gating mechanism, Kimi Linear significantly cuts down the need for large KV caches by up to 75%, enhancing hardware efficiency and boosting decoding throughput by up to $6\times$ for long-context tasks, such as those with up to 1M tokens.

Previous PR: #27654

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
@mergify
Copy link

mergify bot commented Oct 30, 2025

Documentation preview: https://vllm--27809.org.readthedocs.build/en/27809/

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models v1 labels Oct 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Kimi Linear model, a hybrid attention model. The changes include adding the model to the documentation and test registry, and implementing the core model logic in new files, notably vllm/model_executor/layers/kda.py for Kimi Delta Attention and vllm/model_executor/models/kimi_linear.py for the overall model structure. My review identified a critical issue in the KimiMoE implementation regarding the handling of shared experts, which is both inefficient and fragile. Additionally, I found a latent bug in KimiDeltaAttention where incorrect biases were used for k and v convolutions in the prefill path. These issues should be addressed to ensure correctness and robustness.

Comment on lines +163 to +171
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for handling shared_experts is inefficient and not robust. When config.num_shared_experts is 0, a KimiMLP instance is still created and called, performing a no-op. More critically, if config.num_shared_experts could be None, this would lead to an UnboundLocalError as shared_output would be used before assignment.

To fix this, self.shared_experts should be initialized to None in __init__ and only created if num_shared_experts > 0. The forward method should then handle self.shared_experts being None.

Suggested change
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
shared_output = None
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output

k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
self.q_conv1d.bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There appears to be a copy-paste error here. The bias for the k convolution (self.k_conv1d.bias) should be used, but self.q_conv1d.bias is used instead. Although all biases are currently None, this is a latent bug that could cause issues if biases are introduced later. The decode path correctly uses the respective biases.

Suggested change
self.q_conv1d.bias,
self.k_conv1d.bias,

v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
self.q_conv1d.bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the k convolution, the bias for the v convolution (self.v_conv1d.bias) should be used here instead of self.q_conv1d.bias.

Suggested change
self.q_conv1d.bias,
self.v_conv1d.bias,

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle missing Mamba chunk size when prefix caching

The new prefix-caching branch now sets base_chunk_size = model_config.get_mamba_chunk_size() unconditionally. Many Mamba-derived configs do not expose a mamba_chunk_size/chunk_size attribute and this helper therefore returns None. The subsequent lcm(base_chunk_size, kernel_block_alignment_size) will raise a TypeError as soon as prefix caching is enabled for such a model, whereas the previous code fell back to the user‑resolved mamba_block_size. This regression prevents hybrid Mamba models without explicit chunk size metadata from starting with prefix caching. Consider keeping the previous fallback to cache_config.mamba_block_size or guarding against None values before calling lcm.

Useful? React with 👍 / 👎.

both k_conv1d and v_conv have no bias

Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@youkaichao youkaichao merged commit 4e68cc9 into vllm-project:main Oct 30, 2025
3 of 5 checks passed
Comment on lines +160 to +161
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an issue with TP Attn + EP MoE we should handle here. We can follow what I did in #24134 and #24982

def lcm(a, b):
return a * b // gcd(a, b)

base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression of #27289

ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants