Skip to content

Only install torchao ROCm stub on Windows ROCm#704

Merged
danielhanchen merged 1 commit into
mainfrom
fix/torchao-stub-windows-rocm-only
May 30, 2026
Merged

Only install torchao ROCm stub on Windows ROCm#704
danielhanchen merged 1 commit into
mainfrom
fix/torchao-stub-windows-rocm-only

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Problem

import unsloth_zoo crashes on any platform where torchao is not installed (the common case on Linux CI and most Linux/macOS setups):

TypeError: '_ROCmSentinelMeta' object is not iterable
  packaging/version.py ... in parse
  transformers/utils/import_utils.py ... in is_torchao_available

This is currently failing CI on multiple branches/PRs across both repos.

Root cause

The torchao Windows-ROCm stub in temporary_patches/utils.py (added in #703) was installed whenever import torchao raised:

if "torchao" not in sys.modules:
    try:
        import torchao
    except Exception:
        sys.meta_path.insert(0, _ROCmTorchaoFinder())

On a normal Linux machine import torchao raises simply because torchao is not installed. The guard then installs the meta-path finder anyway, so every later import torchao returns a sentinel stub. transformers' is_torchao_available() reads torchao.__version__, gets a _ROCmSentinelMeta class instead of a version string, and packaging.version.parse() crashes on it.

The stub is only ever needed on Windows ROCm, where import torchao crashes on the incomplete torch.distributed C-extension stack. Everywhere else a failing import torchao just means "not installed", which transformers already handles correctly.

Fix

Gate the stub install on an actual Windows ROCm build before touching sys.meta_path:

_is_windows_rocm = False
if sys.platform == "win32":
    try:
        import torch
        _is_windows_rocm = bool(
            getattr(getattr(torch, "version", None), "hip", None)
            or "rocm" in getattr(torch, "__version__", "").lower()
        )
    except Exception:
        _is_windows_rocm = False

if _is_windows_rocm and "torchao" not in sys.modules:
    try:
        import torchao
    except Exception:
        sys.meta_path.insert(0, _ROCmTorchaoFinder())

The Windows ROCm path is unchanged. Every other platform keeps transformers' own torchao handling, so import unsloth_zoo no longer crashes when torchao is absent.

Verification

A scope-faithful simulation execs the real stub classes and the real patched gate under faked sys.platform / torch / torchao conditions:

Scenario Result
Linux, torchao absent, old gate reproduces '_ROCmSentinelMeta' object is not iterable
Linux, torchao absent, new gate no stub installed, no crash
Windows ROCm, import torchao fails, new gate stub installed; isinstance(x, torchao.dtypes.AffineQuantizedTensor) returns False with no TypeError; deep torchao.a.b.c resolves
Windows ROCm, torchao importable, new gate no stub (real torchao used)
Windows CUDA, new gate no stub

python -m py_compile passes on the changed file.

The torchao meta_path stub was installed whenever `import torchao` failed,
including on Linux where torchao simply is not installed. transformers'
is_torchao_available() then read torchao.__version__, got a sentinel class,
and crashed in packaging.version.parse() with
"'_ROCmSentinelMeta' object is not iterable", breaking `import unsloth_zoo`.

Gate the stub install on actual Windows ROCm (sys.platform == "win32" plus a
HIP/ROCm torch build) so every other platform keeps transformers' own torchao
handling. The Windows ROCm behavior is unchanged.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request restricts the installation of the torchao ROCm stub to Windows + ROCm (HIP) PyTorch environments to prevent packaging errors on other platforms. Feedback was provided to simplify the platform check by utilizing the globally imported torch module instead of importing it locally with an alias.

Comment on lines +258 to +263
import torch as _torch_rocm_probe
_is_windows_rocm = bool(
getattr(getattr(_torch_rocm_probe, "version", None), "hip", None)
or "rocm" in getattr(_torch_rocm_probe, "__version__", "").lower()
)
del _torch_rocm_probe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since torch is already imported globally at the top of this file (on line 42), there is no need to import it again as a local alias (_torch_rocm_probe) and delete it afterwards. You can directly reference the globally imported torch module.

        _is_windows_rocm = bool(
            getattr(getattr(torch, "version", None), "hip", None)
            or "rocm" in getattr(torch, "__version__", "").lower()
        )

@danielhanchen danielhanchen merged commit ddd0310 into main May 30, 2026
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant