Skip to content

fix(ROCm): restrict is_rdna() to ROCm-officially-supported GPUs#4136

Merged
danielhanchen merged 2 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/is-rdna-rocm-official-list
Mar 3, 2026
Merged

fix(ROCm): restrict is_rdna() to ROCm-officially-supported GPUs#4136
danielhanchen merged 2 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/is-rdna-rocm-official-list

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

@GoldenGrapeGentleman GoldenGrapeGentleman commented Mar 2, 2026

Problem

The current arch.startswith("gfx1") implementation in is_rdna() over-matches hardware that is not in the ROCm Linux support matrix:

Fix

Replace with an explicit whitelist of the 4 officially ROCm-supported RDNA gfx targets, mirroring the existing is_cdna() style:

gfx1100 / gfx1101  — RDNA3 discrete (RX 7900/7800/7700 series, PRO W7900/W7800/W7700)
gfx1200 / gfx1201  — RDNA4 discrete (RX 9060/9070 series, AI PRO R9700)

Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1).

cc @danielhanchen

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the detection logic for RDNA GPUs within the is_rdna() function to align strictly with AMD's official ROCm support matrix. The change addresses a problem where the previous implementation could over-match unsupported hardware, risking the application of unverified Triton kernel tunings. By implementing an explicit whitelist, the system now ensures that Triton kernels are only applied to genuinely supported RDNA architectures, enhancing stability and correctness.

Highlights

  • ROCm GPU Detection: Modified the is_rdna() function to explicitly whitelist only ROCm-officially-supported RDNA GPUs (gfx1100, gfx1101, gfx1200, gfx1201). This prevents the application of unverified Triton kernel tuning to unsupported hardware.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/utils.py
    • Updated the is_rdna function to use an explicit whitelist for ROCm-supported RDNA architectures.
    • Revised the docstring for is_rdna to reflect the new detection scope.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly refines the is_rdna() function to whitelist only ROCm-officially-supported RDNA GPUs, preventing the use of unverified kernels on unsupported hardware. The change is concise and aligns well with the existing is_cdna() implementation. I have one minor suggestion to improve the readability of the new function body.

Comment thread unsloth/kernels/utils.py
Comment on lines +92 to +97
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx1100",
"gfx1101",
"gfx1200",
"gfx1201",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and to avoid a very long line of code, consider extracting the architecture string into a local variable. This makes the code's intent clearer and is more robust to different code formatters, while being consistent with the structure of the previous implementation.

Suggested change
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx1100",
"gfx1101",
"gfx1200",
"gfx1201",
)
arch = triton.runtime.driver.active.get_current_target().arch
return is_hip() and arch in (
"gfx1100",
"gfx1101",
"gfx1200",
"gfx1201",
)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b82d0972c4

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/kernels/utils.py
Comment on lines +92 to +96
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx1100",
"gfx1101",
"gfx1200",
"gfx1201",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Split RDNA support check from RDNA hardware detection

is_rdna() now encodes a ROCm support whitelist, but it is also used as a general RDNA feature gate; in unsloth/models/loader.py (Gemma3 path around lines 1134-1141), the NaN workaround is conditioned on is_rdna() and explicitly calls out affected gfx1102/gfx115x devices. With this whitelist, those GPUs now return False, so the compile-disable workaround is skipped and those users can regress to NaN outputs. A separate predicate for "ROCm-officially-supported RDNA" should be introduced for tuning decisions, while preserving RDNA-family detection for correctness workarounds.

Useful? React with 👍 / 👎.

Comment thread unsloth/kernels/utils.py
arch = triton.runtime.driver.active.get_current_target().arch
return arch.startswith("gfx1") and not is_cdna()
"""Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4)."""
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: What if triton isn't installed?

Copy link
Copy Markdown
Contributor Author

@GoldenGrapeGentleman GoldenGrapeGentleman Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Datta0 ~ Good catch in spirit, but we're actually safe here on two levels! 😄

  1. triton is already imported unconditionally at the module top-level (line 16 & 56 of utils.py) — if triton isn't installed, the whole module blows up long before anyone calls is_rdna().

  2. Even at runtime, is_hip() and ... short-circuits — so on a non-ROCm machine (where triton might not ship the HIP driver), we never touch triton.runtime.driver at all.

Current arch.startswith("gfx1") incorrectly matches:
  - RDNA1 (gfx10xx) and RDNA2 (gfx103x): not ROCm supported
  - gfx1102 (RX 7600), gfx1103 (Phoenix APU): not in ROCm support matrix
  - gfx1150/1151/1152 (RDNA3.5 APUs): not in ROCm support matrix

Replace with explicit whitelist aligned to the ROCm Linux support matrix:
  https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html

  gfx1100 - RDNA3 discrete (RX 7900 series, PRO W7900/W7800)
  gfx1101 - RDNA3 discrete (RX 7800/7700 series, PRO W7700)
  gfx1200 - RDNA4 discrete (RX 9060 series)
  gfx1201 - RDNA4 discrete (RX 9070 series, AI PRO R9700)

Mirrors the existing is_cdna() pattern. Avoids silently applying
unverified Triton kernel tuning to unsupported hardware.
@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Mar 3, 2026

btw @GoldenGrapeGentleman I see your other PR #4139
Do you wanna club these two into one? I see these are small changes and would simplify review

Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whitelist approach makes sense for performance tuning, but after PR #4139 reverts the only performance use of is_rdna(), the function's sole remaining caller is the Gemma3 NaN correctness workaround in loader.py:1135-1142:

# ROCm/HIP: Gemma3 compiled forward produces NaN on RDNA GPUs
# (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.).
from unsloth.kernels.utils import is_rdna
if is_rdna():
    os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"

That comment explicitly lists gfx1102 (RX 7600), gfx1150, gfx1151 (Strix Halo) as affected -- all excluded by this whitelist. Issue #3385 was filed from a gfx1151 system. This PR would cause the NaN workaround to be skipped on the very hardware it was designed to protect.

For correctness workarounds, broad matching (startswith("gfx1")) is the safer approach. I'd recommend closing this after #4139 merges. If a narrow "officially supported RDNA" check is needed later for perf tuning, it should be a separate function (e.g. is_rdna_supported()) so the correctness path stays broad.

@danielhanchen
Copy link
Copy Markdown
Member

Better to merge #4139 first -- it's a clean revert, ready now, and has zero regression risk.

#4136 needs reconsideration: after #4139 lands, the only remaining is_rdna() caller is the Gemma3 NaN correctness workaround in loader.py. The whitelist here excludes gfx1102, gfx1150, gfx1151 -- the exact hardware that workaround protects (issue #3385 was filed from gfx1151). Combining would hold up the good revert while the whitelist issue gets sorted out.

@danielhanchen danielhanchen merged commit f737858 into unslothai:main Mar 3, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…nslothai#4136)

Current arch.startswith("gfx1") incorrectly matches:
  - RDNA1 (gfx10xx) and RDNA2 (gfx103x): not ROCm supported
  - gfx1102 (RX 7600), gfx1103 (Phoenix APU): not in ROCm support matrix
  - gfx1150/1151/1152 (RDNA3.5 APUs): not in ROCm support matrix

Replace with explicit whitelist aligned to the ROCm Linux support matrix:
  https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html

  gfx1100 - RDNA3 discrete (RX 7900 series, PRO W7900/W7800)
  gfx1101 - RDNA3 discrete (RX 7800/7700 series, PRO W7700)
  gfx1200 - RDNA4 discrete (RX 9060 series)
  gfx1201 - RDNA4 discrete (RX 9070 series, AI PRO R9700)

Mirrors the existing is_cdna() pattern. Avoids silently applying
unverified Triton kernel tuning to unsupported hardware.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants