fix(Triton): ensure float32 eps in RMS LayerNorm rsqrt for HIP/ROCm#4110
Conversation
Summary of ChangesHello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a type promotion issue in Triton when running on AMD ROCm, specifically affecting the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a potential NaN/Inf issue in RMS LayerNorm on AMD ROCm platforms by ensuring eps is correctly promoted to float32. The fix involves explicitly creating a float32 scalar for eps before adding it to the variance. The change is applied to both _rms_layernorm_forward and _gemma_rms_layernorm_forward kernels. The fix is correct and well-explained. I've added a couple of minor suggestions to inline the scalar creation for conciseness.
| eps_f32 = tl.full((), eps, tl.float32) | ||
| inv_var = tl.math.rsqrt(row_var + eps_f32) |
There was a problem hiding this comment.
| eps_f32 = tl.full((), eps, tl.float32) | ||
| inv_var = tl.math.rsqrt(row_var + eps_f32) |
There was a problem hiding this comment.
Review: PR #4110 -- RMS LayerNorm eps fixThe Issues Found1. Same vulnerable pattern exists in
# layernorm.py:58 (standard LayerNorm, NOT RMS LayerNorm)
inv_var = tl.math.rsqrt(row_var + eps)This kernel uses eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)What looks good
NVIDIA non-regression verificationTested on NVIDIA B200, Torch 2.9.1+cu128, Triton 3.5.1. RMS LayerNorm test suite (with PR #4110 applied):
SFT training benchmark (max_steps=61, batch=2, grad_accum=3, seed=3407, 4bit LoRA, bfloat16): Llama-3.2-1B-Instruct:
gemma-3-1b-it:
Losses and grad-norms are identical (max_diff=0.000000) across all configs. No regression on NVIDIA. |
On HIP (AMD ROCm), Triton constexpr eps may not promote to float32 in rsqrt, causing numerical instability (NaN/Inf) on RDNA GPUs (gfx1100, gfx1151 Strix Halo, etc.). Use tl.full((), eps, tl.float32) to explicitly create a float32 scalar before adding to row_var in rsqrt. Applied to both standard and Gemma RMS LayerNorm forward kernels. Tested on W7900 (gfx1100): full test suite passed (dim 512-2048, bf16/fp16, various seqlen). Related: unslothai#3385, unslothai#3588
layernorm.py has the identical tl.constexpr eps pattern in layernorm_forward that can misfire on HIP/ROCm. Apply the same tl.full((), eps, tl.float32) fix for consistency. Both testing_suite_layernorm (standard LayerNorm) and testing_suite_layernorm (RMS LayerNorm) pass on NVIDIA after this change.
|
Thanks for catching the All three LayerNorm forward kernels now have consistent float32 eps handling:
Backward kernels correctly left untouched since they load pre-computed Also appreciate the NVIDIA non-regression verification confirming this is a semantic no-op on CUDA. |
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you! This works great!
…nslothai#4110) * fix(Triton): ensure float32 eps in RMS LayerNorm rsqrt for HIP/ROCm On HIP (AMD ROCm), Triton constexpr eps may not promote to float32 in rsqrt, causing numerical instability (NaN/Inf) on RDNA GPUs (gfx1100, gfx1151 Strix Halo, etc.). Use tl.full((), eps, tl.float32) to explicitly create a float32 scalar before adding to row_var in rsqrt. Applied to both standard and Gemma RMS LayerNorm forward kernels. Tested on W7900 (gfx1100): full test suite passed (dim 512-2048, bf16/fp16, various seqlen). Related: unslothai#3385, unslothai#3588 * Apply same float32 eps fix to layernorm.py for PR unslothai#4110 layernorm.py has the identical tl.constexpr eps pattern in layernorm_forward that can misfire on HIP/ROCm. Apply the same tl.full((), eps, tl.float32) fix for consistency. Both testing_suite_layernorm (standard LayerNorm) and testing_suite_layernorm (RMS LayerNorm) pass on NVIDIA after this change. --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Summary
On HIP (AMD ROCm), Triton
tl.constexpreps may not correctly promote tofloat32insidetl.math.rsqrt(), causing potential NaN/Inf on RDNA GPUs (gfx1100, gfx1151 Strix Halo, etc.).Root Cause
Triton's constexpr scalar handling differs between CUDA and HIP backends. When
eps(a Python float) is passed astl.constexprand added torow_var(float32), HIP may not promote it correctly, leading torsqrt(row_var + 0)→ Inf → NaN propagation.Fix
Explicitly create a
float32scalar usingtl.full((), eps, tl.float32)before the addition:Applied to both:
_rms_layernorm_forward(standard RMS norm for Llama, Qwen, Mistral, etc.)_gemma_rms_layernorm_forward(Gemma variant with +1.0 weight)Testing on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1)
Full
testing_suite_layernorm()passed:No-op on CUDA
This change is semantically identical on NVIDIA GPUs —
tl.full((), eps, tl.float32)produces the same value as the original constexpr. Zero performance impact.Related
cc @danielhanchen @0xrushi