Skip to content

Fix Gemma3 NaN losses on ROCm by disabling torch.compile for RDNA GPUs#4029

Closed
GoldenGrapeGentleman wants to merge 2 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix-gemma3-nan-on-rocm
Closed

Fix Gemma3 NaN losses on ROCm by disabling torch.compile for RDNA GPUs#4029
GoldenGrapeGentleman wants to merge 2 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix-gemma3-nan-on-rocm

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

@GoldenGrapeGentleman GoldenGrapeGentleman commented Feb 11, 2026

Summary

Gemma3 training produces NaN losses from the first step on all RDNA GPUs (gfx1100, gfx1151). The compiled forward path is numerically unstable on the ROCm/Triton backend, while the eager path trains correctly.

This adds a targeted compile disable for Gemma3 on HIP, following the same UNSLOTH_COMPILE_DISABLE pattern already used for Sesame/CSM models.

Reproduction

from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

ds = Dataset.from_list([{"text": "Q: Hi\nA: Hello!"}] * 20)
m, t = FastLanguageModel.from_pretrained(
    "unsloth/gemma-3-1b-pt", max_seq_length=256, dtype=None, load_in_4bit=True,
)
m = FastLanguageModel.get_peft_model(m, r=16,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    use_gradient_checkpointing="unsloth",
)
trainer = SFTTrainer(model=m, tokenizer=t, train_dataset=ds,
    args=SFTConfig(output_dir="./out", max_steps=3, bf16=True,
        per_device_train_batch_size=2, dataset_text_field="text",
        max_seq_length=256, report_to="none"))
trainer.train()  # loss=nan on every step

Root Cause

The Triton/ROCm compiler backend generates numerically unstable code for Gemma3. Other architectures (Llama, Qwen, Mistral) compile and train correctly on the same GPU. The eager path for Gemma3 is numerically correct — disabling compile is sufficient.

The module-level constant unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE must also be overridden at runtime because it is evaluated at import time, before the model-specific configuration in loader.py runs.

Testing

Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1, PyTorch 2.8.0):

Model Before After
Gemma-3-1B SFT (4bit bf16) ❌ NaN (all steps) ✅ loss=3.81
Llama-3.2-1B SFT (regression) ✅ loss=4.34 ✅ loss=4.40
Qwen2.5-1.5B SFT (regression) ✅ loss=2.88 ✅ (unaffected)

Also verified on a fresh container with no pre-existing compile cache.

Closes #3385

Related: PR #3588 (broader RDNA3 fix attempt, open since Nov 2025)

Co-authored-by: billishyahao bill.he@amd.com
Co-authored-by: yueyuan yueyuan@amd.com

cc @danielhanchen @0xrushi @kyuz0

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical numerical instability issue affecting Gemma3 model training on AMD RDNA GPUs with ROCm, where torch.compile led to NaN losses. By selectively disabling compilation for Gemma3 on HIP devices, the PR ensures stable and correct training, leveraging the numerically sound eager execution path without impacting other models or architectures.

Highlights

  • Fixes NaN losses for Gemma3 on ROCm/RDNA GPUs: Addresses a numerical instability issue where Gemma3 training produced NaN losses from the first step on RDNA GPUs (gfx1100, gfx1151) when using the compiled forward path.
  • Disables torch.compile for Gemma3 on HIP: Implements a targeted disable of torch.compile specifically for Gemma3 models when running on HIP (ROCm) devices, as the eager path is numerically stable.
  • Overrides UNSLOTH_COMPILE_DISABLE at runtime: Ensures the UNSLOTH_COMPILE_DISABLE flag is correctly set at runtime, as it's evaluated at import time before model-specific configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/models/loader.py
    • Added conditional logic within the from_pretrained function to check for HIP devices when loading Gemma3 models.
    • Set the UNSLOTH_COMPILE_DISABLE environment variable to "1".
    • Imported unsloth_zoo.compiler and explicitly set unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE to True to ensure the flag is active at runtime.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a targeted fix to disable torch.compile for Gemma3 models on ROCm platforms, addressing an issue with NaN losses during training on RDNA GPUs. The change correctly identifies the specific hardware (hip) and model type (gemma3) and applies a workaround by setting UNSLOTH_COMPILE_DISABLE and patching the unsloth_zoo.compiler module at runtime. The implementation is sound and directly addresses the root cause described. I have one minor suggestion to improve code style by moving the local import to the top of the file.

Comment thread unsloth/models/loader.py
# See https://github.com/unslothai/unsloth/issues/3385
if DEVICE_TYPE == "hip":
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
import unsloth_zoo.compiler
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code style and adhere to PEP 8, it's better to have imports at the top of the file. Please move this import to the top of unsloth/models/loader.py and remove it from here.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 18a6c6ad78

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/loader.py
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
import unsloth_zoo.compiler

unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid persisting compile-disable globally after Gemma3 load

This assignment mutates a process-wide compiler flag and there is no corresponding reset path in FastLanguageModel.from_pretrained, so once a HIP Gemma3 load runs, later loads of other model families in the same Python process can inherit compile-disabled behavior unintentionally. That creates a silent cross-model regression (not just Gemma3) where subsequent calls into unsloth_compile_transformers may stay on eager paths and lose expected optimizations.

Useful? React with 👍 / 👎.

GoldenGrapeGentleman and others added 2 commits February 12, 2026 02:57
On RDNA GPUs (gfx1100, gfx1151, etc.), Gemma3's compiled forward path
produces NaN losses from the first training step. The eager (uncompiled)
path is numerically correct.

Root cause: the Triton/ROCm compiler backend generates numerically
unstable code for Gemma3's architecture. Other model architectures
(Llama, Qwen) compile and train correctly on the same hardware.

Fix: set UNSLOTH_COMPILE_DISABLE=1 for Gemma3 on HIP, following the
same pattern used for Sesame/CSM models. The module-level constant in
unsloth_zoo.compiler is also updated since it is evaluated at import
time before loader.py's model-specific block runs.

Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1, PyTorch 2.8.0):
  - Gemma-3-1B SFT 4bit bf16: NaN → loss=3.81 (fixed)
  - Llama-3.2-1B SFT (regression): loss=4.40 (unaffected)

Closes unslothai#3385
Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on NVIDIA B200 (compute capability 10.0, CUDA 12.8, torch 2.9.1) -- no regressions observed.

Code review notes:

  • The DEVICE_TYPE == "hip" guard is consistent with the existing pattern used across the codebase (6+ existing usages).
  • Both os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" and the module-level unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE = True are set, which covers the critical compile decision path in compiler.py.
  • Minor gap: common.py binds UNSLOTH_COMPILE_DISABLE at import time before loader runs, so its local copy is not patched. However, the env var read in compiler.py:2848 is done at call time and IS affected, so the fix works in practice.
  • No-op on NVIDIA -- the entire block is gated behind DEVICE_TYPE == "hip".

LGTM.

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor Author

Tested on NVIDIA B200 (compute capability 10.0, CUDA 12.8, torch 2.9.1) -- no regressions observed.

Code review notes:

  • The DEVICE_TYPE == "hip" guard is consistent with the existing pattern used across the codebase (6+ existing usages).
  • Both os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" and the module-level unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE = True are set, which covers the critical compile decision path in compiler.py.
  • Minor gap: common.py binds UNSLOTH_COMPILE_DISABLE at import time before loader runs, so its local copy is not patched. However, the env var read in compiler.py:2848 is done at call time and IS affected, so the fix works in practice.
  • No-op on NVIDIA -- the entire block is gated behind DEVICE_TYPE == "hip".

LGTM.

Thanks @danielhanchen ~

GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 25, 2026
…dna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED
@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor Author

Closing in favor of #4109, which supersedes this PR with improvements based on Daniel's review:

  • Scoped compile disable to RDNA only (is_rdna()) instead of all HIP devices
  • Uses "partial" mode instead of full disable, preserving loss compilation
  • Adds is_rdna() detection for future RDNA-specific optimizations

All fixes from this PR are fully contained in #4109.

GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 28, 2026
…dna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED
pull Bot pushed a commit to edisplay/unsloth that referenced this pull request Mar 1, 2026
…dna() (unslothai#4109)

* fix(ROCm): comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: scope compile disable to RDNA only, use partial mode, remove unused import

Changes based on Daniel's review:
1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling
   torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine
2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to
   only disable model forward compilation while keeping loss compilation,
   matching the existing Sesame pattern
3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401)

* Remove redundant is_cdna/is_rdna exports from kernels/__init__.py

These functions are imported directly from .utils where needed
(e.g. cross_entropy_loss.py, loader.py). No external code imports
them from the unsloth.kernels namespace.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…dna() (unslothai#4109)

* fix(ROCm): comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: scope compile disable to RDNA only, use partial mode, remove unused import

Changes based on Daniel's review:
1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling
   torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine
2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to
   only disable model forward compilation while keeping loss compilation,
   matching the existing Sesame pattern
3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401)

* Remove redundant is_cdna/is_rdna exports from kernels/__init__.py

These functions are imported directly from .utils where needed
(e.g. cross_entropy_loss.py, loader.py). No external code imports
them from the unsloth.kernels namespace.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Training on ROCm (gfx1151, Strix Halo) results in NaN losses with Gemma3 fine-tuning

2 participants