Skip to content

[Feature] Add DCP support for GQA with flashinfer#14982

Open
FENP wants to merge 15 commits intosgl-project:mainfrom
FENP:dcp_for_gqa_fi
Open

[Feature] Add DCP support for GQA with flashinfer#14982
FENP wants to merge 15 commits intosgl-project:mainfrom
FENP:dcp_for_gqa_fi

Conversation

@FENP
Copy link
Copy Markdown

@FENP FENP commented Dec 12, 2025

Motivation

Following RFC #12196, this PR adds decode context parallel (DCP) support for GQA models, aimed at reducing KV redundancy when the TP size exceeds the number of key-value heads. A similar work has already been adopted by vLLM (vllm-project/vllm#24864).

Difference between DCP and DP Attention

Although both DCP and DP Attention can reduce KV cache memory redundancy, their implementation approaches and effects are entirely different.

  • DP Attention applies data parallelism to the attention, where each DP worker independently processes a distinct subset of the batch. This means DP Attention splitting along the sequence dimension, placing different sequences on different GPUs to avoid KV cache redundancy.
  • DCP is built on top of tensor parallelism (TP) and stores tokens (KV Cache) of a sequence across different DCP workers in an interleaved manner. This means DCP splitting along the token dimension, placing tokens from the same sequence across different GPUs to avoid KV cache redundancy.

In comparison, DCP offers the following advantages that DP Attention does not have.

  1. The KV cache for a single sequence can be distributed across all GPUs, which is highly beneficial for handling extremely long contexts. In contrast, DP Attention can only store the KV cache of a single sequence on a subset of GPUs (determined by attn_tp_size).
  2. The attention computation for each sequence can leverage all GPUs. In contrast, DP Attention can only utilize a subset of GPUs (determined by attn_tp_size). This means DCP has the potential to avoid KV redundancy while maintaining low latency (close to that of tensor parallelism). We need to further optimize the overhead of DCP, especially its communication overhead, to achieve this benefit.
  3. Avoids the redundant Q, K, V, and O weight storage additionally introduced by data parallelism.

Modifications

  • Add DCP args (--dcp-size)
  • Add DCP communication group
  • Modify KV cache management and allocation.
    • In the logical view (Radix cache and TokenToKVPoolAllocator), the page size is scaled up by a factor of dcp-size compared to the base page size.
    • For the actual KV buffer (TokenToKVPool), the page size remains unchanged.
    • KV are stored in an interleaved style across the KV Pool of the corresponding DCP rank.
  • Modify flashinfer backend. Each dcp rank only stores a partial segment of the KV cache, so we must compute attention output using online softmax to ensure correct.

Usage

# TP8 (Baseline)
python -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 --tp 8 --attention-backend flashinfer  --enable-symm-mem

# DCP2TP8
python -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 --tp 8 --dcp 2 --attention-backend flashinfer  --enable-symm-mem

Accuracy Tests

Qwen3-235B-A22B-Instruct-2507 test results.

  • TP8 (Baseline)
$python3 -m sglang.test.few_shot_gsm8k  --parallel 128 --max-new-tokens 512
100%|███████████████████| 200/200 [00:16<00:00, 12.42it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 17.437 s
Output throughput: 1586.356 token/s
  • DCP2TP8
python3 -m sglang.test.few_shot_gsm8k  --parallel 128 --max-new-tokens 512
100%|███████████████████| 200/200 [00:20<00:00,  9.76it/s]
Accuracy: 0.965
Invalid: 0.000
Latency: 22.336 s
Output throughput: 1272.156 token/s

Benchmarking and Profiling

DCP can expand the KV Cache memory by reducing KV redundancy thereby allowing for a larger batch size or more prefix cache. We tested Qwen3-235B-A22B-Instruct-2507 on H20 with 32k/4k.

  1. We first compare on bs=1 to evaluate the impact of DCP overhead on TPOT. The results show that DCP's impact on TPOT is within 10% (< 1ms).
  • TP8
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 1         
Successful requests:                     10        
Benchmark duration (s):                  217.32    
Total input tokens:                      124262    
Total input text tokens:                 124262    
Total input vision tokens:               0         
Total generated tokens:                  20099     
Total generated tokens (retokenized):    20097     
Request throughput (req/s):              0.05      
Input token throughput (tok/s):          571.81    
Output token throughput (tok/s):         92.49     
Peak output token throughput (tok/s):    103.00    
Peak concurrent requests:                2         
Total token throughput (tok/s):          664.29    
Concurrency:                             1.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   21729.72  
Median E2E Latency (ms):                 20860.90  
P90 E2E Latency (ms):                    36148.13  
P99 E2E Latency (ms):                    37502.47  
---------------Time to First Token----------------
Mean TTFT (ms):                          1366.96   
Median TTFT (ms):                        1065.90   
P99 TTFT (ms):                           3954.15   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.18     
Median TPOT (ms):                        10.15     
P99 TPOT (ms):                           10.59     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           10.14     
Median ITL (ms):                         10.13     
P95 ITL (ms):                            10.55     
P99 ITL (ms):                            10.95     
Max ITL (ms):                            15.91     
==================================================
  • DCP2TP8
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 1         
Successful requests:                     10        
Benchmark duration (s):                  231.46    
Total input tokens:                      124262    
Total input text tokens:                 124262    
Total input vision tokens:               0         
Total generated tokens:                  20099     
Total generated tokens (retokenized):    20033     
Request throughput (req/s):              0.04      
Input token throughput (tok/s):          536.85    
Output token throughput (tok/s):         86.83     
Peak output token throughput (tok/s):    95.00     
Peak concurrent requests:                2         
Total token throughput (tok/s):          623.69    
Concurrency:                             1.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   23144.42  
Median E2E Latency (ms):                 22191.90  
P90 E2E Latency (ms):                    38961.84  
P99 E2E Latency (ms):                    40010.64  
---------------Time to First Token----------------
Mean TTFT (ms):                          1402.02   
Median TTFT (ms):                        1091.36   
P99 TTFT (ms):                           4025.15   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.86     
Median TPOT (ms):                        10.81     
P99 TPOT (ms):                           11.24     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           10.82     
Median ITL (ms):                         10.76     
P95 ITL (ms):                            11.16     
P99 ITL (ms):                            11.74     
Max ITL (ms):                            17.22     
==================================================
  1. We compared the throughput gains from DCP under large batch sizes. Compared to TP8, DCP2TP8 can double the KV Cache memory, increasing TPS by 16%.
bs TP8 (TPS) DCP2TP8 (TPS) Throughput gain
32 4626 4635 0.2%
48 4891 5165 5.6%
64 4916 5386 9.6%
80 4950 5540 12.0%
96 4897 5679 16.0%

The performance gain will be even more significant with longer contexts or on devices with less GPU memory. The test results for 32K/8K are as follows:

bs TP8 (TPS) DCP2TP8 (TPS) Throughput gain
32 3376 3511 4.0%
48 3589 3978 10.8%
64 3620 4191 15.7%
80 3541 4415 24.7%

Limitations

  1. DCP requires inter-rank communication to correct the attention output, which introduces additional communication overhead. This results in a ~10% (Test Qwen3-235B-A22B on H20) increase in TPOT for small batch sizes.
    Some optimization efforts are currently underway.
    • Reduce communication overhead through NCCL symmetric memory.
    • Avoiding all-gather Q by using replicated QKVO linear layers.
  2. DCP cannot infinitely increase the KV Cache memory capacity. Its maximum value is limited by the ratio of TP size to the number of KV heads.

TODOs

  • Modify apply_rope_pos_ids_cos_sin_cache kernel enable DCP support for fused_set_kv_buffer

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @FENP, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements Decode Context Parallelism (DCP) to enhance the efficiency of GQA models, particularly in scenarios where the tensor parallelism configuration leads to redundant KV cache storage. By distributing the KV cache across multiple devices and adapting the attention mechanism to operate on these partial caches, the system aims to reduce memory overhead and improve performance. The changes span across distributed state management, KV cache allocation, and the FlashInfer attention implementation to ensure correct and efficient operation in a DCP environment.

Highlights

  • Decode Context Parallel (DCP) Support: Introduced a new parallelization strategy, Decode Context Parallel (DCP), specifically for GQA models to optimize KV cache usage when the Tensor Parallel (TP) size exceeds the number of Key-Value (KV) heads.
  • Distributed Communication Group: Added a dedicated communication group for DCP, including new reduce_scatter_along_dim functionality and initialization/destruction logic within the parallel state management.
  • KV Cache Management Adaptation: Modified KV cache allocation and management. The logical page size in the Radix cache and TokenToKVPoolAllocator is now scaled by dcp-size, while the actual KV buffer page size remains unchanged. KV entries are stored in an interleaved manner across DCP ranks.
  • FlashInfer Backend Integration: Updated the FlashInfer attention backend to support DCP. Each DCP rank now stores only a partial segment of the KV cache, necessitating the use of online softmax for correct attention output computation and cross-rank aggregation of results.
  • New CLI Argument: Added a new command-line argument --dcp-size (or --decode-context-parallel-size) to configure the decode context parallel size.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallelism (DCP) to enhance support for GQA models, particularly when the tensor parallelism size is greater than the number of KV heads. The implementation involves sharding the KV cache and employing an online softmax algorithm for attention computation, with changes spanning distributed communication, the FlashInfer attention backend, and memory management. The core logic, which includes an all-gather of queries followed by a distributed attention step and a reduce-scatter of the output, appears solid. My feedback primarily focuses on improving code structure by reducing duplication, enhancing clarity in data flow, and addressing a potential bug related to inconsistent return types.

Comment on lines +1102 to +1104
if dcp_size is None:
dcp_size = 1
dcp_rank = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defaulting dcp_size and dcp_rank here if they are None can make the data flow harder to trace. It would be cleaner to ensure that these parameters are always provided with valid values from the calling context (e.g., init_forward_metadata), which already has access to the correct DCP group information. This would make the function signature more explicit and remove the need for this fallback logic.

Comment on lines +375 to +376
assert out.is_contiguous()
out = cp_group.reduce_scatter_along_dim(out, dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The correct_attn_out function modifies its input out in-place. Since there's no guarantee that the input tensor cp_attn_out is contiguous, this assertion might fail. It would be safer to explicitly make the tensor contiguous before the reduce_scatter_along_dim call, which relies on a contiguous layout for its internal movedim operation.

    out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
    out = out.contiguous()
    out = cp_group.reduce_scatter_along_dim(out, dim=1)

@github-actions github-actions bot added the npu label Dec 18, 2025
@FENP FENP changed the title [WIP] Add DCP support for GQA with flashinfer [Feature] Add DCP support for GQA with flashinfer Dec 18, 2025
@ClawSeven
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@FENP FENP force-pushed the dcp_for_gqa_fi branch 4 times, most recently from 93f4b70 to 0d7f580 Compare December 28, 2025 07:38
@FENP FENP force-pushed the dcp_for_gqa_fi branch 2 times, most recently from dfbfe2d to b74a483 Compare January 15, 2026 06:06
FENP added 15 commits January 27, 2026 14:37
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants