fix(ROCm): restrict is_rdna() to ROCm-officially-supported GPUs#4136
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the detection logic for RDNA GPUs within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly refines the is_rdna() function to whitelist only ROCm-officially-supported RDNA GPUs, preventing the use of unverified kernels on unsupported hardware. The change is concise and aligns well with the existing is_cdna() implementation. I have one minor suggestion to improve the readability of the new function body.
| return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( | ||
| "gfx1100", | ||
| "gfx1101", | ||
| "gfx1200", | ||
| "gfx1201", | ||
| ) |
There was a problem hiding this comment.
For improved readability and to avoid a very long line of code, consider extracting the architecture string into a local variable. This makes the code's intent clearer and is more robust to different code formatters, while being consistent with the structure of the previous implementation.
| return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( | |
| "gfx1100", | |
| "gfx1101", | |
| "gfx1200", | |
| "gfx1201", | |
| ) | |
| arch = triton.runtime.driver.active.get_current_target().arch | |
| return is_hip() and arch in ( | |
| "gfx1100", | |
| "gfx1101", | |
| "gfx1200", | |
| "gfx1201", | |
| ) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b82d0972c4
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( | ||
| "gfx1100", | ||
| "gfx1101", | ||
| "gfx1200", | ||
| "gfx1201", |
There was a problem hiding this comment.
Split RDNA support check from RDNA hardware detection
is_rdna() now encodes a ROCm support whitelist, but it is also used as a general RDNA feature gate; in unsloth/models/loader.py (Gemma3 path around lines 1134-1141), the NaN workaround is conditioned on is_rdna() and explicitly calls out affected gfx1102/gfx115x devices. With this whitelist, those GPUs now return False, so the compile-disable workaround is skipped and those users can regress to NaN outputs. A separate predicate for "ROCm-officially-supported RDNA" should be introduced for tuning decisions, while preserving RDNA-family detection for correctness workarounds.
Useful? React with 👍 / 👎.
| arch = triton.runtime.driver.active.get_current_target().arch | ||
| return arch.startswith("gfx1") and not is_cdna() | ||
| """Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4).""" | ||
| return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( |
There was a problem hiding this comment.
NIT: What if triton isn't installed?
There was a problem hiding this comment.
Hi @Datta0 ~ Good catch in spirit, but we're actually safe here on two levels! 😄
-
tritonis already imported unconditionally at the module top-level (line 16 & 56 ofutils.py) — if triton isn't installed, the whole module blows up long before anyone callsis_rdna(). -
Even at runtime,
is_hip() and ...short-circuits — so on a non-ROCm machine (where triton might not ship the HIP driver), we never touchtriton.runtime.driverat all.
Current arch.startswith("gfx1") incorrectly matches:
- RDNA1 (gfx10xx) and RDNA2 (gfx103x): not ROCm supported
- gfx1102 (RX 7600), gfx1103 (Phoenix APU): not in ROCm support matrix
- gfx1150/1151/1152 (RDNA3.5 APUs): not in ROCm support matrix
Replace with explicit whitelist aligned to the ROCm Linux support matrix:
https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html
gfx1100 - RDNA3 discrete (RX 7900 series, PRO W7900/W7800)
gfx1101 - RDNA3 discrete (RX 7800/7700 series, PRO W7700)
gfx1200 - RDNA4 discrete (RX 9060 series)
gfx1201 - RDNA4 discrete (RX 9070 series, AI PRO R9700)
Mirrors the existing is_cdna() pattern. Avoids silently applying
unverified Triton kernel tuning to unsupported hardware.
|
btw @GoldenGrapeGentleman I see your other PR #4139 |
danielhanchen
left a comment
There was a problem hiding this comment.
The whitelist approach makes sense for performance tuning, but after PR #4139 reverts the only performance use of is_rdna(), the function's sole remaining caller is the Gemma3 NaN correctness workaround in loader.py:1135-1142:
# ROCm/HIP: Gemma3 compiled forward produces NaN on RDNA GPUs
# (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.).
from unsloth.kernels.utils import is_rdna
if is_rdna():
os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"That comment explicitly lists gfx1102 (RX 7600), gfx1150, gfx1151 (Strix Halo) as affected -- all excluded by this whitelist. Issue #3385 was filed from a gfx1151 system. This PR would cause the NaN workaround to be skipped on the very hardware it was designed to protect.
For correctness workarounds, broad matching (startswith("gfx1")) is the safer approach. I'd recommend closing this after #4139 merges. If a narrow "officially supported RDNA" check is needed later for perf tuning, it should be a separate function (e.g. is_rdna_supported()) so the correctness path stays broad.
|
Better to merge #4139 first -- it's a clean revert, ready now, and has zero regression risk. #4136 needs reconsideration: after #4139 lands, the only remaining |
…nslothai#4136) Current arch.startswith("gfx1") incorrectly matches: - RDNA1 (gfx10xx) and RDNA2 (gfx103x): not ROCm supported - gfx1102 (RX 7600), gfx1103 (Phoenix APU): not in ROCm support matrix - gfx1150/1151/1152 (RDNA3.5 APUs): not in ROCm support matrix Replace with explicit whitelist aligned to the ROCm Linux support matrix: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html gfx1100 - RDNA3 discrete (RX 7900 series, PRO W7900/W7800) gfx1101 - RDNA3 discrete (RX 7800/7700 series, PRO W7700) gfx1200 - RDNA4 discrete (RX 9060 series) gfx1201 - RDNA4 discrete (RX 9070 series, AI PRO R9700) Mirrors the existing is_cdna() pattern. Avoids silently applying unverified Triton kernel tuning to unsupported hardware.
Problem
The current
arch.startswith("gfx1")implementation inis_rdna()over-matches hardware that is not in the ROCm Linux support matrix:Fix
Replace with an explicit whitelist of the 4 officially ROCm-supported RDNA gfx targets, mirroring the existing
is_cdna()style:Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1).
cc @danielhanchen