[Feature] Add DCP support for GQA with flashinfer#14982
[Feature] Add DCP support for GQA with flashinfer#14982FENP wants to merge 15 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @FENP, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements Decode Context Parallelism (DCP) to enhance the efficiency of GQA models, particularly in scenarios where the tensor parallelism configuration leads to redundant KV cache storage. By distributing the KV cache across multiple devices and adapting the attention mechanism to operate on these partial caches, the system aims to reduce memory overhead and improve performance. The changes span across distributed state management, KV cache allocation, and the FlashInfer attention implementation to ensure correct and efficient operation in a DCP environment. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces Decode Context Parallelism (DCP) to enhance support for GQA models, particularly when the tensor parallelism size is greater than the number of KV heads. The implementation involves sharding the KV cache and employing an online softmax algorithm for attention computation, with changes spanning distributed communication, the FlashInfer attention backend, and memory management. The core logic, which includes an all-gather of queries followed by a distributed attention step and a reduce-scatter of the output, appears solid. My feedback primarily focuses on improving code structure by reducing duplication, enhancing clarity in data flow, and addressing a potential bug related to inconsistent return types.
| if dcp_size is None: | ||
| dcp_size = 1 | ||
| dcp_rank = 0 |
There was a problem hiding this comment.
Defaulting dcp_size and dcp_rank here if they are None can make the data flow harder to trace. It would be cleaner to ensure that these parameters are always provided with valid values from the calling context (e.g., init_forward_metadata), which already has access to the correct DCP group information. This would make the function signature more explicit and remove the need for this fallback logic.
| assert out.is_contiguous() | ||
| out = cp_group.reduce_scatter_along_dim(out, dim=1) |
There was a problem hiding this comment.
The correct_attn_out function modifies its input out in-place. Since there's no guarantee that the input tensor cp_attn_out is contiguous, this assertion might fail. It would be safer to explicitly make the tensor contiguous before the reduce_scatter_along_dim call, which relies on a contiguous layout for its internal movedim operation.
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = out.contiguous()
out = cp_group.reduce_scatter_along_dim(out, dim=1)|
/tag-and-rerun-ci |
93f4b70 to
0d7f580
Compare
dfbfe2d to
b74a483
Compare
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Motivation
Following RFC #12196, this PR adds decode context parallel (DCP) support for GQA models, aimed at reducing KV redundancy when the TP size exceeds the number of key-value heads. A similar work has already been adopted by vLLM (vllm-project/vllm#24864).
Difference between DCP and DP Attention
Although both DCP and DP Attention can reduce KV cache memory redundancy, their implementation approaches and effects are entirely different.
In comparison, DCP offers the following advantages that DP Attention does not have.
attn_tp_size).attn_tp_size). This means DCP has the potential to avoid KV redundancy while maintaining low latency (close to that of tensor parallelism). We need to further optimize the overhead of DCP, especially its communication overhead, to achieve this benefit.Modifications
--dcp-size)Usage
Accuracy Tests
Qwen3-235B-A22B-Instruct-2507 test results.
Benchmarking and Profiling
DCP can expand the KV Cache memory by reducing KV redundancy thereby allowing for a larger batch size or more prefix cache. We tested Qwen3-235B-A22B-Instruct-2507 on H20 with 32k/4k.
The performance gain will be even more significant with longer contexts or on devices with less GPU memory. The test results for 32K/8K are as follows:
Limitations
Some optimization efforts are currently underway.
TODOs
Checklist