Skip to content

[bugfix] Restore attn_tp_rank/size reset in DeepseekV2AttentionMLA when CP enabled#19495

Closed
huangzhilin-hzl wants to merge 1 commit intosgl-project:mainfrom
huangzhilin-hzl:fix_ds32_precision_in_cp
Closed

[bugfix] Restore attn_tp_rank/size reset in DeepseekV2AttentionMLA when CP enabled#19495
huangzhilin-hzl wants to merge 1 commit intosgl-project:mainfrom
huangzhilin-hzl:fix_ds32_precision_in_cp

Conversation

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor

@huangzhilin-hzl huangzhilin-hzl commented Feb 27, 2026

Motivation

ref #19483

PR #17213 removed the reset of attn_tp_rank and attn_tp_size when nsa_enable_prefill_cp and use_nsa are True, causing incorrect num_local_heads calculation in DeepseekV2AttentionMLA. Restore the reset to fix ds32 model precision issue.

Modifications

Accuracy Tests

GPU: H20-141GB

  • case1 In-seq splitting mode launched with EP + DP
python3 -m sglang.launch_server --model-path /root/.cache/models/DeepSeek-V3.2 --trust-remote-code \
--port 8000 --host 0.0.0.0 --attention-backend  nsa \
--enable-metrics --mem-fraction-static 0.8 --max-running-requests 128 --enable-cache-report --page-size 64 \
--tp-size 8 \
--tool-call-parser deepseekv32 \
--reasoning-parser deepseek-v3 \
--chunked-prefill-size 16384 \
--nsa-decode-backend fa3 \
--enable-nsa-prefill-context-parallel \
--attn-cp-size 4 \
--ep 8 \
--dp 2 \
--enable-dp-attention \
--nsa-prefill-cp-mode in-seq-split 

pr before

# python3 benchmark/gsm8k/bench_sglang.py --host http://0.0.0.0 --port 8000 --num-shots 8 --num-questions 1319 --parallel 200 
Accuracy: 0.947
Invalid: 0.000
Latency: 364.807 s
Output throughput: 349.867 token/s

after

# python3 benchmark/gsm8k/bench_sglang.py --host http://0.0.0.0 --port 8000 --num-shots 8 --num-questions 1319 --parallel 200 
Accuracy: 0.945
Invalid: 0.000
Latency: 369.949 s
Output throughput: 344.164 token/s

case2 Round robin splitting

python3 -m sglang.launch_server --model-path /root/.cache/models/DeepSeek-V3.2 --trust-remote-code \
--port 8000 --host 0.0.0.0 --attention-backend  nsa \
--enable-metrics --mem-fraction-static 0.8 --max-running-requests 128 --enable-cache-report --page-size 64 \
--tp-size 8 \
--tool-call-parser deepseekv32 \
--reasoning-parser deepseek-v3 \
--chunked-prefill-size 16384 \
--nsa-decode-backend fa3 \
--enable-nsa-prefill-context-parallel \
--nsa-prefill-cp-mode round-robin-split

pr before

# python3 benchmark/gsm8k/bench_sglang.py --host http://0.0.0.0 --port 8000 --num-shots 8 --num-questions 1319 --parallel 500 
Accuracy: 0.008
Invalid: 0.033
Latency: 258.471 s
Output throughput: 2612.782 token/s

after

# python3 benchmark/gsm8k/bench_sglang.py --host http://0.0.0.0 --port 8000 --num-shots 8 --num-questions 1319 --parallel 200 
Accuracy: 0.948
Invalid: 0.000
Latency: 205.447 s
Output throughput: 617.236 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical precision bug affecting the DeepseekV2AttentionMLA model. It reintroduces specific variable initializations that were previously removed, ensuring the correct calculation of local attention heads when prefill optimization with tensor parallelism is enabled. This fix is essential for maintaining the model's accuracy.

Highlights

  • Bugfix for DeepseekV2AttentionMLA Precision: Restored the reset of attn_tp_rank and attn_tp_size to 0 and 1 respectively, within the conditional block where nsa_enable_prefill_cp and use_nsa are true. This corrects an issue where num_local_heads was incorrectly calculated, leading to precision problems in the DeepseekV2AttentionMLA model, which was introduced by a previous PR.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/deepseek_v2.py
    • Restored the assignment of attn_tp_rank = 0 and attn_tp_size = 1 inside the if self.nsa_enable_prefill_cp and self.use_nsa: block within the __init__ method of DeepseekV2AttentionMLA.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request restores the reset of attn_tp_rank and attn_tp_size in DeepseekV2AttentionMLA when context parallelism is enabled. This is a correct bug fix that addresses an issue where num_local_heads was calculated incorrectly, leading to precision problems. The change is small, targeted, and looks good.

@huangzhilin-hzl huangzhilin-hzl marked this pull request as ready for review February 27, 2026 10:19
Fridge003
Fridge003 previously approved these changes Feb 27, 2026
@Fridge003 Fridge003 dismissed their stale review February 27, 2026 20:55

misoepration

@Fridge003
Copy link
Copy Markdown
Collaborator

Fridge003 commented Feb 27, 2026

Can you explain why the accuracy issue only happens on PD + CP cases, and how this fix help with PD?

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor Author

Can you explain why the accuracy issue only happens on PD + CP cases, and how this fix help with PD?

Hi @Fridge003 , the issue fixed in the single-node scenario rather than in the PD scenario . When CP is active, each rank must have complete attention heads (attn_tp_size = 1). However, if attn_cp_size is not explicitly provided in the launch arguments and falls back to the default value of 1, it leads to an incorrect computation of attn_tp_size. https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/dp_attention.py#L230

@vladnosiv
Copy link
Copy Markdown
Contributor

When CP is active, each rank must have complete attention heads (attn_tp_size = 1).

Is it right actually? I think after refactoring you can use DP2 ATTN_CP2 ATTN_TP2 configuration for example.

I also noticed accuracy problems after refactoring, but they are more likely that when CP is enabled, but without passing the --attn-cp-size N flag, default 1 behaves inconsistently: in fact, the CP code path is not activated, because attn_cp_size = 1. At the same time, most likely there is some kind of all_gather ranks with activation by the old flag (and not by attn_cp_size) or something like that.

The easiest way to see problems is by running prefill separately and seeing that the first token is always completely random.

But since adding of --attn-cp-size 8 flag helps to restore parity with the previous version, I did not study the problem in depth.

@huangzhilin-hzl
Copy link
Copy Markdown
Contributor Author

When CP is active, each rank must have complete attention heads (attn_tp_size = 1).

Is it right actually? I think after refactoring you can use DP2 ATTN_CP2 ATTN_TP2 configuration for example.

I also noticed accuracy problems after refactoring, but they are more likely that when CP is enabled, but without passing the --attn-cp-size N flag, default 1 behaves inconsistently: in fact, the CP code path is not activated, because attn_cp_size = 1. At the same time, most likely there is some kind of all_gather ranks with activation by the old flag (and not by attn_cp_size) or something like that.

The easiest way to see problems is by running prefill separately and seeing that the first token is always completely random.

But since adding of --attn-cp-size 8 flag helps to restore parity with the previous version, I did not study the problem in depth.

Agree. The simplest fix is to explicitly set attn_cp_size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants