-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix Gemma3 NaN losses on ROCm by disabling torch.compile for RDNA GPUs #4029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1046,6 +1046,14 @@ def from_pretrained( | |
| # Set norms to float32 since anyways they get upcasted to float32 | ||
| # common in both gemma-3 and gemma-3n | ||
| os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" | ||
| # ROCm: Gemma3 compiled forward produces NaN on RDNA GPUs (gfx11xx). | ||
| # Disable compilation; eager path is numerically correct. | ||
| # See https://github.com/unslothai/unsloth/issues/3385 | ||
| if DEVICE_TYPE == "hip": | ||
| os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" | ||
| import unsloth_zoo.compiler | ||
|
|
||
| unsloth_zoo.compiler.UNSLOTH_COMPILE_DISABLE = True | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This assignment mutates a process-wide compiler flag and there is no corresponding reset path in Useful? React with 👍 / 👎. |
||
| # Cohere | ||
| elif "cohere2" in model_types_all and transformers_version < Version( | ||
| "4.50.0.dev0" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve code style and adhere to PEP 8, it's better to have imports at the top of the file. Please move this import to the top of
unsloth/models/loader.pyand remove it from here.