Skip to content

[Feature]KV pool supports sparse attention#6339

Merged
LCAIZJ merged 1 commit intovllm-project:mainfrom
luoxiaolin712:kvpool_sparse_attention
Feb 5, 2026
Merged

[Feature]KV pool supports sparse attention#6339
LCAIZJ merged 1 commit intovllm-project:mainfrom
luoxiaolin712:kvpool_sparse_attention

Conversation

@luoxiaolin712
Copy link
Copy Markdown
Contributor

@luoxiaolin712 luoxiaolin712 commented Jan 28, 2026

What this PR does / why we need it?

The kv pooling feature is adapted to Sparse Attention to support models such as Deepseek V3.2.

Does this PR introduce any user-facing change?

NA

How was this patch tested?

vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for sparse attention to the KV pooling feature, enabling models like Deepseek V3.2. The changes introduce a use_sparse flag and associated logic to handle the multi-component KV cache structure of sparse attention models. My review identifies a critical bug in the block length calculation that could lead to incorrect behavior if cache components have different data types. Additionally, I've pointed out an opportunity to refactor a method with significant code duplication to improve maintainability. Both comments include code suggestions to address the issues.

Comment on lines +155 to +159
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in how self.block_len is calculated. The code uses first_kv_cache[i].element_size(), but first_kv_cache is an alias for first_kv_cache_tuple[0]. This means you are incorrectly using the element size of the first tensor in first_kv_cache_tuple for all components. If the tensors for different components (norm, pe, k) have different data types, this will lead to incorrect block length calculations. You should use the element_size() of each respective tensor from first_kv_cache_tuple.

Suggested change
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
self.block_len = [
first_kv_cache_tuple[0].element_size() * math.prod(block_shape_norm),
first_kv_cache_tuple[1].element_size() * math.prod(block_shape_pe),
first_kv_cache_tuple[2].element_size() * math.prod(block_shape_k),
]

Comment on lines +138 to +160
block_id = block_ids[start // self.block_size]
if self.use_mla:
if self.use_sparse:
addr_k = self.kv_caches_base_addr[layer_id * 3] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 3 + 1] + block_id * self.block_len[1]
addr_dsa_k = self.kv_caches_base_addr[layer_id * 3 + 2] + block_id * self.block_len[2]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
length_dsa_k = int(self.block_len[2] / self.block_size * (end - start))
size_list = [length_k, length_v, length_dsa_k]
addr_list = [addr_k, addr_v, addr_dsa_k]
elif self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
addr_list = [addr_k, addr_v]
else:
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
addr_list = [addr_k, addr_v]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for use_sparse, use_mla, and the default case in this method is very similar and contains a lot of duplicated code. This makes the code harder to maintain and increases the risk of introducing bugs if one path is updated but not the others. Consider refactoring this method to remove the duplication.

        block_id = block_ids[start // self.block_size]
        if self.use_sparse:
            num_components = 3
            block_lens = self.block_len
        elif self.use_mla:
            num_components = 2
            block_lens = self.block_len
        else:
            num_components = 2
            block_lens = [self.block_len[0]] * 2

        addr_list = []
        size_list = []
        len_multiplier = (end - start) / self.block_size

        for i in range(num_components):
            base_addr = self.kv_caches_base_addr[layer_id * num_components + i]
            block_len = block_lens[i]
            addr = base_addr + block_id * block_len
            length = int(block_len * len_multiplier)
            addr_list.append(addr)
            size_list.append(length)

@luoxiaolin712 luoxiaolin712 force-pushed the kvpool_sparse_attention branch from 9c480b1 to f4c8524 Compare January 28, 2026 06:54
@Pz1116
Copy link
Copy Markdown
Collaborator

Pz1116 commented Jan 29, 2026

You should submit an RFC as an issue, describing designs about KVPool supporting sparse attention, and you can mention this PR in the RFC. There are plenty of RFC issues you can refer to as templates.
Consider rename the PR to [Feature][KV pool supports sparse attention]

@luoxiaolin712 luoxiaolin712 changed the title [RFC] KV pool supports sparse attention [Feature]KV pool supports sparse attention Jan 29, 2026
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
if self.use_sparse:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Refactor this part as the three if-else are pretty similar

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Refactor this part as the three if-else are pretty similar

Done

@luoxiaolin712 luoxiaolin712 force-pushed the kvpool_sparse_attention branch 2 times, most recently from 1c40726 to b0dcc29 Compare January 29, 2026 02:53
@luoxiaolin712
Copy link
Copy Markdown
Contributor Author

luoxiaolin712 commented Feb 2, 2026

You should submit an RFC as an issue, describing designs about KVPool supporting sparse attention, and you can mention this PR in the RFC. There are plenty of RFC issues you can refer to as templates. Consider rename the PR to [Feature][KV pool supports sparse attention]

Thanks,RFC#6411

@luoxiaolin712 luoxiaolin712 force-pushed the kvpool_sparse_attention branch from 75e551e to b8e93eb Compare February 2, 2026 02:26
@luoxiaolin712 luoxiaolin712 force-pushed the kvpool_sparse_attention branch 3 times, most recently from 362bd97 to f4f4b57 Compare February 2, 2026 07:16
Copy link
Copy Markdown
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianzs @LCAIZJ @liziyu179 please take a look at this refactor

@LCAIZJ
Copy link
Copy Markdown
Collaborator

LCAIZJ commented Feb 5, 2026

This PR modifies the KV pool to support sparse attention, but it also changes the P2P connector. It's recommended to split it into two separate PRs for better clarity.

Signed-off-by: lty <linhebiwen@gmail.com>
@luoxiaolin712 luoxiaolin712 force-pushed the kvpool_sparse_attention branch from f4f4b57 to 2981abc Compare February 5, 2026 02:25
@LCAIZJ LCAIZJ merged commit 33b8ca4 into vllm-project:main Feb 5, 2026
16 checks passed
@LCAIZJ LCAIZJ added ready read for review ready-for-test start test by label for PR labels Feb 5, 2026
@luoxiaolin712
Copy link
Copy Markdown
Contributor Author

This PR modifies the KV pool to support sparse attention, but it also changes the P2P connector. It's recommended to split it into two separate PRs for better clarity.

Thanks,This has been split into two PRs. P2P PR is #6551

845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 6, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (59 commits)
  [Feat.]: 310p support MOE models (vllm-project#6530)
  [Doc] backport 0.13.0 release note (vllm-project#6584)
  [CI] Update UT CANN version to 8.5.0 for main branch (vllm-project#6564)
  [CI] Change A2 runner (vllm-project#6557)
  [Bugfix] Fix the incorrect use of the output parameter in _forward_fia_slidingwindow (vllm-project#6469)
  [main2main] upgrade vllm main 0202 (vllm-project#6560)
  [CI][npugraph_ex]Fix npugraph ex e2e test (vllm-project#6553)
  [Feature]KV pool supports sparse attention (vllm-project#6339)
  [bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (vllm-project#6491)
  perf: adaptive block size selection in linear_persistent kernel (vllm-project#6537)
  [ModelRunner][Fix] Pads query_start_loc to satisfy FIA/TND constraint (vllm-project#6475)
  [Bugfix]Fix of Pooling Code and Update of Pooling Usage Guide (vllm-project#6126)
  [Fusion] Add rmsnorm dynamic quant fusion pass (vllm-project#6274)
  [Bugfix] Synchronize only the current stream to avoid device sync (vllm-project#6432)
  [CI] Add long and short prompt tests for DeepSeek-V3.2 (vllm-project#6499)
  [Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (vllm-project#6442)
  [bugfix][npugraph_ex]duplicate pattern issue (vllm-project#6513)
  [bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (vllm-project#6430)
  [Quant] GLM4.7-Flash Support W8A8 (vllm-project#6492)
  [Nightly][BugFix] Remove kv_cache nz test case for test_mla_preprocess_nq.py (vllm-project#6505)
  ...
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: lty <linhebiwen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants