Skip to content

Opt tp: tp attn support tp reduce scattered input#10568

Merged
ch-wan merged 6 commits intosgl-project:mainfrom
antgroup:xyf/qkv_opt
Nov 15, 2025
Merged

Opt tp: tp attn support tp reduce scattered input#10568
ch-wan merged 6 commits intosgl-project:mainfrom
antgroup:xyf/qkv_opt

Conversation

@xu-yfei
Copy link
Contributor

@xu-yfei xu-yfei commented Sep 17, 2025

Motivation

In H20(96GB) TP8 prefill, optimize the original combined operation:
embed/mlp all reduce + RMSNorm + fused_qkv_a_proj_with_mqa
into:
embed/mlp reduce scatter + RMSNorm + fused_qkv_a_proj_with_mqa + all gather

Use switch --enable-attn-tp-input-scattered to enable this feature.

image

This optimization primarily brings the following improvements:

  1. Computation and Memory Reduction: The amount of data that RMSNorm and fused_qkv_a_proj_with_mqa need to process is reduced to 1/8 of the original.
  2. Communication Pattern Optimization: The all reduce communication is decomposed into reduce scatter + all gather. During the all gather phase, after fused_qkv_a_proj_with_mqa, the last dimension is reduced from 7168 to (1536 + 512 + 64), significantly reducing the communication data volume.

Based on the performance sampling data for 16K chunked prefill, the effects after optimization are as follows:

  • The latency offused_qkv_a_proj_with_mqa decreased from 205.1 ms to 26.14 ms.

  • The total latency of communication has decreased from 267.1 ms to 249.63 ms.

  • The total latency of RMSNorm has decreased from 82.303 ms to 43.398 ms.

After:

image

Before:

image

Modifications

Accuracy Tests

#gsm8k
Accuracy: 0.955
Invalid: 0.000
Latency: 271.702 s
Output throughput: 467.895 token/s
#mmlu
subject: abstract_algebra, #q:100, acc: 0.760
subject: anatomy, #q:135, acc: 0.844
subject: astronomy, #q:152, acc: 0.934
subject: business_ethics, #q:100, acc: 0.890
subject: clinical_knowledge, #q:265, acc: 0.928
subject: college_biology, #q:144, acc: 0.965
subject: college_chemistry, #q:100, acc: 0.620
subject: college_computer_science, #q:100, acc: 0.830
subject: college_mathematics, #q:100, acc: 0.760
subject: college_medicine, #q:173, acc: 0.873
subject: college_physics, #q:102, acc: 0.804
subject: computer_security, #q:100, acc: 0.890
subject: conceptual_physics, #q:235, acc: 0.936
subject: econometrics, #q:114, acc: 0.763
subject: electrical_engineering, #q:145, acc: 0.869
subject: elementary_mathematics, #q:378, acc: 0.939
subject: formal_logic, #q:126, acc: 0.802
subject: global_facts, #q:100, acc: 0.670
subject: high_school_biology, #q:310, acc: 0.952
subject: high_school_chemistry, #q:203, acc: 0.857
subject: high_school_computer_science, #q:100, acc: 0.940
subject: high_school_european_history, #q:165, acc: 0.885
subject: high_school_geography, #q:198, acc: 0.965
subject: high_school_government_and_politics, #q:193, acc: 0.984
subject: high_school_macroeconomics, #q:390, acc: 0.926
subject: high_school_mathematics, #q:270, acc: 0.748
subject: high_school_microeconomics, #q:238, acc: 0.962
subject: high_school_physics, #q:151, acc: 0.834
subject: high_school_psychology, #q:545, acc: 0.971
subject: high_school_statistics, #q:216, acc: 0.861
subject: high_school_us_history, #q:204, acc: 0.961
subject: high_school_world_history, #q:237, acc: 0.949
subject: human_aging, #q:223, acc: 0.852
subject: human_sexuality, #q:131, acc: 0.939
subject: international_law, #q:121, acc: 0.942
subject: jurisprudence, #q:108, acc: 0.907
subject: logical_fallacies, #q:163, acc: 0.933
subject: machine_learning, #q:112, acc: 0.786
subject: management, #q:103, acc: 0.932
subject: marketing, #q:234, acc: 0.944
subject: medical_genetics, #q:100, acc: 0.940
subject: miscellaneous, #q:783, acc: 0.951

Benchmarking and Profiling

export SGL_ENABLE_JIT_DEEPGEMM=1
export TORCHINDUCTOR_CACHE_DIR=/home/admin/inductor_root_cache
export SGLANG_TORCH_PROFILER_DIR=/home/admin/torch_profiler
export SGL_CHUNKED_PREFIX_CACHE_USE_TUNED=1
model_path=/home/deepseek-ai__DeepSeek-R1

python3 -m sglang.launch_server --model-path $model_path \
--host 0.0.0.0 --port 8000 --trust-remote-code \
--enable-cache-report --quantization fp8 --log-level info --max-running-requests 32 \
--mem-fraction-static 0.92 --chunked-prefill-size 16384 --context-length 65535 --chat-template /home/r1.jinja \
--attention-backend fa3 \
--disable-radix-cache \
--tp-size 8 --enable-metrics --cuda-graph-max-bs 32
input_len=1000  # 2000, 4000
python3 -m sglang.bench_serving --backend sglang --dataset-name random \
--random-input ${input_len} --random-output 1 --request-rate 1000 \
--num-prompt 500 --random-range-ratio 1 --max-concurrency 16  --port 8000 
--dataset-path /home/ShareGPT_V3_unfiltered_cleaned_split.json

Request throughput (req/s) :

Input Length Before PR After PR
1000 12.82 14.22
2000 6.52 7.33
4000 2.49 2.72
4096 2.41 2.63

Checklist

Summary by CodeRabbit

  • New Features

    • Added support for tensor-parallel scattered attention, enabling more efficient multi-GPU attention workflows.
    • Introduced a fused QKV projection path with automatic gathering when required.
    • Added an option to skip reduction in vocab-parallel embeddings for advanced workflows.
  • Performance

    • Reduced communication overhead in attention and embedding stages, improving scalability on tensor-parallel setups.
  • Bug Fixes

    • Prevented potential hangs in fused paths when processing empty micro-batches.
  • Refactor

    • Exposed additional runtime context to better coordinate tensor-parallel operations.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization for the prefill phase of large language models, specifically targeting H20 (96GB) TP8 configurations. By refactoring the communication pattern around attention and MLP layers from a single all-reduce to a reduce-scatter followed by an all-gather, the changes drastically reduce the amount of data processed by intermediate layers and minimize communication overhead. This leads to notable improvements in overall latency and computational efficiency.

Highlights

  • Optimized Communication Pattern: The core optimization replaces a single all reduce operation with a reduce scatter followed by an all gather for embed/mlp and fused_qkv_a_proj_with_mqa in H20 (96GB) TP8 prefill. This reduces the data processed by RMSNorm and fused_qkv_a_proj_with_mqa to 1/8th of the original and significantly decreases communication data volume.
  • Performance Improvements: Benchmarking shows substantial latency reductions: fused_qkv_a_proj_with_mqa decreased from 205.1 ms to 26.14 ms, total communication from 267.1 ms to 249.63 ms, and RMSNorm from 82.303 ms to 43.398 ms for an input length of 4000.
  • Flexible Attention Input Handling: Introduced attn_input_tp_scattered flags and logic to dynamically determine whether attention inputs should be scattered across tensor parallel ranks, allowing for more efficient processing in specific configurations.
  • New Communication Primitives: Added a tp_all_gather function to facilitate the new communication pattern and a _tp_reduce_scatter_or_all_reduce method within the LayerCommunicator to manage the conditional application of reduce-scatter or all-reduce operations.
  • DeepseekV2 Model Integration: The DeepseekV2 model's attention and decoder layers have been updated to leverage this new communication strategy, with helper functions to enable the optimization under specific conditions (e.g., q_lora_rank and forward mode).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a commendable optimization for tensor parallelism by replacing all_reduce with a reduce_scatter + all_gather pattern. This change, primarily affecting communicator.py and deepseek_v2.py, aims to reduce communication overhead and memory usage. While the optimization strategy is sound, my review identified a critical correctness issue in the new communication logic within communicator.py concerning the handling of scattered residuals. This flaw is likely to produce incorrect model outputs and needs to be addressed. I have also provided suggestions for a type hint correction and a refactoring opportunity in deepseek_v2.py to improve code maintainability.

@xu-yfei xu-yfei force-pushed the xyf/qkv_opt branch 3 times, most recently from 4ed160a to c357dd0 Compare September 25, 2025 07:28
@xu-yfei xu-yfei changed the title Opt fused_qkv_a_proj_with_mqa: tp attn support tp reduce scattered input Opt tp: tp attn support tp reduce scattered input Sep 25, 2025
@whybeyoung
Copy link
Collaborator

can you resolve the confilicts ?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 29, 2025

can you resolve the confilicts ?

done

@miter6
Copy link
Contributor

miter6 commented Sep 29, 2025

Do you test the even input ids??
There will be a crash which the ids is not the power of 2.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 29, 2025

Do you test the even input ids?? There will be a crash which the ids is not the power of 2.

Whether it is an even or odd number, it is very common. I have verified that it is normal. What is your current scenario? TP8? What is the specific error?

@miter6
Copy link
Contributor

miter6 commented Sep 29, 2025

input sequence is [1023,7168] and tp 8.

@miter6
Copy link
Contributor

miter6 commented Sep 29, 2025

I have implemented a similar feature locally.
We must pad the sequence to fit the tp-size.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 29, 2025

I have implemented a similar feature locally. We must pad the sequence to fit the tp-size.

a = torch.randn((1023, 7168), dtype=torch.bfloat16, device="cuda")
>>> b = a.tensor_split(8)
>>> b[0].shape
torch.Size([128, 7168])
>>> b[-1].shape
torch.Size([127, 7168])

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 9, 2025

@merrymercy @yizhang2077 Could you please help review this pr?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Nov 10, 2025

@ch-wan I've updated a version according to the review comments. Could you please review it?

@xu-yfei xu-yfei force-pushed the xyf/qkv_opt branch 3 times, most recently from 15ba980 to 2febc76 Compare November 12, 2025 07:19
Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's much better than the original version. Thank you for your continuous efforts.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Nov 15, 2025
@ch-wan ch-wan merged commit d91b16e into sgl-project:main Nov 15, 2025
53 of 66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants