Skip to content

[refactor]Optimized the kvcache usage of Deepseek v3.2#6610

Merged
wangxiyuan merged 2 commits intovllm-project:mainfrom
kunpengW-code:v3.2-kvcache-refactor
Feb 9, 2026
Merged

[refactor]Optimized the kvcache usage of Deepseek v3.2#6610
wangxiyuan merged 2 commits intovllm-project:mainfrom
kunpengW-code:v3.2-kvcache-refactor

Conversation

@kunpengW-code
Copy link
Copy Markdown
Contributor

@kunpengW-code kunpengW-code commented Feb 7, 2026

What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in a large amount of kvcache being wasted. The proportion of discarded kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The following shows the comparison of kvcache usage:
Before refactoring

[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens

After refactoring

[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens

This pull request refactors the KV cache allocation for Deepseek v3.2 models that use sparse attention. It replaces the use of FullAttentionSpec with MLAAttentionSpec and introduces a more principled way of calculating KV cache tensor split factors based on model configuration.

This change removes hardcoded values and correctly sizes the cache tensors, leading to optimized memory usage and improved code maintainability.

Does this PR introduce any user-facing change?

No, this is an internal optimization and does not introduce any user-facing changes.

How was this patch tested?

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @kunpengW-code, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the KV cache allocation and management for Deepseek v3.2 models, specifically targeting sparse attention mechanisms. The changes aim to optimize memory usage and improve compatibility by introducing a more dynamic and configuration-driven approach to determine KV cache tensor splitting and shaping, moving away from static or less precise calculations.

Highlights

  • KV Cache Specification Update: The KV cache attention specification for Deepseek v3.2 models, when using sparse attention, has been updated from FullAttentionSpec to MLAAttentionSpec. This change likely better aligns with the underlying attention mechanism of Deepseek v3.2.
  • Dynamic KV Cache Ratio Calculation: The logic for splitting KV cache tensors (k_tensor_split_factor, v_tensor_split_factor, and dsa_k_cache_factor) for Deepseek v3.2 sparse attention has been refactored. Instead of using hardcoded values or simple calculations, these factors are now dynamically derived from a new _get_sparse_kv_cache_ratio() method, which retrieves specific head dimensions from the model's configuration.
  • Refined DSA K-Cache Reshaping: The dsa_k_cache_shape calculation has been improved to use the index_head_dim and num_kv_heads from the model's configuration, providing a more accurate and flexible sizing for the DSA K-cache.
  • New Helper Method for Sparse KV Cache Ratios: A new private method, _get_sparse_kv_cache_ratio(), has been introduced to centralize the retrieval of Deepseek v3.2 specific KV cache ratios (kv_lora_rank, qk_rope_head_dim, index_head_dim), making the code cleaner and more maintainable.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/worker/model_runner_v1.py
    • Removed FullAttentionSpec import and added MLAAttentionSpec import.
    • Modified the calculation of k_tensor_split_factor, v_tensor_split_factor, and dsa_k_cache_factor for sparse attention in Deepseek v3.2, now using _get_sparse_kv_cache_ratio().
    • Adjusted the dsa_k_cache_shape calculation to incorporate index_head_dim and num_kv_heads for more precise tensor shaping.
    • Replaced FullAttentionSpec with MLAAttentionSpec in get_kv_cache_spec for sparse attention, updating head_size and adding cache_dtype_str.
    • Added a new method _get_sparse_kv_cache_ratio() to retrieve model-specific KV cache dimension ratios.
Activity
  • The pull request was opened by kunpengW-code with a title indicating a refactor to optimize KV cache usage for Deepseek v3.2.
  • The description body is currently a template, suggesting no specific human comments or detailed explanations have been added by the author yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 7, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache handling for Deepseek v3.2 models using sparse attention. The changes replace a hardcoded implementation based on FullAttentionSpec with a more robust and configurable approach using MLAAttentionSpec. This results in more accurate KV cache allocation and optimized memory usage. The changes look good and improve code clarity and maintainability.

As per the repository's style guide, I've provided suggestions for the pull request title and summary below.

Suggested PR Title:

[worker][refactor] Optimize kvcache usage for Deepseek v3.2

Suggested PR Summary:

### What this PR does / why we need it?

This pull request refactors the KV cache allocation for Deepseek v3.2 models that use sparse attention. It replaces the use of `FullAttentionSpec` with `MLAAttentionSpec` and introduces a more principled way of calculating KV cache tensor split factors based on model configuration.

This change removes hardcoded values and correctly sizes the cache tensors, leading to optimized memory usage and improved code maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any user-facing changes.

### How was this patch tested?

CI tests should be sufficient to validate this internal refactoring.

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to return a tuple instead of a list.

@wangxiyuan wangxiyuan merged commit 156976b into vllm-project:main Feb 9, 2026
25 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 11, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend:
  [Feat] 310p support MoE W8A8 quantizaition (vllm-project#6641)
  [TEST]add a qwen3-30b acc case with mooncake mempool (vllm-project#6244)
  [MOE Refactor] Remove QuantType in prepare_finalize.py (vllm-project#6534)
  [EPLB] Avoiding eplb's dependency on a specified model (vllm-project#6528)
  [Doc][Misc] Restructure tutorial documentation (vllm-project#6501)
  implement batch invariant with ascendc (vllm-project#6590)
  [Refact]Refact MLA/SFA weight prefetch to consist with moe weight prefetch (vllm-project#6629)
  [Misc] upgrade to vllm main (vllm-project#6646)
  [main][Docs] Fix spelling errors across documentation (vllm-project#6649)
  [bugfix]Fix no attribute 'data' when MLAPO is enable  (vllm-project#6601)
  [DOC]Add Memcache Usage Guide (vllm-project#6476)
  [main][bugfix] Fix spec acceptance rate problem in vllm_0.15.0 (vllm-project#6606)
  [Test][LoRA] Add e2e test for base model inference (vllm-project#6624)
  [refactor]Optimized the kvcache usage of Deepseek v3.2 (vllm-project#6610)
  [Feat](sfa,dcp) support dcp for sfa (vllm-project#6563)
  [BugFix] Add support for rotary_dim parameter when using partial rope in rotary_embedding (vllm-project#6581)
  [fix bug] fix tensor mismatch bug in sigmoid operate test case (vllm-project#6619)
  [Kernel]: Optimize DispatchFFNCombine performance (vllm-project#6468)
  [MISC] Clean up useless env USE_OPTIMIZED_MODEL (vllm-project#6618)
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@d7e17aa

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@d7e17aa

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@d7e17aa

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@d7e17aa

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@d7e17aa

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants