Skip to content

[ROCm][Perf] Add fused AllReduce+RMSNorm for DeepSeek on MI355X#37891

Closed
attila-dusnoki-htec wants to merge 22 commits intovllm-project:mainfrom
attila-dusnoki-htec:dsr1-ar-rmsnorm
Closed

[ROCm][Perf] Add fused AllReduce+RMSNorm for DeepSeek on MI355X#37891
attila-dusnoki-htec wants to merge 22 commits intovllm-project:mainfrom
attila-dusnoki-htec:dsr1-ar-rmsnorm

Conversation

@attila-dusnoki-htec
Copy link
Copy Markdown

@attila-dusnoki-htec attila-dusnoki-htec commented Mar 23, 2026

Depends on #37646

Summary

Fuse tensor-parallel allreduce into RMSNorm layers using AITER's fused
allreduce+residual-add+rmsnorm kernel on gfx950 (MI355X). This reduces kernel
launch overhead and memory traffic for DeepSeek V2/V3/R1 models under FP4 and
FP8 quantization with TP > 1.

  • Moves allreduce out of o_proj and MoE projections into the subsequent
    RMSNorm layer, where AITER's fused kernel handles allreduce + residual-add +
    rmsnorm in a single operation.
  • Auto-enabled when gfx950, AITER, RMSNorm kernels, and AITER's
    CustomAllreduce communicator are available. No environment variable needed.
  • Only wired into deepseek_v2.py. Other models are unaffected. Models that
    inherit DeepseekV2DecoderLayer (Eagle, MTP, Mistral Large 3) benefit
    automatically.

FP4 vs FP8 behavior

The fused op behaves differently depending on quantization:

  • FP4 / BF16: fused_allreduce_rmsnorm is preserved at compile time and
    executed as a single AITER kernel (custom_fused_ar_rms).
  • FP8: fused_allreduce_rmsnorm is decomposed at compile time into
    all_reduce + rmsnorm_with_add, then rmsnorm_with_add + fp8_quant are
    fused into one AITER op by the existing RocmAiterRMSNormQuantFusionPass.

Changes

File What
vllm/_aiter_ops.py Add is_fused_allreduce_rmsnorm_supported() auto-detection
vllm/model_executor/models/deepseek_v2.py Move allreduce from projections into RMSNorm layers
vllm/model_executor/layers/layernorm.py Add fused_allreduce parameter to RMSNorm
vllm/distributed/parallel_state.py Register fused_allreduce_rmsnorm custom op + graph capture
vllm/distributed/communication_op.py Add TP wrapper function
vllm/distributed/device_communicators/cuda_communicator.py Init AITER CustomAllreduce, implement fused kernel + fallback
vllm/compilation/passes/fusion/rocm_aiter_fusion.py Decompose fused op for FP8 quant compatibility

Test plan

  • Unit tests (tests/rocm/aiter/test_fused_ar_rmsnorm.py):
    • Custom op registration and fake tensor shapes
    • Auto-detection gating (disabled without gfx950 / AITER / RMSNorm)
    • Graph-level decomposition for FP8 path
    • Graph-level preservation for FP4/BF16 path
  • Multi-GPU correctness (tests/distributed/test_fused_ar_rmsnorm.py):
    • Numerical parity: fused path vs split allreduce + rmsnorm (2 GPUs)
    • world_size=1 fallback to add + rmsnorm
  • Multi-GPU compiler integration (tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py):
    • FP4/BF16: fused_allreduce_rmsnorm preserved through torch.compile
    • FP8: fused_allreduce_rmsnorm decomposed, output matches unfused baseline
  • Benchmark (MI355X, 8×GPU TP=8):
    • vllm bench latency with deepseek-ai/DeepSeek-R1-0528 --quantization fp8
    • vllm bench latency with amd/DeepSeek-R1-0528-MXFP4 --quantization quark
    • Compare TPOT / ITL against upstream vLLM (baseline without this feature)

vllmellm added 10 commits March 6, 2026 11:02
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
… file

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…d-rmsnorm

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify Bot added deepseek Related to DeepSeek models nvidia rocm Related to AMD ROCm labels Mar 23, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 23, 2026
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for DeepSeek models on ROCm MI355X hardware by fusing AllReduce and RMSNorm operations. The implementation is well-structured, leveraging AITER's custom kernels and providing a fallback for unsupported cases. The changes correctly handle different execution paths for FP8 and FP4/BF16 quantization, including a clever decomposition pass for torch.compile compatibility. The integration into the model and communication layers is clean and follows existing patterns. My review includes one high-severity comment regarding undocumented magic numbers in the kernel dispatch logic, which should be addressed to improve maintainability.

Comment on lines +292 to +296
can_use_fused = (
n <= 16384
and total_bytes < 8 * 1024 * 8192
and self.world_size != 6
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These conditions for using the fused kernel contain several 'magic numbers' (16384, 8 * 1024 * 8192, and 6) that are not explained. Undocumented magic numbers make the code harder to understand, maintain, and debug.

Specifically, the condition self.world_size != 6 is concerning as it suggests a potential bug or limitation in the underlying AITER kernel for that specific configuration.

Please add comments explaining the origin and purpose of these values. For example:

  • Are they from performance tuning?
  • Are they hard limitations of the AITER kernel?
  • Is world_size != 6 a workaround for a known issue? If so, linking to the issue would be very helpful.

This documentation is critical for future developers to understand the constraints of this optimization and to know when these values might need to be updated.

Fuse tensor-parallel allreduce into RMSNorm layers using AITER's
custom allreduce+residual-add+rmsnorm kernel on gfx950 (MI355X).
Reduces kernel launch overhead and memory traffic for DeepSeek V2/V3/R1
models under FP4 and FP8 quantization.
How it works:
- Moves allreduce out of o_proj and MoE projections into the subsequent
  RMSNorm, where AITER's fused kernel handles allreduce + residual-add +
  rmsnorm in a single operation.
- FP4/BF16: the fused op is preserved at compile time and executed as
  one AITER kernel (custom_fused_ar_rms).
- FP8: the fused op is decomposed at compile time into all_reduce +
  rmsnorm_with_add, then rmsnorm_with_add + fp8_quant are fused by the
  existing RocmAiterRMSNormQuantFusionPass pattern matcher.
Auto-detection:
- Automatically enabled when gfx950, AITER, RMSNorm kernels, and AITER
  CustomAllreduce are available with TP > 1. No environment variable
  needed.
- Only wired into deepseek_v2.py; other models are unaffected.
  Models inheriting DeepseekV2DecoderLayer (Eagle, MTP, Mistral Large 3)
  benefit automatically.
Files changed:
- _aiter_ops.py: add is_fused_allreduce_rmsnorm_supported()
- deepseek_v2.py: move allreduce from projections into RMSNorm layers
- layernorm.py: add fused_allreduce parameter to RMSNorm
- parallel_state.py: register fused_allreduce_rmsnorm custom op
- communication_op.py: add TP wrapper function
- cuda_communicator.py: init AITER CustomAllreduce, implement fused op
- rocm_aiter_fusion.py: decompose fused op for FP8 quant compatibility

Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
@attila-dusnoki-htec attila-dusnoki-htec marked this pull request as ready for review March 24, 2026 13:14
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Mar 24, 2026

This PR enables a broader usage of all reduce + rmsnorm and its performance gain effect #37646

@attila-dusnoki-htec
Copy link
Copy Markdown
Author

This PR enables a broader usage of all reduce + rmsnorm and its performance gain effect #37646

Thanks for the info! I checked it out, and that is alone will not pick-up the kernels for dsr1 sadly.
I will rebase my code on top of that PR and will push an update when everything works.

…o dsr1-ar-rmsnorm

Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
@andyluo7
Copy link
Copy Markdown

✅ Tested on MI355X (gfx950) — All Tests Pass

Tested on 8x AMD Instinct MI355X (gfx950, ROCm 7.0.1) with VLLM_ROCM_USE_AITER=1.

Unit Tests — 55 passed, 0 failed

Test Suite Result
tests/rocm/aiter/test_fused_ar_rmsnorm.py 11 passed
tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py 2 passed
tests/distributed/test_fused_ar_rmsnorm.py (multi-GPU, Ray) 2 passed
tests/compile/passes/distributed/test_fusion_all_reduce.py 4 passed, 12 skipped ✅
tests/compile/fusions_e2e/test_tp2_ar_rms.py 36 passed, 116 skipped ✅

All skips are NVIDIA-only tests — no failures.

Feature Detection — Working

  • is_fused_allreduce_rmsnorm_supported()True on gfx950
  • All 8 workers log: "AITER CustomAllreduce initialized for fused AR+RMSNorm kernel"
  • All 8 workers log: "Fused AllReduce+RMSNorm enabled for MI355X"

End-to-End Benchmark — DeepSeek-R1 FP8 Dynamic, TP=8

Low concurrency (10 requests, input=128, output=128):

Metric Value
Output throughput 444 tok/s
Median TTFT 123 ms
Median TPOT 21.6 ms
Median ITL 15.1 ms

High concurrency (32 requests, input=512, output=256):

Metric Value
Output throughput 1,377 tok/s
Peak output throughput 1,760 tok/s
Total throughput 4,124 tok/s
Request throughput 5.38 req/s
Median TTFT 805 ms
Median TPOT 20.1 ms
Median ITL 18.2 ms
Failed requests 0

Notes

  • Model loaded successfully: 81.73 GiB in ~629s
  • Backend: ROCM_AITER_MLA attention + AITER FP8 MoE
  • VLLM_ROCM_USE_AITER=1 is required to activate the fused path
  • No baseline (without PR) comparison in this run — confirming correctness and stability on MI355X hardware

Great work on this PR! The fused allreduce+rmsnorm path works cleanly on MI355X with zero test failures. 🚀

Test environment: 8x MI355X (gfx950:sramecc+:xnack-), ROCm 7.0.1, vLLM built from PR branch (includes dep PR #37646)

ChuanLi1101 added a commit to ChuanLi1101/vllm-rocm-docker that referenced this pull request Mar 31, 2026
Base image: rocm/vllm-dev:base_custom_rocm_7.2.1_torch_triton_0330_vllm018

Patches applied:
- AITER SplitK bug fix (ROCm/aiter#2508)
- vLLM persistent MLA kernel (vllm-project/vllm#36574)
- vLLM fused AllReduce+RMSNorm (vllm-project/vllm#37891)

Made-with: Cursor
@attila-dusnoki-htec
Copy link
Copy Markdown
Author

Hey, any chance someone could review this? :)

@ppalanga
Copy link
Copy Markdown

@gshtras : Can you please look at this PR? Thanks

var_hidden_size: int | None = None,
has_weight: bool = True,
dtype: torch.dtype | None = None,
fused_allreduce: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be done through a pattern matching mechanism?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dependent PR #37646 already implements a pure pattern matching mechanism for this — RocmAiterAllReduceFusionPass with AiterAllreduceFusedAddRMSNormPattern that matches all_reduce → fused_add_rmsnorm sequences in the compiled graph and replaces them with AITER's fused kernel. This mirrors how the existing CUDA AllReduceFusionPass works with flashinfer/trtllm.

However, as the PR author noted, #37646 alone doesn't pick up the kernels for DeepSeek because the all_reduce → rmsnorm pattern doesn't naturally exist as adjacent nodes in the FX graph. In DeepSeek's decoder layer, the allreduce and rmsnorm are structurally separated:

Attention path: all_reduce happens inside o_proj (a RowParallelLinear with default reduce_results=True), which performs the allreduce internally in its forward method. After self_attn returns, the next rmsnorm is post_attention_layernorm — separated by the attention return boundary and potential FP16 overflow scaling.

MoE path: The allreduce happens inside self.experts.maybe_all_reduce_tensor_model_parallel(), a method on SharedFusedMoE with conditional logic. After self.mlp() returns, the next rmsnorm (input_layernorm of layer N+1) is across the layer iteration boundary.

A compiler pass that could automatically move allreduce boundaries (detect allreduce inside a projection, find the subsequent rmsnorm across residual connections/layer boundaries, and rewrite the graph) would be significantly more complex. Even the CUDA flashinfer path relies on the all_reduce → rmsnorm pattern already being adjacent in the graph.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras @dllehr-amd added an implementation without changing model definition. Please check: #38762

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I don't think this is right; this should be possible via pattern matching. Can you post the resulting fx graph and show the ops in between? Is it just the view? That should be eliminated by the existing NoOpEliminationPass, no?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg Yes, it's just view. I added the fx graph relevant section in #38762 before/after adding _bypass_noop_views_after_allreduce() to RocmAiterAllReduceFusionPass that shows view goes away.
The NoOpEliminationPass handles _aten.reshape.default, aten.slice.Tensor, aten.slice_scatter.default_. Do you think it would be better to extend this method to include aten.view.default instead of adding _bypass_noop_views_after_allreduce() in ROCm pass?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ProExpertProg
You are right, the view was not the problem.
I added a comment to the original PR #37646 (comment)

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @attila-dusnoki-htec.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@attila-dusnoki-htec
Copy link
Copy Markdown
Author

@rbrugaro-amd Thanks for figuring out! Closing this in favour of #38762

vllmellm and others added 5 commits April 9, 2026 10:05
…d-rmsnorm

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…o HEAD

Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
@attila-dusnoki-htec
Copy link
Copy Markdown
Author

I did not manage to make the pattern-matching version work, so i'm re-opening this solution.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @attila-dusnoki-htec.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models nvidia rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants