[DeepSeekV3.2] Enable pure TP & Partial DP Attention#13646
[DeepSeekV3.2] Enable pure TP & Partial DP Attention#13646Fridge003 merged 10 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for pure Tensor Parallelism (TP) for the DeepSeekV3.2 model's Non-Standard Attention (NSA) mechanism. It addresses compatibility issues with the FlashMLA sparse kernel when Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables pure Tensor Parallelism (TP) for DeepSeekV3.2 models. My review focuses on correctness and performance. I've found a critical issue in nsa_backend.py where new_empty is used for padding, which can lead to incorrect results due to uninitialized memory; I've suggested using new_zeros instead. I also pointed out a performance concern with calling get_device_capability inside a forward pass. In server_args.py, I've identified a condition that seems to be dead code and prevents a warning from being logged, and I've suggested a fix to ensure the warning is displayed correctly under pure TP mode.
|
Thanks @YAMY1234~ If your PR is blocked on FlashMLA side, you can create a new branch at https://github.com/sgl-project/FlashMLA. The flashmla kernel now integrated in sglang are built on this repo |
@YAMY1234 For the Hxx device, the padding head may exhibit poor performance. Could we consider swapping the head and token dimensions via all-to-all (a2a) communication, where each rank increases the head dimension while reducing the token dimension? |
Thanks for the suggestion! For now, this padding logic only applies to the FlashMLA sparse path under pure TP / partial DP attention. For original normal DP attention path is unchanged — if the layout doesn’t match, it will just fail as before, so there should be no behavioral change there. I agree that doing an a2a-based head/token swap could be a good follow-up optimization for the TP / partial DP case on Hxx to reduce the padding overhead. For this PR, I think we can keep the change scoped to making pure TP & partial DP attention functional and stable, and we can explore the a2a approach in a separate perf-focused change. |
|
@YAMY1234 Hi, I use your branch (https://github.com/YAMY1234/sglang/tree/dpsk_tp) and get some error in PD: I launced prefill as decode as TP16 DP16 EP16 |
Thanks for pointing this out! For now this PR is mainly focused on and validated under the aggregated (agg) setup🥺. |
|
@YAMY1234 Can you add a benchmark for bs=1? |
@Fridge003 |
Oh I mean performance benchmark. You can test with |
Sorry my previous added benchmark disappeared😂 might be an saving error. I tested with sglang.bench_serving and looks like TP will be 4-5x faster than DP Attention during prefill in bs=1's situation. Could you take a second look at the PR desc? Thanks! |
|
@YAMY1234 Thanks~ |
|
@Fridge003 Thanks! Added docs and unittest~ |
|
/tag-and-rerun-ci |
Motivation
DeepSeekV3.2 NSA currently has rough edges when running in pure TP mode (
dp_size < tp_size):num_headsper rank after TP sharding.(num_q, num_k)logits matrix in large pure-TP batches.This PR makes NSA + pure TP & partial DP Attention a supported and stable configuration for DeepSeekV3.2.
Should merge after Upgrade flashmla kernel for NSA tp support #13718.
Sample launch commands:
Modifications
NSA backend (
nsa_backend.py)_forward_flashmla_sparse(...), padq’s head dimension to the required multiple (64 on SM90, 128 on SM100+), call the sparse kernel with padded heads, then trim the output back to the originalnum_heads(TP support).topk_transform(...)withtopk_indices_offset_overrideto accept precomputed ragged offsets(indexer chunking support).device_capability/device_sm_majoron init and reuse it in TRTLLM ragged and FlashMLA paths.Server args (
server_args.py)dp_size < tp_sizefor DeepSeekV3.2 NSA:FlashMLA cmake (
flashmla.cmake)be055fb7df0090fde45f08e9cb5b8b4c0272da73to use the latest sparse kernel(Avoid crashing with large bs).NSA indexer (
nsa_indexer.py)_should_chunk_mqa_logits(...)to decide when to chunk fp8 MQA logits based on workload size and free GPU memory._get_topk_ragged(...), add a chunked path that:topk_indices_offset_overrideso each chunk can reuse the global ragged offsets safely.(token_nums, index_topk)buffer for all tokens.Fix a shape-mismatch bug in NSA sparse prefill:
when MLP-sync pads tokens to TP multiples (
7→8), the indexer still returns only real-token rows._pad_topk_indices(...)allow padding to maintain#tokens(q) == #tokens(topk_indices), restoring correctness for FlashMLA-sparse under partial DP attention (e.g., TP 8, DP 4).Accuracy Tests
Launch with
Benchmarking and Profiling
Serving with one request:
TP mode:
DP Attention:
Checklist