Skip to content

fix(ROCm): Comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()#4109

Merged
danielhanchen merged 5 commits into
unslothai:mainfrom
GoldenGrapeGentleman:rdna-comprehensive-fix
Mar 1, 2026
Merged

fix(ROCm): Comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()#4109
danielhanchen merged 5 commits into
unslothai:mainfrom
GoldenGrapeGentleman:rdna-comprehensive-fix

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

Summary

Comprehensive fix for AMD RDNA consumer/workstation GPUs (RX 7000/9000 series, PRO W7900, Strix Halo, etc.).

Changes

  1. is_rdna() detection (unsloth/kernels/utils.py)

    • New is_rdna() function detects all RDNA3/3.5/RDNA4 architectures (gfx11xx, gfx1151, gfx12xx)
    • Uses arch.startswith("gfx1") + not is_cdna() for future-proof coverage
  2. Gemma3 NaN fix (unsloth/models/loader.py)

  3. Export is_cdna/is_rdna (unsloth/kernels/__init__.py)

    • Makes both functions available for downstream kernel tuning
  4. Import is_rdna in cross_entropy_loss (unsloth/kernels/cross_entropy_loss.py)

    • Preparation for future RDNA-specific Triton kernel optimizations

Test Results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1)

Model Loss Trend NaN? Status
Gemma3-1B (4bit SFT) 3.37 → 3.25 ✅ PASS
Llama-3.2-1B (4bit SFT) 2.44 → 2.37 ✅ PASS
Qwen2.5-1.5B (4bit SFT) 1.89 → 1.85 ✅ PASS
RMS LayerNorm Triton (bf16/fp16) - - ✅ PASS
Cross Entropy Loss Triton (32K/256K) - - ✅ PASS

Environment

  • GPU: AMD Radeon PRO W7900 (gfx1100, 48GB VRAM)
  • ROCm: 7.1.25424
  • PyTorch: 2.8.0+rocm7.1
  • Triton: 3.4.0+rocm7.1
  • Transformers: 4.57.6

Related Issues

cc @danielhanchen

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances Unsloth's support for AMD RDNA GPUs by introducing robust detection for these architectures and resolving a critical numerical instability issue (NaNs) encountered when training Gemma3 models on ROCm. The changes ensure more reliable performance and pave the way for future RDNA-specific optimizations, improving the overall user experience for AMD GPU owners.

Highlights

  • RDNA GPU Detection: A new is_rdna() function has been added to detect AMD RDNA consumer/workstation GPUs (RDNA3, RDNA3.5, RDNA4 architectures like gfx11xx, gfx1151, gfx12xx) by checking arch.startswith("gfx1") and ensuring it's not a CDNA GPU.
  • Gemma3 NaN Fix on ROCm: To address numerical instability and NaN issues observed with Gemma3 models on RDNA GPUs when using torch.compile, torch.compile is now conditionally disabled for Gemma3 on HIP (ROCm) devices.
  • Exported GPU Detection Functions: Both is_cdna and the newly introduced is_rdna functions are now exported from unsloth/kernels/__init__.py, making them accessible for downstream kernel tuning and other modules.
  • Preparation for RDNA-specific Optimizations: The is_rdna function has been imported into unsloth/kernels/cross_entropy_loss.py, laying the groundwork for future RDNA-specific Triton kernel optimizations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/init.py
    • Exported is_cdna and is_rdna functions to make them publicly available.
  • unsloth/kernels/cross_entropy_loss.py
    • Imported the is_rdna function, preparing for future RDNA-specific kernel optimizations.
  • unsloth/kernels/utils.py
    • Added a new is_rdna() function, decorated with functools.lru_cache, to detect RDNA consumer/workstation GPUs based on their gfx architecture string and by excluding CDNA GPUs.
  • unsloth/models/loader.py
    • Implemented a conditional check within the Gemma3 model loading logic to disable torch.compile for HIP (ROCm) devices, preventing NaN issues on RDNA GPUs.
    • Set the UNSLOTH_COMPILE_DISABLE environment variable and unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE flag to True when Gemma3 is loaded on HIP.
Activity
  • No specific activity (comments, reviews, etc.) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for AMD RDNA GPUs. It adds an is_rdna() utility function to detect these GPUs and uses it to apply a targeted fix for a NaN issue with Gemma3 models on ROCm by disabling torch.compile. The changes also export the new detection function for broader use. My feedback focuses on making the fix even more precise by fully utilizing the new is_rdna() function.

Comment thread unsloth/models/loader.py Outdated
# (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.).
# Disable torch.compile; eager path is numerically correct.
# See https://github.com/unslothai/unsloth/issues/3385
if DEVICE_TYPE == "hip":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check DEVICE_TYPE == "hip" is overly broad for an issue specific to RDNA GPUs. Using the newly introduced is_rdna() function provides a more targeted fix. This will prevent disabling torch.compile on other HIP devices like CDNA GPUs, where it might be beneficial. The is_rdna() function already includes a check for HIP.

Suggested change
if DEVICE_TYPE == "hip":
if __import__("unsloth.kernels").is_rdna():

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 50ca55a095

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/loader.py Outdated
Comment on lines +1053 to +1054
if DEVICE_TYPE == "hip":
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Limit HIP compile disable to RDNA targets

FastModel.from_pretrained now disables torch.compile for every HIP device, but the regression described here is RDNA-specific. On ROCm CDNA systems (for example MI2xx/MI3xx), this path still executes and forces UNSLOTH_COMPILE_DISABLE=1, which can unnecessarily drop training/inference throughput for Gemma3 workloads even when they are not affected by the NaN issue. The guard should use RDNA detection instead of DEVICE_TYPE == "hip".

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Copy Markdown
Member

Review: PR #4109 -- RDNA GPU support

Thanks for working on AMD RDNA support. The is_rdna() function design is solid and complements is_cdna() well. A few items to address before merge:

Issues Found

1. Compile disable scope is too broad (HIGH)

loader.py uses DEVICE_TYPE == "hip" to disable compile for Gemma3, but this catches ALL HIP devices including CDNA (MI250X/MI300X/MI350) where compile works fine. The PR adds is_rdna() specifically for this purpose but does not use it here:

# Current (disables compile on ALL AMD GPUs):
if DEVICE_TYPE == "hip":
    os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"

# Should be (only disables on RDNA):
from unsloth.kernels.utils import is_rdna
if is_rdna():
    os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"

2. Full compile disable vs partial (MEDIUM)

The PR sets UNSLOTH_COMPILE_DISABLE = "1" (full disable), but existing model-specific code (e.g., Sesame at loader.py:1059) uses "partial". Full disable also blocks loss function compilation. Unless Gemma3 loss compilation is also broken on RDNA, consider using "partial" to match the existing pattern.

3. Unused import (LOW)

cross_entropy_loss.py imports is_rdna but never uses it. Linters will flag this as F401. Suggest removing the import until it is actually needed in a follow-up PR for RDNA-specific Triton kernel tuning.

What looks good

NVIDIA non-regression verification

Tested on NVIDIA B200, Torch 2.9.1+cu128, Triton 3.5.1:

  • is_rdna() returns False on NVIDIA
  • is_cdna() returns False on NVIDIA
  • UNSLOTH_COMPILE_DISABLE is not set after loading Gemma3 on NVIDIA (correct -- the DEVICE_TYPE == "hip" guard prevents it)

SFT training benchmark (max_steps=61, batch=2, grad_accum=3, seed=3407, 4bit LoRA, bfloat16):

Llama-3.2-1B-Instruct:

Config Tokens/s Peak Mem Losses [1st,2nd,3rd,n-1,nth] Grad-Norms [1st,2nd,3rd,n-1,nth]
Baseline (main) 23912 1.68GB [1.540, 1.738, 1.961, 1.351, 1.255] [1.098, 1.222, 1.721, 0.757, 0.726]
With PR #4109 24355 1.68GB [1.540, 1.738, 1.961, 1.351, 1.255] [1.098, 1.222, 1.721, 0.757, 0.726]

gemma-3-1b-it:

Config Tokens/s Peak Mem Losses [1st,2nd,3rd,n-1,nth] Grad-Norms [1st,2nd,3rd,n-1,nth]
Baseline (main) 2114 2.05GB [1.973, 2.201, 2.312, 1.362, 1.318] [4.494, 4.937, 5.240, 1.111, 1.143]
With PR #4109 3834 2.05GB [1.973, 2.201, 2.312, 1.362, 1.318] [4.494, 4.937, 5.240, 1.111, 1.143]

Losses and grad-norms are identical (max_diff=0.000000) across all configs. No regression on NVIDIA.

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor Author

GoldenGrapeGentleman commented Feb 26, 2026

Thanks for the thorough review @danielhanchen! All 3 issues addressed in the latest push:

  1. HIGH — Compile disable scope: Replaced DEVICE_TYPE == "hip" with is_rdna() so CDNA GPUs (MI250X/MI300X/MI350) are no longer affected.
  2. MEDIUM — Partial vs full disable: Changed to "partial" to match the Sesame pattern — only model forward compilation is disabled, loss compilation remains active.
  3. LOW — Unused import: Removed is_rdna import from cross_entropy_loss.py.

Also appreciate the NVIDIA non-regression verification — great to see identical losses and grad-norms across all configs.

image

GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 27, 2026
Add is_rdna() detection for RDNA3/3.5/4 (gfx11xx, gfx12xx) and optimize
Triton kernel launch parameters based on W7900 (gfx1100) benchmarks:

- calculate_settings(): double num_warps (capped at 32) for RDNA.
  RDNA's dual-issue SIMD32 CUs achieve higher occupancy with more warps.
- Cross entropy (chunked, large vocab): use 16 warps on RDNA (same as
  CDNA), benchmarked ~10% faster than the NVIDIA default of 32.
- Export is_cdna/is_rdna from kernels for downstream use.

Benchmark results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1):

  RMS LayerNorm (n_rows=4096, bf16):
    dim=2048:  8w 102us -> 16w  95us (+7.6%)
    dim=4096:  8w 210us -> 16w 203us (+3.1%)
    dim=8192: 16w 370us -> 32w 344us (+7.0%)

  Cross Entropy chunked (vocab=128256):
    32w 1930us -> 16w 1728us (+10.5%)

Correctness: all RMS LN, LN, CE tests pass (fwd+bwd, fp16/bf16).

Builds on unslothai#4109 (is_rdna detection for Gemma3 NaN fix).
GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 27, 2026
Add is_rdna() detection for RDNA3/3.5/4 (gfx11xx, gfx12xx) and optimize
Triton kernel launch parameters based on W7900 (gfx1100) benchmarks:

- calculate_settings(): double num_warps (capped at 32) for RDNA.
  RDNA's dual-issue SIMD32 CUs achieve higher occupancy with more warps.
- Cross entropy (chunked, large vocab): use 16 warps on RDNA (same as
  CDNA), benchmarked ~10% faster than the NVIDIA default of 32.
- Export is_cdna/is_rdna from kernels for downstream use.

Benchmark results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1):

  RMS LayerNorm (n_rows=4096, bf16):
    dim=2048:  8w 102us -> 16w  95us (+7.6%)
    dim=4096:  8w 210us -> 16w 203us (+3.1%)
    dim=8192: 16w 370us -> 32w 344us (+7.0%)

  Cross Entropy chunked (vocab=128256):
    32w 1930us -> 16w 1728us (+10.5%)

Correctness: all RMS LN, LN, CE tests pass (fwd+bwd, fp16/bf16).

Builds on unslothai#4109 (is_rdna detection for Gemma3 NaN fix).
GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 28, 2026
Use 16 warps for RDNA in the chunked cross-entropy forward kernel
(large vocab > 65536), matching the existing CDNA optimization.

Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median):
  - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32
  - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is
    already optimal for RDNA; no modification needed.

Depends on: unslothai#4109 (provides is_rdna() detection)
@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor Author

Updated: Rebased onto latest main and removed redundant is_cdna/is_rdna exports from kernels/__init__.py — no external code imports them from that namespace (all internal usage is from .utils import is_rdna).

Current diff (3 files, clean):

  • kernels/utils.py: add is_rdna() with lru_cache
  • models/loader.py: disable compile for Gemma3 on RDNA only (partial mode)
  • kernels/__init__.py: no changes (exports removed)

All 3 review items from Daniel have been addressed.

GoldenGrapeGentleman and others added 5 commits February 28, 2026 23:39
…dna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED
… remove unused import

Changes based on Daniel's review:
1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling
   torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine
2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to
   only disable model forward compilation while keeping loss compilation,
   matching the existing Sesame pattern
3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401)
These functions are imported directly from .utils where needed
(e.g. cross_entropy_loss.py, loader.py). No external code imports
them from the unsloth.kernels namespace.
Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This works great!

@danielhanchen danielhanchen merged commit ccf942f into unslothai:main Mar 1, 2026
1 check passed
pull Bot pushed a commit to edisplay/unsloth that referenced this pull request Mar 1, 2026
…ai#4123)

Use 16 warps for RDNA in the chunked cross-entropy forward kernel
(large vocab > 65536), matching the existing CDNA optimization.

Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median):
  - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32
  - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is
    already optimal for RDNA; no modification needed.

Depends on: unslothai#4109 (provides is_rdna() detection)
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…dna() (unslothai#4109)

* fix(ROCm): comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()

- Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx)
- Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029)
- Export is_cdna/is_rdna from kernels for downstream use
- Import is_rdna into cross_entropy_loss for future RDNA-specific tuning

Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1:
  ✓ Gemma3-1B: loss 3.37→3.25 (no NaN)
  ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN)
  ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN)
  ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED
  ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: scope compile disable to RDNA only, use partial mode, remove unused import

Changes based on Daniel's review:
1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling
   torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine
2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to
   only disable model forward compilation while keeping loss compilation,
   matching the existing Sesame pattern
3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401)

* Remove redundant is_cdna/is_rdna exports from kernels/__init__.py

These functions are imported directly from .utils where needed
(e.g. cross_entropy_loss.py, loader.py). No external code imports
them from the unsloth.kernels namespace.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…ai#4123)

Use 16 warps for RDNA in the chunked cross-entropy forward kernel
(large vocab > 65536), matching the existing CDNA optimization.

Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median):
  - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32
  - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is
    already optimal for RDNA; no modification needed.

Depends on: unslothai#4109 (provides is_rdna() detection)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Training on ROCm (gfx1151, Strix Halo) results in NaN losses with Gemma3 fine-tuning

2 participants