Skip to content

[Feature] Add DCP support for DeepSeek v3.2#18167

Open
FENP wants to merge 5 commits intosgl-project:mainfrom
FENP:dcp_for_ds32
Open

[Feature] Add DCP support for DeepSeek v3.2#18167
FENP wants to merge 5 commits intosgl-project:mainfrom
FENP:dcp_for_ds32

Conversation

@FENP
Copy link
Copy Markdown

@FENP FENP commented Feb 3, 2026

Motivation

Following PR #14982, add decode context parallel (DCP) support for DeepSeek v3.2. The KV cache redundancy issue in DeepSeek v3.2 is even more severe because the MLA architecture uses only 1 KV head. DCP can completely eliminate this redundancy: compared to TP8, TP8+DCP8 can expand the KV cache capacity by 8×.

Modifications

The changes are largely the same as those in PR #14982. For DSA, this PR includes the following additional design considerations:

  1. Do not shard the indexer's K cache across ranks to avoid introducing additional complexity in the top-k computation.
  2. Filter the top-k indices based on the DCP size and rank, keeping only those indices whose KV cache resides on the current rank.

Usage

python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3.2/  --tp 8   --dcp <1/2/4/8>  --nsa-decode-backend flashmla_sparse 

Accuracy Tests

few_shot_gsm8k

# TP8DCP1
100%|██████████████████████| 200/200 [00:17<00:00, 11.30it/s]
Accuracy: 0.955
Invalid: 0.000
Latency: 17.949 s
Output throughput: 1015.863 token/s

# TP8DCP2
100%|██████████████████████| 200/200 [00:22<00:00,  9.05it/s]
Accuracy: 0.965
Invalid: 0.000
Latency: 22.422 s
Output throughput: 810.458 token/s

# TP8DCP4
100%|██████████████████████| 200/200 [00:27<00:00,  7.26it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 27.896 s
Output throughput: 654.467 token/s

# TP8DCP8
100%|██████████████████████| 200/200 [00:38<00:00,  5.21it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 38.854 s
Output throughput: 472.079 token/s

Benchmarking and Profiling

Hardware: H20 * 8

4K/1.5K

  • max-concurrency=1

DCP results in 8% to 13% performance degradation compared to TP. Further testing will be conducted after optimizing communication (e.g., symmetric memory, replicated linear ).

It is worth noting that, under the TP8DCP4 configuration, the number of attention heads in Q becomes exactly 64 after the DCP all-gather, enabling flashmla_sparse to run on SM90 without padding.

python3 -m sglang.bench_serving --backend sglang --dataset-name random-ids --num-prompt 10  --random-input-len 4096 --random-output-len 1536 --max-concurrency 1 --random-range-ratio 1.0
Parallel Config P50 TTFT (ms) P99 TPOT (ms) Total TPS Raw result
TP8 1106.99 21.75 163.31 image
TP8DCP2 1115.29 23.60 150.82 image
TP8DCP4 1172.82 23.78 149.53 image
TP8DCP8 1930.39 24.55 142.25 image

TODOs

  • Support symmetric memory to reduce communication overhead.
  • Support FA3 as the backend. currently encountering an illegal memory access issue because FA3 cannot handle the invalid value -1 in the page table, leading it to attempt loading data from an invalid address.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @FENP, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Decode Context Parallel (DCP) support into the SGLang framework, primarily targeting the DeepSeek v3.2 model. The core motivation is to mitigate KV cache redundancy, a critical concern for models like DeepSeek v3.2 that use a single KV head in their MLA architecture. By distributing the KV cache across multiple devices, DCP effectively multiplies the available cache capacity, allowing for larger context windows and more efficient memory utilization. The changes span across distributed state management, attention mechanisms, memory allocation, and scheduling, ensuring a cohesive and performant implementation of this new parallelism strategy.

Highlights

  • Decode Context Parallel (DCP) Group Management: Introduced a new parallel group, _DCP, and associated functions (get_dcp_group, initialize_model_parallel, destroy_model_parallel) to manage decode context parallel operations, integrating it into the existing distributed state and CUDA graph capture mechanisms.
  • KV Cache Redundancy Elimination for DeepSeek v3.2: Implemented Decode Context Parallel (DCP) specifically for DeepSeek v3.2, addressing the severe KV cache redundancy issue in its MLA architecture. This allows for an 8x expansion of KV cache capacity when using TP8+DCP8, significantly improving memory efficiency.
  • Attention Mechanism Adaptation for DCP: Modified NSA attention backends (flashmla_sparse, flashmla_kv, fa3) to incorporate DCP. This includes all-gathering query tensors across DCP ranks, passing dcp_size to cache dequantization and index transformation functions, and aggregating attention outputs and log-sum-exp (LSE) values using new Triton kernels and helper functions (cp_lse_ag_out_rs).
  • Memory Pool and Scheduler Updates: Adjusted memory pool allocation and scheduler logic to account for DCP. The page_size and max_total_num_tokens are scaled by dcp_size, and token usage reporting is updated to reflect the effective capacity increase. KV cache writes now utilize a dcp_kv_mask to ensure each rank processes its assigned portion of the cache.
  • CLI Argument for DCP Configuration: Added a new command-line argument --dcp-size (or --decode-context-parallel-size) to allow users to configure the decode context parallel size, enabling flexible deployment and scaling of the feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/distributed/parallel_state.py
    • Added _DCP global variable and get_dcp_group() for managing the Decode Context Parallel group.
    • Updated graph_capture to include get_dcp_group().graph_capture(context).
    • Modified initialize_model_parallel to accept decode_context_parallel_size and initialize the DCP group.
    • Updated destroy_model_parallel to destroy the _DCP group.
  • python/sglang/srt/entrypoints/engine.py
    • Added logic to set NCCL_GRAPH_MIXING_SUPPORT="0" if dcp_size > 1 to optimize CUDA graph performance.
  • python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
    • Added dcp_size parameter to dequantize_k_cache_paged and its Triton kernel.
    • Modified the kernel to divide token_id_paged by DCP_SIZE.
  • python/sglang/srt/layers/attention/nsa/transform_index.py
    • Added dcp_size parameter to various transform_index_page_table functions.
    • Modified kernels/functions to divide loaded_kv_indices or result by dcp_size when dcp_size > 1.
  • python/sglang/srt/layers/attention/nsa/utils.py
    • Added is_nsa_enable_decode_cp() function to check if decode context parallel is enabled.
  • python/sglang/srt/layers/attention/nsa_backend.py
    • Imported get_dcp_group and cp_lse_ag_out_rs.
    • Added self.dcp_size and self.dcp_rank attributes to NSARadixAttentionBackend.
    • Introduced _save_kv_cache method to handle DCP-aware KV cache saving.
    • Modified forward_extend and forward_decode to use _save_kv_cache and perform all_gather for query tensors across the DCP group.
    • Passed dcp_size to cache dequantization and index transformation functions.
    • Modified attention forward methods to return softmax LSE and apply cp_lse_ag_out_rs for DCP aggregation.
    • Updated _forward_flashmla_kv to adjust q_all view based on self.dcp_size.
    • Modified set_nsa_prefill_impl to check is_nsa_enable_decode_cp().
  • python/sglang/srt/layers/attention/utils.py
    • Imported GroupCoordinator.
    • Added _correct_attn_cp_out_kernel Triton kernel for correcting attention outputs with all-gathered LSEs.
    • Added CPTritonContext class to manage Triton kernel recompilation.
    • Added correct_attn_out and cp_lse_ag_out_rs functions for DCP attention output correction and aggregation.
  • python/sglang/srt/managers/scheduler.py
    • Added self.dcp_size attribute to Scheduler.
    • Modified logging of max_total_num_tokens to reflect DCP's impact on capacity.
    • Scaled page_size by self.dcp_size during cache initialization and for PrefillAdder.
  • python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
    • Modified _get_token_info to divide available_size and evictable_size by self.dcp_size for accurate token usage.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Added dcp_kv_mask and dcp_size parameters to set_mla_kv_buffer.
    • Passed dcp_size to NSATokenToKVPoolAllocator constructor and adjusted indexer K cache size calculation.
  • python/sglang/srt/mem_cache/utils.py
    • Added set_mla_kv_buffer_with_mask_kernel Triton kernel for masked KV cache writes in DCP.
    • Modified set_mla_kv_buffer_triton to conditionally use the new kernel when dcp_kv_mask is provided.
  • python/sglang/srt/model_executor/cuda_graph_runner.py
    • Added self.dcp_size attribute to CudaGraphRunner.
    • Passed dcp_size to GraphInputBuffers.create and introduced dcp_kv_mask handling in capture_one_batch_size.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added dcp_kv_mask attribute to ForwardBatch.
    • Modified init_new to calculate and assign dcp_kv_mask based on out_cache_loc, dcp_size, and dcp_rank.
  • python/sglang/srt/model_executor/input_buffers.py
    • Added dcp_kv_mask attribute to GraphInputBuffers and updated its creation and population logic.
  • python/sglang/srt/model_executor/model_runner.py
    • Added self.dcp_size and self.dcp_rank attributes to ModelRunner.
    • Passed decode_context_parallel_size to initialize_model_parallel and dcp_size to GraphInputBuffers.create.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Modified get_cell_size_per_token to multiply indexer K cache size by self.dcp_size.
    • Passed dcp_size to NSATokenToKVPoolAllocator constructor and scaled max_total_num_tokens and page_size for PagedTokenToKVPoolAllocator.
  • python/sglang/srt/models/deepseek_v2.py
    • Imported get_dcp_group.
    • Added self.dcp_size and self.dcp_rank attributes to DeepSeekV2Attention.
    • Introduced _filter_topk_indices_by_dcp method to filter topk_indices based on DCP rank, applied in forward_absorb_prepare.
  • python/sglang/srt/server_args.py
    • Added dcp_size attribute to ServerArgs with a default of 1.
    • Added --decode-context-parallel-size (--dcp-size) CLI argument and mapped it to args.dcp_size.
Activity
  • The pull request was initiated by FENP to add Decode Context Parallel (DCP) support for DeepSeek v3.2.
  • The motivation clearly states the goal of eliminating KV cache redundancy in DeepSeek v3.2's MLA architecture.
  • A usage example is provided, demonstrating how to launch the server with --tp 8 --dcp 8 and flashmla_kv backend.
  • The checklist items for formatting, unit tests, documentation, accuracy tests, and benchmarking are currently unchecked, indicating potential areas for further work or verification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallel (DCP) support for DeepSeek v3.2, which is a significant feature for improving KV cache capacity. The changes are extensive, touching many parts of the distributed state management, memory allocation, and attention kernels. The implementation correctly follows standard patterns for context parallelism, including sharding the KV cache, gathering query tensors, and using log-sum-exp correction to combine partial attention results. The reuse of vllm's context parallelism utilities is a good practice.

I've identified a critical issue related to page size configuration in the memory allocator when DCP is enabled, which could break the sharding logic. I've also included a couple of suggestions to refactor duplicated code in the NSA backend to improve maintainability. Overall, this is a solid contribution, and addressing the identified issues will make it even better.

Comment on lines +638 to +639
self.max_total_num_tokens * self.dcp_size,
page_size=self.page_size * self.dcp_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The page_size passed to PagedTokenToKVPoolAllocator appears to be incorrect when DCP is enabled. It should be the physical page size (self.page_size), not scaled by dcp_size. The total number of tokens is already correctly scaled. By also scaling page_size, the number of pages managed by the allocator becomes (max_total_num_tokens * dcp_size) / (page_size * dcp_size) = max_total_num_tokens / page_size, which is the per-rank page count. This breaks the global page pool assumption required for DCP sharding logic (e.g., page_index % dcp_size). The allocator should manage the global pool of physical pages.

Suggested change
self.max_total_num_tokens * self.dcp_size,
page_size=self.page_size * self.dcp_size,
self.max_total_num_tokens * self.dcp_size,
page_size=self.page_size,

# Prefill policy
adder = PrefillAdder(
self.page_size,
self.page_size * self.dcp_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Consistent with the proposed change for PagedTokenToKVPoolAllocator, the page_size passed to PrefillAdder should also be the physical page size, not scaled by dcp_size. PrefillAdder uses this to calculate the number of pages required for a request, and this calculation should be based on the physical page size to correctly interact with the global page pool.

Suggested change
self.page_size * self.dcp_size,
self.page_size,

Comment on lines 1258 to 1266
if self.dcp_size > 1:
q_nope = get_dcp_group().all_gather(q_nope.contiguous(), dim=1)
q_rope = get_dcp_group().all_gather(q_rope.contiguous(), dim=1)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
if self.dcp_size > 1:
q_all = get_dcp_group().all_gather(q_all, dim=1)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for preparing and gathering query tensors (q_nope, q_rope, q_all) is duplicated in forward_decode (lines 1430-1438). To improve code maintainability and reduce redundancy, consider refactoring this logic into a helper method. For example:

def _gather_q(self, q, q_rope, layer):
    if q_rope is not None:
        q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
        q_rope = q_rope.view(
            -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
        )
        if self.dcp_size > 1:
            q_nope = get_dcp_group().all_gather(q_nope.contiguous(), dim=1)
            q_rope = get_dcp_group().all_gather(q_rope.contiguous(), dim=1)
        return q_nope, q_rope, None
    else:
        q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
        if self.dcp_size > 1:
            q_all = get_dcp_group().all_gather(q_all, dim=1)
        q_nope = q_all[:, :, : layer.v_head_dim]
        q_rope = q_all[:, :, layer.v_head_dim :]
        return q_nope, q_rope, q_all

Comment on lines +1344 to +1346
if self.dcp_size > 1:
return cp_lse_ag_out_rs(o, s, get_dcp_group())
return o
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pattern of checking self.dcp_size > 1 and calling cp_lse_ag_out_rs is repeated multiple times (6 times in total) in both forward_extend and forward_decode for different attention backends. This could be extracted into a helper method to reduce code duplication and improve clarity. For example:

def _process_context_parallel_output(self, o, s):
    if self.dcp_size > 1:
        return cp_lse_ag_out_rs(o, s, get_dcp_group())
    return o

Then you could replace this block with a single call: return self._process_context_parallel_output(o, s).

@Fridge003
Copy link
Copy Markdown
Collaborator

Hi @FENP can you please update your codes on the basis of latest main branch
We did some management for the CP communication groups in #17213, so the following CP features can be better aligned

FENP added 5 commits February 27, 2026 14:31
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@FENP
Copy link
Copy Markdown
Author

FENP commented Feb 27, 2026

Hi @FENP can you please update your codes on the basis of latest main branch We did some management for the CP communication groups in #17213, so the following CP features can be better aligned

Of course! I've already rebased the code, and the accuracy tests are passing normally.

@mjp9527
Copy link
Copy Markdown

mjp9527 commented Feb 28, 2026

Great work! Do you have any plans to support Prefill-Decode Disaggregation and MTP?

@FENP
Copy link
Copy Markdown
Author

FENP commented Mar 2, 2026

Great work! Do you have any plans to support Prefill-Decode Disaggregation and MTP?

Thank you for your attention. These two features will be supported in future work. Contributions from the community are also welcome.

@mjp9527
Copy link
Copy Markdown

mjp9527 commented Mar 12, 2026

🌹 I've observed that for the sparse_attn kernel, even when topk_indices is locally set to -1, there is still no performance gain. Therefore, partitioning the KV is much slower than partitioning the Q. Do you have any best practices or experience regarding this issue?

@wangfakang
Copy link
Copy Markdown
Contributor

wangfakang commented Mar 17, 2026

Hello @FENP, Great work! May I ask, after enabling DCP, in the pressure testing scenario above, what is the highest proportion of AllGather communication in a single forward operation? Recently, I have been optimizing the symm-related issues in sglang. If the proportion is high, we can cherry-pick these PRs: #17756 (merged) #20406 (merged) #19329 (review) #20153 (review) and give it a try. Of course, DCP needs to be adapted for symm, that is, use with use_symmetric_memory(${DCP_GROUP}).

@FENP
Copy link
Copy Markdown
Author

FENP commented Mar 25, 2026

🌹 I've observed that for the sparse_attn kernel, even when topk_indices is locally set to -1, there is still no performance gain. Therefore, partitioning the KV is much slower than partitioning the Q. Do you have any best practices or experience regarding this issue?

Do you mean pure kernel performance testing or DCP performance testing?

@FENP
Copy link
Copy Markdown
Author

FENP commented Mar 25, 2026

Hello @FENP, Great work! May I ask, after enabling DCP, in the pressure testing scenario above, what is the highest proportion of AllGather communication in a single forward operation? Recently, I have been optimizing the symm-related issues in sglang. If the proportion is high, we can cherry-pick these PRs: #17756 (merged) #20406 (merged) #19329 (review) #20153 (review) and give it a try. Of course, DCP needs to be adapted for symm, that is, use with use_symmetric_memory(${DCP_GROUP}).

Hello @FENP, Great work! May I ask, after enabling DCP, in the pressure testing scenario above, what is the highest proportion of AllGather communication in a single forward operation? Recently, I have been optimizing the symm-related issues in sglang. If the proportion is high, we can cherry-pick these PRs: #17756 (merged) #20406 (merged) #19329 (review) #20153 (review) and give it a try. Of course, DCP needs to be adapted for symm, that is, use with use_symmetric_memory(${DCP_GROUP}).

Good! Based on my previous experience (#14982), symmetric_memory is an effective way to reduce DCP communication overhead (by approximately 10%). This PR does not yet integrate symmetric_memory; I will attempt to add it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek npu quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants