[ROCm][Perf] Add fused AllReduce+RMSNorm for DeepSeek on MI355X#37891
[ROCm][Perf] Add fused AllReduce+RMSNorm for DeepSeek on MI355X#37891attila-dusnoki-htec wants to merge 22 commits intovllm-project:mainfrom
Conversation
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… file Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…d-rmsnorm Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance optimization for DeepSeek models on ROCm MI355X hardware by fusing AllReduce and RMSNorm operations. The implementation is well-structured, leveraging AITER's custom kernels and providing a fallback for unsupported cases. The changes correctly handle different execution paths for FP8 and FP4/BF16 quantization, including a clever decomposition pass for torch.compile compatibility. The integration into the model and communication layers is clean and follows existing patterns. My review includes one high-severity comment regarding undocumented magic numbers in the kernel dispatch logic, which should be addressed to improve maintainability.
| can_use_fused = ( | ||
| n <= 16384 | ||
| and total_bytes < 8 * 1024 * 8192 | ||
| and self.world_size != 6 | ||
| ) |
There was a problem hiding this comment.
These conditions for using the fused kernel contain several 'magic numbers' (16384, 8 * 1024 * 8192, and 6) that are not explained. Undocumented magic numbers make the code harder to understand, maintain, and debug.
Specifically, the condition self.world_size != 6 is concerning as it suggests a potential bug or limitation in the underlying AITER kernel for that specific configuration.
Please add comments explaining the origin and purpose of these values. For example:
- Are they from performance tuning?
- Are they hard limitations of the AITER kernel?
- Is
world_size != 6a workaround for a known issue? If so, linking to the issue would be very helpful.
This documentation is critical for future developers to understand the constraints of this optimization and to know when these values might need to be updated.
Fuse tensor-parallel allreduce into RMSNorm layers using AITER's custom allreduce+residual-add+rmsnorm kernel on gfx950 (MI355X). Reduces kernel launch overhead and memory traffic for DeepSeek V2/V3/R1 models under FP4 and FP8 quantization. How it works: - Moves allreduce out of o_proj and MoE projections into the subsequent RMSNorm, where AITER's fused kernel handles allreduce + residual-add + rmsnorm in a single operation. - FP4/BF16: the fused op is preserved at compile time and executed as one AITER kernel (custom_fused_ar_rms). - FP8: the fused op is decomposed at compile time into all_reduce + rmsnorm_with_add, then rmsnorm_with_add + fp8_quant are fused by the existing RocmAiterRMSNormQuantFusionPass pattern matcher. Auto-detection: - Automatically enabled when gfx950, AITER, RMSNorm kernels, and AITER CustomAllreduce are available with TP > 1. No environment variable needed. - Only wired into deepseek_v2.py; other models are unaffected. Models inheriting DeepseekV2DecoderLayer (Eagle, MTP, Mistral Large 3) benefit automatically. Files changed: - _aiter_ops.py: add is_fused_allreduce_rmsnorm_supported() - deepseek_v2.py: move allreduce from projections into RMSNorm layers - layernorm.py: add fused_allreduce parameter to RMSNorm - parallel_state.py: register fused_allreduce_rmsnorm custom op - communication_op.py: add TP wrapper function - cuda_communicator.py: init AITER CustomAllreduce, implement fused op - rocm_aiter_fusion.py: decompose fused op for FP8 quant compatibility Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
8c02d27 to
65a9a35
Compare
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
|
This PR enables a broader usage of all reduce + rmsnorm and its performance gain effect #37646 |
Thanks for the info! I checked it out, and that is alone will not pick-up the kernels for dsr1 sadly. |
…o dsr1-ar-rmsnorm Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
✅ Tested on MI355X (gfx950) — All Tests PassTested on 8x AMD Instinct MI355X (gfx950, ROCm 7.0.1) with Unit Tests — 55 passed, 0 failed
All skips are NVIDIA-only tests — no failures. Feature Detection — Working
End-to-End Benchmark — DeepSeek-R1 FP8 Dynamic, TP=8Low concurrency (10 requests, input=128, output=128):
High concurrency (32 requests, input=512, output=256):
Notes
Great work on this PR! The fused allreduce+rmsnorm path works cleanly on MI355X with zero test failures. 🚀 Test environment: 8x MI355X (gfx950:sramecc+:xnack-), ROCm 7.0.1, vLLM built from PR branch (includes dep PR #37646) |
Base image: rocm/vllm-dev:base_custom_rocm_7.2.1_torch_triton_0330_vllm018 Patches applied: - AITER SplitK bug fix (ROCm/aiter#2508) - vLLM persistent MLA kernel (vllm-project/vllm#36574) - vLLM fused AllReduce+RMSNorm (vllm-project/vllm#37891) Made-with: Cursor
|
Hey, any chance someone could review this? :) |
|
@gshtras : Can you please look at this PR? Thanks |
| var_hidden_size: int | None = None, | ||
| has_weight: bool = True, | ||
| dtype: torch.dtype | None = None, | ||
| fused_allreduce: bool = False, |
There was a problem hiding this comment.
Can this be done through a pattern matching mechanism?
There was a problem hiding this comment.
The dependent PR #37646 already implements a pure pattern matching mechanism for this — RocmAiterAllReduceFusionPass with AiterAllreduceFusedAddRMSNormPattern that matches all_reduce → fused_add_rmsnorm sequences in the compiled graph and replaces them with AITER's fused kernel. This mirrors how the existing CUDA AllReduceFusionPass works with flashinfer/trtllm.
However, as the PR author noted, #37646 alone doesn't pick up the kernels for DeepSeek because the all_reduce → rmsnorm pattern doesn't naturally exist as adjacent nodes in the FX graph. In DeepSeek's decoder layer, the allreduce and rmsnorm are structurally separated:
Attention path: all_reduce happens inside o_proj (a RowParallelLinear with default reduce_results=True), which performs the allreduce internally in its forward method. After self_attn returns, the next rmsnorm is post_attention_layernorm — separated by the attention return boundary and potential FP16 overflow scaling.
MoE path: The allreduce happens inside self.experts.maybe_all_reduce_tensor_model_parallel(), a method on SharedFusedMoE with conditional logic. After self.mlp() returns, the next rmsnorm (input_layernorm of layer N+1) is across the layer iteration boundary.
A compiler pass that could automatically move allreduce boundaries (detect allreduce inside a projection, find the subsequent rmsnorm across residual connections/layer boundaries, and rewrite the graph) would be significantly more complex. Even the CUDA flashinfer path relies on the all_reduce → rmsnorm pattern already being adjacent in the graph.
There was a problem hiding this comment.
@gshtras @dllehr-amd added an implementation without changing model definition. Please check: #38762
There was a problem hiding this comment.
Btw I don't think this is right; this should be possible via pattern matching. Can you post the resulting fx graph and show the ops in between? Is it just the view? That should be eliminated by the existing NoOpEliminationPass, no?
There was a problem hiding this comment.
@ProExpertProg Yes, it's just view. I added the fx graph relevant section in #38762 before/after adding _bypass_noop_views_after_allreduce() to RocmAiterAllReduceFusionPass that shows view goes away.
The NoOpEliminationPass handles _aten.reshape.default, aten.slice.Tensor, aten.slice_scatter.default_. Do you think it would be better to extend this method to include aten.view.default instead of adding _bypass_noop_views_after_allreduce() in ROCm pass?
There was a problem hiding this comment.
Hi @ProExpertProg
You are right, the view was not the problem.
I added a comment to the original PR #37646 (comment)
|
This pull request has merge conflicts that must be resolved before it can be |
|
@rbrugaro-amd Thanks for figuring out! Closing this in favour of #38762 |
…d-rmsnorm Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…mbeddedLLM/vllm into aiter-all-reduce-fused-rmsnorm
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…o HEAD Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
|
I did not manage to make the pattern-matching version work, so i'm re-opening this solution. |
|
This pull request has merge conflicts that must be resolved before it can be |
Depends on #37646
Summary
Fuse tensor-parallel allreduce into RMSNorm layers using AITER's fused
allreduce+residual-add+rmsnorm kernel on gfx950 (MI355X). This reduces kernel
launch overhead and memory traffic for DeepSeek V2/V3/R1 models under FP4 and
FP8 quantization with TP > 1.
o_projand MoE projections into the subsequentRMSNorm layer, where AITER's fused kernel handles allreduce + residual-add +
rmsnorm in a single operation.
CustomAllreducecommunicator are available. No environment variable needed.deepseek_v2.py. Other models are unaffected. Models thatinherit
DeepseekV2DecoderLayer(Eagle, MTP, Mistral Large 3) benefitautomatically.
FP4 vs FP8 behavior
The fused op behaves differently depending on quantization:
fused_allreduce_rmsnormis preserved at compile time andexecuted as a single AITER kernel (
custom_fused_ar_rms).fused_allreduce_rmsnormis decomposed at compile time intoall_reduce+rmsnorm_with_add, thenrmsnorm_with_add+fp8_quantarefused into one AITER op by the existing
RocmAiterRMSNormQuantFusionPass.Changes
vllm/_aiter_ops.pyis_fused_allreduce_rmsnorm_supported()auto-detectionvllm/model_executor/models/deepseek_v2.pyvllm/model_executor/layers/layernorm.pyfused_allreduceparameter toRMSNormvllm/distributed/parallel_state.pyfused_allreduce_rmsnormcustom op + graph capturevllm/distributed/communication_op.pyvllm/distributed/device_communicators/cuda_communicator.pyCustomAllreduce, implement fused kernel + fallbackvllm/compilation/passes/fusion/rocm_aiter_fusion.pyTest plan
tests/rocm/aiter/test_fused_ar_rmsnorm.py):tests/distributed/test_fused_ar_rmsnorm.py):tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py):fused_allreduce_rmsnormpreserved throughtorch.compilefused_allreduce_rmsnormdecomposed, output matches unfused baselinevllm bench latencywithdeepseek-ai/DeepSeek-R1-0528 --quantization fp8vllm bench latencywithamd/DeepSeek-R1-0528-MXFP4 --quantization quark