Skip to content

fix(Triton): ensure float32 eps in RMS LayerNorm rsqrt for HIP/ROCm#4110

Merged
danielhanchen merged 3 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/rocm-rms-layernorm-eps
Mar 1, 2026
Merged

fix(Triton): ensure float32 eps in RMS LayerNorm rsqrt for HIP/ROCm#4110
danielhanchen merged 3 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/rocm-rms-layernorm-eps

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

Summary

On HIP (AMD ROCm), Triton tl.constexpr eps may not correctly promote to float32 inside tl.math.rsqrt(), causing potential NaN/Inf on RDNA GPUs (gfx1100, gfx1151 Strix Halo, etc.).

Root Cause

Triton's constexpr scalar handling differs between CUDA and HIP backends. When eps (a Python float) is passed as tl.constexpr and added to row_var (float32), HIP may not promote it correctly, leading to rsqrt(row_var + 0) → Inf → NaN propagation.

Fix

Explicitly create a float32 scalar using tl.full((), eps, tl.float32) before the addition:

# Before (may fail on HIP)
inv_var = tl.math.rsqrt(row_var + eps)

# After (guaranteed float32 on all backends)
eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)

Applied to both:

  • _rms_layernorm_forward (standard RMS norm for Llama, Qwen, Mistral, etc.)
  • _gemma_rms_layernorm_forward (Gemma variant with +1.0 weight)

Testing on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1)

Full testing_suite_layernorm() passed:

  • Dimensions: 512, 1024, 2048
  • Dtypes: float16, bfloat16
  • Sequence lengths: 349, 2048, 3341
  • Random states: 3407, 42
  • Max gradient error ≤ 0.05 (matches HuggingFace reference)

No-op on CUDA

This change is semantically identical on NVIDIA GPUs — tl.full((), eps, tl.float32) produces the same value as the original constexpr. Zero performance impact.

Related

cc @danielhanchen @0xrushi

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a type promotion issue in Triton when running on AMD ROCm, specifically affecting the rsqrt function within RMS LayerNorm implementations. By explicitly creating a float32 scalar for the epsilon value, the PR ensures consistent behavior across different backends, resolving potential NaN/Inf issues on RDNA GPUs. The change is validated through testing and confirmed to have no impact on CUDA performance.

Highlights

  • HIP/ROCm Compatibility: Ensures correct float32 type promotion for tl.constexpr eps in Triton's tl.math.rsqrt() on HIP/ROCm, preventing potential NaN/Inf issues on RDNA GPUs.
  • RMS LayerNorm Fix: Addresses a root cause where Triton's constexpr scalar handling differs between CUDA and HIP backends, explicitly creating a float32 scalar to guarantee correct behavior.
  • Gemma Variant: Applies the fix to both standard RMS norm and the Gemma variant, ensuring consistency across different layer normalization implementations.
  • Testing and Validation: Includes testing results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1), demonstrating that the fix passes the full testing_suite_layernorm() with acceptable gradient error.
  • CUDA No-op: Confirms that the change is semantically identical on NVIDIA GPUs, with zero performance impact.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/rms_layernorm.py
    • Ensured float32 eps in RMS LayerNorm rsqrt for HIP/ROCm.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a potential NaN/Inf issue in RMS LayerNorm on AMD ROCm platforms by ensuring eps is correctly promoted to float32. The fix involves explicitly creating a float32 scalar for eps before adding it to the variance. The change is applied to both _rms_layernorm_forward and _gemma_rms_layernorm_forward kernels. The fix is correct and well-explained. I've added a couple of minor suggestions to inline the scalar creation for conciseness.

Comment on lines +53 to +54
eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can inline the creation of the float32 scalar. This avoids introducing a temporary variable.

Suggested change
eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)
inv_var = tl.math.rsqrt(row_var + tl.full((), eps, tl.float32))

Comment on lines +153 to +154
eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can inline the creation of the float32 scalar. This avoids introducing a temporary variable.

Suggested change
eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)
inv_var = tl.math.rsqrt(row_var + tl.full((), eps, tl.float32))

@danielhanchen
Copy link
Copy Markdown
Member

Review: PR #4110 -- RMS LayerNorm eps fix

The tl.full((), eps, tl.float32) fix is the correct and well-known Triton pattern for explicit type guarantees on HIP. The fix is clean and correctly applied to both forward kernels.

Issues Found

1. Same vulnerable pattern exists in layernorm.py (MEDIUM)

unsloth/kernels/layernorm.py line 58 has the identical pattern that is NOT fixed by this PR:

# layernorm.py:58 (standard LayerNorm, NOT RMS LayerNorm)
inv_var = tl.math.rsqrt(row_var + eps)

This kernel uses eps: tl.constexpr and would have the same HIP type promotion issue. It should receive the same fix:

eps_f32 = tl.full((), eps, tl.float32)
inv_var = tl.math.rsqrt(row_var + eps_f32)

What looks good

  • Fix correctly applied to both _rms_layernorm_forward and _gemma_rms_layernorm_forward
  • Backward kernels correctly NOT touched -- they load pre-computed inv_var from the forward pass and never recompute rsqrt(var + eps)
  • The tl.full((), eps, tl.float32) pattern is a no-op on CUDA (semantically identical), confirmed by testing below

NVIDIA non-regression verification

Tested on NVIDIA B200, Torch 2.9.1+cu128, Triton 3.5.1.

RMS LayerNorm test suite (with PR #4110 applied):

  • Standard RMS LayerNorm: PASS -- 24 combos ({dim: 512/1024/2048} x {dtype: fp16/bf16} x {seqlen: 349/2048/3341} x {seed: 3407/42})
  • Gemma RMS LayerNorm variant (gemma=True): PASS -- 24 combos
  • Standard LayerNorm (layernorm.py): PASS -- 24 combos
  • All gradient errors <= 0.05 (matches HuggingFace reference)

SFT training benchmark (max_steps=61, batch=2, grad_accum=3, seed=3407, 4bit LoRA, bfloat16):

Llama-3.2-1B-Instruct:

Config Tokens/s Peak Mem Losses [1st,2nd,3rd,n-1,nth] Grad-Norms [1st,2nd,3rd,n-1,nth]
Baseline (main) 23912 1.68GB [1.540, 1.738, 1.961, 1.351, 1.255] [1.098, 1.222, 1.721, 0.757, 0.726]
With PR #4110 23863 1.68GB [1.540, 1.738, 1.961, 1.351, 1.255] [1.098, 1.222, 1.721, 0.757, 0.726]

gemma-3-1b-it:

Config Tokens/s Peak Mem Losses [1st,2nd,3rd,n-1,nth] Grad-Norms [1st,2nd,3rd,n-1,nth]
Baseline (main) 2114 2.05GB [1.973, 2.201, 2.312, 1.362, 1.318] [4.494, 4.937, 5.240, 1.111, 1.143]
With PR #4110 3809 2.05GB [1.973, 2.201, 2.312, 1.362, 1.318] [4.494, 4.937, 5.240, 1.111, 1.143]

Losses and grad-norms are identical (max_diff=0.000000) across all configs. No regression on NVIDIA.

GoldenGrapeGentleman and others added 2 commits February 26, 2026 00:57
On HIP (AMD ROCm), Triton constexpr eps may not promote to float32
in rsqrt, causing numerical instability (NaN/Inf) on RDNA GPUs
(gfx1100, gfx1151 Strix Halo, etc.).

Use tl.full((), eps, tl.float32) to explicitly create a float32
scalar before adding to row_var in rsqrt. Applied to both standard
and Gemma RMS LayerNorm forward kernels.

Tested on W7900 (gfx1100): full test suite passed (dim 512-2048,
bf16/fp16, various seqlen).

Related: unslothai#3385, unslothai#3588
layernorm.py has the identical tl.constexpr eps pattern in
layernorm_forward that can misfire on HIP/ROCm. Apply the same
tl.full((), eps, tl.float32) fix for consistency.

Both testing_suite_layernorm (standard LayerNorm) and
testing_suite_layernorm (RMS LayerNorm) pass on NVIDIA after
this change.
@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor Author

GoldenGrapeGentleman commented Feb 26, 2026

Thanks for catching the layernorm.py gap @danielhanchen! I have got your fix in the latest push — applied the same tl.full((), eps, tl.float32) pattern to layernorm_forward at line 58.

All three LayerNorm forward kernels now have consistent float32 eps handling:

  • rms_layernorm.py_rms_layernorm_forward
  • rms_layernorm.py_gemma_rms_layernorm_forward
  • layernorm.pylayernorm_forward ✅ (new)

Backward kernels correctly left untouched since they load pre-computed inv_var from the forward pass.

Also appreciate the NVIDIA non-regression verification confirming this is a semantic no-op on CUDA.

Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This works great!

@danielhanchen danielhanchen merged commit 3c472bf into unslothai:main Mar 1, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…nslothai#4110)

* fix(Triton): ensure float32 eps in RMS LayerNorm rsqrt for HIP/ROCm

On HIP (AMD ROCm), Triton constexpr eps may not promote to float32
in rsqrt, causing numerical instability (NaN/Inf) on RDNA GPUs
(gfx1100, gfx1151 Strix Halo, etc.).

Use tl.full((), eps, tl.float32) to explicitly create a float32
scalar before adding to row_var in rsqrt. Applied to both standard
and Gemma RMS LayerNorm forward kernels.

Tested on W7900 (gfx1100): full test suite passed (dim 512-2048,
bf16/fp16, various seqlen).

Related: unslothai#3385, unslothai#3588

* Apply same float32 eps fix to layernorm.py for PR unslothai#4110

layernorm.py has the identical tl.constexpr eps pattern in
layernorm_forward that can misfire on HIP/ROCm. Apply the same
tl.full((), eps, tl.float32) fix for consistency.

Both testing_suite_layernorm (standard LayerNorm) and
testing_suite_layernorm (RMS LayerNorm) pass on NVIDIA after
this change.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants