[Model] Introduce Kimi Linear to vLLM#27809
Conversation
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
|
Documentation preview: https://vllm--27809.org.readthedocs.build/en/27809/ |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Kimi Linear model, a hybrid attention model. The changes include adding the model to the documentation and test registry, and implementing the core model logic in new files, notably vllm/model_executor/layers/kda.py for Kimi Delta Attention and vllm/model_executor/models/kimi_linear.py for the overall model structure. My review identified a critical issue in the KimiMoE implementation regarding the handling of shared experts, which is both inefficient and fragile. Additionally, I found a latent bug in KimiDeltaAttention where incorrect biases were used for k and v convolutions in the prefill path. These issues should be addressed to ensure correctness and robustness.
| if self.num_shared_experts is not None: | ||
| shared_output = self.shared_experts(hidden_states) | ||
| router_logits, _ = self.gate(hidden_states) | ||
| final_hidden_states = ( | ||
| self.experts(hidden_states=hidden_states, router_logits=router_logits) | ||
| * self.routed_scaling_factor | ||
| ) | ||
| if shared_output is not None: | ||
| final_hidden_states = final_hidden_states + shared_output |
There was a problem hiding this comment.
The logic for handling shared_experts is inefficient and not robust. When config.num_shared_experts is 0, a KimiMLP instance is still created and called, performing a no-op. More critically, if config.num_shared_experts could be None, this would lead to an UnboundLocalError as shared_output would be used before assignment.
To fix this, self.shared_experts should be initialized to None in __init__ and only created if num_shared_experts > 0. The forward method should then handle self.shared_experts being None.
| if self.num_shared_experts is not None: | |
| shared_output = self.shared_experts(hidden_states) | |
| router_logits, _ = self.gate(hidden_states) | |
| final_hidden_states = ( | |
| self.experts(hidden_states=hidden_states, router_logits=router_logits) | |
| * self.routed_scaling_factor | |
| ) | |
| if shared_output is not None: | |
| final_hidden_states = final_hidden_states + shared_output | |
| shared_output = None | |
| if self.shared_experts is not None: | |
| shared_output = self.shared_experts(hidden_states) | |
| router_logits, _ = self.gate(hidden_states) | |
| final_hidden_states = ( | |
| self.experts(hidden_states=hidden_states, router_logits=router_logits) | |
| * self.routed_scaling_factor | |
| ) | |
| if shared_output is not None: | |
| final_hidden_states = final_hidden_states + shared_output |
vllm/model_executor/layers/kda.py
Outdated
| k = causal_conv1d_fn( | ||
| k_proj_states, | ||
| k_conv_weights, | ||
| self.q_conv1d.bias, |
There was a problem hiding this comment.
There appears to be a copy-paste error here. The bias for the k convolution (self.k_conv1d.bias) should be used, but self.q_conv1d.bias is used instead. Although all biases are currently None, this is a latent bug that could cause issues if biases are introduced later. The decode path correctly uses the respective biases.
| self.q_conv1d.bias, | |
| self.k_conv1d.bias, |
vllm/model_executor/layers/kda.py
Outdated
| v = causal_conv1d_fn( | ||
| v_proj_states, | ||
| v_conv_weights, | ||
| self.q_conv1d.bias, |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm/model_executor/models/config.py
Outdated
There was a problem hiding this comment.
Handle missing Mamba chunk size when prefix caching
The new prefix-caching branch now sets base_chunk_size = model_config.get_mamba_chunk_size() unconditionally. Many Mamba-derived configs do not expose a mamba_chunk_size/chunk_size attribute and this helper therefore returns None. The subsequent lcm(base_chunk_size, kernel_block_alignment_size) will raise a TypeError as soon as prefix caching is enabled for such a model, whereas the previous code fell back to the user‑resolved mamba_block_size. This regression prevents hybrid Mamba models without explicit chunk size metadata from starting with prefix caching. Consider keeping the previous fallback to cache_config.mamba_block_size or guarding against None values before calling lcm.
Useful? React with 👍 / 👎.
both k_conv1d and v_conv have no bias Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| num_tokens, hidden_size = hidden_states.shape |
| def lcm(a, b): | ||
| return a * b // gcd(a, b) | ||
|
|
||
| base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() |
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Purpose
Introducing Kimi Linear, an advanced hybrid attention model that combines the efficiency of Kimi Delta Attention (KDA), a refined version of Gated DeltaNet, with reduced memory requirements and superior performance across short, long, and reinforcement learning contexts. By introducing an optimized gating mechanism, Kimi Linear significantly cuts down the need for large KV caches by up to 75%, enhancing hardware efficiency and boosting decoding throughput by up to$6\times$ for long-context tasks, such as those with up to 1M tokens.
Previous PR: #27654
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.