Opt tp: tp attn support tp reduce scattered input#10568
Opt tp: tp attn support tp reduce scattered input#10568ch-wan merged 6 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant optimization for the prefill phase of large language models, specifically targeting H20 (96GB) TP8 configurations. By refactoring the communication pattern around attention and MLP layers from a single all-reduce to a reduce-scatter followed by an all-gather, the changes drastically reduce the amount of data processed by intermediate layers and minimize communication overhead. This leads to notable improvements in overall latency and computational efficiency.
Highlights
- Optimized Communication Pattern: The core optimization replaces a single
all reduceoperation with areduce scatterfollowed by anall gatherforembed/mlpandfused_qkv_a_proj_with_mqain H20 (96GB) TP8 prefill. This reduces the data processed byRMSNormandfused_qkv_a_proj_with_mqato 1/8th of the original and significantly decreases communication data volume. - Performance Improvements: Benchmarking shows substantial latency reductions:
fused_qkv_a_proj_with_mqadecreased from 205.1 ms to 26.14 ms, total communication from 267.1 ms to 249.63 ms, andRMSNormfrom 82.303 ms to 43.398 ms for an input length of 4000. - Flexible Attention Input Handling: Introduced
attn_input_tp_scatteredflags and logic to dynamically determine whether attention inputs should be scattered across tensor parallel ranks, allowing for more efficient processing in specific configurations. - New Communication Primitives: Added a
tp_all_gatherfunction to facilitate the new communication pattern and a_tp_reduce_scatter_or_all_reducemethod within theLayerCommunicatorto manage the conditional application of reduce-scatter or all-reduce operations. - DeepseekV2 Model Integration: The DeepseekV2 model's attention and decoder layers have been updated to leverage this new communication strategy, with helper functions to enable the optimization under specific conditions (e.g.,
q_lora_rankand forward mode).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a commendable optimization for tensor parallelism by replacing all_reduce with a reduce_scatter + all_gather pattern. This change, primarily affecting communicator.py and deepseek_v2.py, aims to reduce communication overhead and memory usage. While the optimization strategy is sound, my review identified a critical correctness issue in the new communication logic within communicator.py concerning the handling of scattered residuals. This flaw is likely to produce incorrect model outputs and needs to be addressed. I have also provided suggestions for a type hint correction and a refactoring opportunity in deepseek_v2.py to improve code maintainability.
17b27f3 to
24ebcbc
Compare
4ed160a to
c357dd0
Compare
|
can you resolve the confilicts ? |
c357dd0 to
d395444
Compare
done |
|
Do you test the even input ids?? |
Whether it is an even or odd number, it is very common. I have verified that it is normal. What is your current scenario? TP8? What is the specific error? |
|
input sequence is [1023,7168] and tp 8. |
|
I have implemented a similar feature locally. |
|
0578335 to
b3724c9
Compare
|
@merrymercy @yizhang2077 Could you please help review this pr? |
abc792c to
0d28a37
Compare
|
@ch-wan I've updated a version according to the review comments. Could you please review it? |
15ba980 to
2febc76
Compare
ch-wan
left a comment
There was a problem hiding this comment.
It's much better than the original version. Thank you for your continuous efforts.
Motivation
In H20(96GB) TP8 prefill, optimize the original combined operation:
embed/mlp all reduce + RMSNorm + fused_qkv_a_proj_with_mqainto:
embed/mlp reduce scatter + RMSNorm + fused_qkv_a_proj_with_mqa + all gatherUse switch
--enable-attn-tp-input-scatteredto enable this feature.This optimization primarily brings the following improvements:
RMSNormandfused_qkv_a_proj_with_mqaneed to process is reduced to 1/8 of the original.all reducecommunication is decomposed intoreduce scatter+all gather. During theall gatherphase, afterfused_qkv_a_proj_with_mqa, the last dimension is reduced from7168to(1536 + 512 + 64), significantly reducing the communication data volume.Based on the performance sampling data for 16K chunked prefill, the effects after optimization are as follows:
The latency offused_qkv_a_proj_with_mqa decreased from 205.1 ms to 26.14 ms.
The total latency of communication has decreased from 267.1 ms to 249.63 ms.
The total latency of RMSNorm has decreased from 82.303 ms to 43.398 ms.
After:
Before:
Modifications
Accuracy Tests
Benchmarking and Profiling
Request throughput (req/s) :
Checklist
Summary by CodeRabbit
New Features
Performance
Bug Fixes
Refactor