fix(ROCm): Comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna()#4109
Conversation
Summary of ChangesHello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Unsloth's support for AMD RDNA GPUs by introducing robust detection for these architectures and resolving a critical numerical instability issue (NaNs) encountered when training Gemma3 models on ROCm. The changes ensure more reliable performance and pave the way for future RDNA-specific optimizations, improving the overall user experience for AMD GPU owners. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for AMD RDNA GPUs. It adds an is_rdna() utility function to detect these GPUs and uses it to apply a targeted fix for a NaN issue with Gemma3 models on ROCm by disabling torch.compile. The changes also export the new detection function for broader use. My feedback focuses on making the fix even more precise by fully utilizing the new is_rdna() function.
| # (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.). | ||
| # Disable torch.compile; eager path is numerically correct. | ||
| # See https://github.com/unslothai/unsloth/issues/3385 | ||
| if DEVICE_TYPE == "hip": |
There was a problem hiding this comment.
The check DEVICE_TYPE == "hip" is overly broad for an issue specific to RDNA GPUs. Using the newly introduced is_rdna() function provides a more targeted fix. This will prevent disabling torch.compile on other HIP devices like CDNA GPUs, where it might be beneficial. The is_rdna() function already includes a check for HIP.
| if DEVICE_TYPE == "hip": | |
| if __import__("unsloth.kernels").is_rdna(): |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 50ca55a095
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if DEVICE_TYPE == "hip": | ||
| os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" |
There was a problem hiding this comment.
Limit HIP compile disable to RDNA targets
FastModel.from_pretrained now disables torch.compile for every HIP device, but the regression described here is RDNA-specific. On ROCm CDNA systems (for example MI2xx/MI3xx), this path still executes and forces UNSLOTH_COMPILE_DISABLE=1, which can unnecessarily drop training/inference throughput for Gemma3 workloads even when they are not affected by the NaN issue. The guard should use RDNA detection instead of DEVICE_TYPE == "hip".
Useful? React with 👍 / 👎.
Review: PR #4109 -- RDNA GPU supportThanks for working on AMD RDNA support. The Issues Found1. Compile disable scope is too broad (HIGH)
# Current (disables compile on ALL AMD GPUs):
if DEVICE_TYPE == "hip":
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
# Should be (only disables on RDNA):
from unsloth.kernels.utils import is_rdna
if is_rdna():
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"2. Full compile disable vs partial (MEDIUM) The PR sets 3. Unused import (LOW)
What looks good
NVIDIA non-regression verificationTested on NVIDIA B200, Torch 2.9.1+cu128, Triton 3.5.1:
SFT training benchmark (max_steps=61, batch=2, grad_accum=3, seed=3407, 4bit LoRA, bfloat16): Llama-3.2-1B-Instruct:
gemma-3-1b-it:
Losses and grad-norms are identical (max_diff=0.000000) across all configs. No regression on NVIDIA. |
|
Thanks for the thorough review @danielhanchen! All 3 issues addressed in the latest push:
Also appreciate the NVIDIA non-regression verification — great to see identical losses and grad-norms across all configs.
|
Add is_rdna() detection for RDNA3/3.5/4 (gfx11xx, gfx12xx) and optimize
Triton kernel launch parameters based on W7900 (gfx1100) benchmarks:
- calculate_settings(): double num_warps (capped at 32) for RDNA.
RDNA's dual-issue SIMD32 CUs achieve higher occupancy with more warps.
- Cross entropy (chunked, large vocab): use 16 warps on RDNA (same as
CDNA), benchmarked ~10% faster than the NVIDIA default of 32.
- Export is_cdna/is_rdna from kernels for downstream use.
Benchmark results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1):
RMS LayerNorm (n_rows=4096, bf16):
dim=2048: 8w 102us -> 16w 95us (+7.6%)
dim=4096: 8w 210us -> 16w 203us (+3.1%)
dim=8192: 16w 370us -> 32w 344us (+7.0%)
Cross Entropy chunked (vocab=128256):
32w 1930us -> 16w 1728us (+10.5%)
Correctness: all RMS LN, LN, CE tests pass (fwd+bwd, fp16/bf16).
Builds on unslothai#4109 (is_rdna detection for Gemma3 NaN fix).
Add is_rdna() detection for RDNA3/3.5/4 (gfx11xx, gfx12xx) and optimize
Triton kernel launch parameters based on W7900 (gfx1100) benchmarks:
- calculate_settings(): double num_warps (capped at 32) for RDNA.
RDNA's dual-issue SIMD32 CUs achieve higher occupancy with more warps.
- Cross entropy (chunked, large vocab): use 16 warps on RDNA (same as
CDNA), benchmarked ~10% faster than the NVIDIA default of 32.
- Export is_cdna/is_rdna from kernels for downstream use.
Benchmark results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1):
RMS LayerNorm (n_rows=4096, bf16):
dim=2048: 8w 102us -> 16w 95us (+7.6%)
dim=4096: 8w 210us -> 16w 203us (+3.1%)
dim=8192: 16w 370us -> 32w 344us (+7.0%)
Cross Entropy chunked (vocab=128256):
32w 1930us -> 16w 1728us (+10.5%)
Correctness: all RMS LN, LN, CE tests pass (fwd+bwd, fp16/bf16).
Builds on unslothai#4109 (is_rdna detection for Gemma3 NaN fix).
Use 16 warps for RDNA in the chunked cross-entropy forward kernel
(large vocab > 65536), matching the existing CDNA optimization.
Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median):
- Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32
- All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is
already optimal for RDNA; no modification needed.
Depends on: unslothai#4109 (provides is_rdna() detection)
ce6d2da to
7da16e8
Compare
|
Updated: Rebased onto latest Current diff (3 files, clean):
All 3 review items from Daniel have been addressed. |
…dna() - Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx) - Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029) - Export is_cdna/is_rdna from kernels for downstream use - Import is_rdna into cross_entropy_loss for future RDNA-specific tuning Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1: ✓ Gemma3-1B: loss 3.37→3.25 (no NaN) ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN) ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN) ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED
for more information, see https://pre-commit.ci
… remove unused import Changes based on Daniel's review: 1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine 2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to only disable model forward compilation while keeping loss compilation, matching the existing Sesame pattern 3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401)
These functions are imported directly from .utils where needed (e.g. cross_entropy_loss.py, loader.py). No external code imports them from the unsloth.kernels namespace.
for more information, see https://pre-commit.ci
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you! This works great!
…ai#4123) Use 16 warps for RDNA in the chunked cross-entropy forward kernel (large vocab > 65536), matching the existing CDNA optimization. Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median): - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32 - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is already optimal for RDNA; no modification needed. Depends on: unslothai#4109 (provides is_rdna() detection)
…dna() (unslothai#4109) * fix(ROCm): comprehensive RDNA GPU support - fix Gemma3 NaN & add is_rdna() - Add is_rdna() detection for RDNA3/3.5/RDNA4 consumer GPUs (gfx11xx, gfx1151, gfx12xx) - Disable torch.compile for Gemma3 on HIP to fix NaN loss (fixes unslothai#3385, unslothai#4029) - Export is_cdna/is_rdna from kernels for downstream use - Import is_rdna into cross_entropy_loss for future RDNA-specific tuning Tested on AMD Radeon PRO W7900 (gfx1100) with ROCm 7.1: ✓ Gemma3-1B: loss 3.37→3.25 (no NaN) ✓ Llama-3.2-1B: loss 2.44→2.37 (no NaN) ✓ Qwen2.5-1.5B: loss 1.89→1.85 (no NaN) ✓ RMS LayerNorm Triton kernel: bf16/fp16 PASSED ✓ Cross Entropy Loss Triton kernel: 32K/256K vocab PASSED * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address review: scope compile disable to RDNA only, use partial mode, remove unused import Changes based on Daniel's review: 1. (HIGH) Replace DEVICE_TYPE=='hip' with is_rdna() to avoid disabling torch.compile on CDNA GPUs (MI250X/MI300X/MI350) where it works fine 2. (MEDIUM) Use 'partial' instead of '1' for UNSLOTH_COMPILE_DISABLE to only disable model forward compilation while keeping loss compilation, matching the existing Sesame pattern 3. (LOW) Remove unused is_rdna import from cross_entropy_loss.py (F401) * Remove redundant is_cdna/is_rdna exports from kernels/__init__.py These functions are imported directly from .utils where needed (e.g. cross_entropy_loss.py, loader.py). No external code imports them from the unsloth.kernels namespace. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…ai#4123) Use 16 warps for RDNA in the chunked cross-entropy forward kernel (large vocab > 65536), matching the existing CDNA optimization. Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median): - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32 - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is already optimal for RDNA; no modification needed. Depends on: unslothai#4109 (provides is_rdna() detection)

Summary
Comprehensive fix for AMD RDNA consumer/workstation GPUs (RX 7000/9000 series, PRO W7900, Strix Halo, etc.).
Changes
is_rdna()detection (unsloth/kernels/utils.py)is_rdna()function detects all RDNA3/3.5/RDNA4 architectures (gfx11xx, gfx1151, gfx12xx)arch.startswith("gfx1")+not is_cdna()for future-proof coverageGemma3 NaN fix (
unsloth/models/loader.py)torch.compilefor Gemma3 on HIP (ROCm) devicesExport
is_cdna/is_rdna(unsloth/kernels/__init__.py)Import
is_rdnain cross_entropy_loss (unsloth/kernels/cross_entropy_loss.py)Test Results on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1)
Environment
Related Issues
cc @danielhanchen