Integrate MixedComm from FlashInfer to fuse communication kernels when attention TP and DP are both used#23713
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for fused communication using FlashInfer's MixedComm kernels, which optimizes performance by fusing reduce-scatter and all-gather operations when data parallelism is combined with attention tensor parallelism. The changes include the implementation of a MixedCommHandler for initialization, integration of fused kernels into the communication layers, and updated documentation. Additionally, the logic for MoE reduce-scatter operations was refined to ensure correctness when both tensor and data parallelism are active. I have no feedback to provide as the review comments were either explanatory or did not identify actionable issues in the current implementation.
…n attention TP and DP are both used Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
900ffdd to
49faeb9
Compare
|
Hi @ch-wan, it seems that the bot failed to assign this PR to a Merge Oncall. Since this PR is mainly related to attention DP, can you help review what modifications are needed and ping other relevant experts if necessary? |
Motivation
When attention TP and DP are both used, two kernels are used for allreduce + allgather / reducescatter + allreduce currently (implemented via a reducescatter kernel + an allgather kernel). Using fused kernels for these two communication patterns can improve the performance in the low-latency scenarios.
A small bug is found when testing the following case (a single reducescatter is not sufficient for this case). Therefore, this PR also fixes this bug.
Modifications
python/sglang/srt/layers/dp_attention.pyis modified to use MixedComm from FlashInfer when the environment variableSGLANG_USE_MIXED_COMMis set to1ortrue.python/sglang/srt/layers/communicator.pyis modified to prefer reduce_scatter over reduce_scatterv when the numbers of tokens are identical on all ranks.python/sglang/srt/layers/moe/utils.pyis modified to fix the abovementioned bug.docs/advanced_features/dp_dpa_smg_guide.mdis modified to update documentation.test/registered/distributed/test_dp_attention.pyandtest/registered/distributed/test_dp_attention_large.pyare modified to add unit tests.Accuracy Tests
The accuracy has been verified by testing GSM8K and GPQA scores of Qwen3.5 on H200. The tested cases are:
Scores:
Scripts to reproduce the results:
Speed Tests and Profiling
The performance of Qwen3.5 is tested on H200 using the following cases:
Results of median ITL are obtained from the outputs of
sglang.bench_serving. Results of kernel execution time are obtained by analyzing nsys timelines.The performance is tested using the following commands:
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci