Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion unsloth_zoo/temporary_patches/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,27 @@ def find_module(self, fullname, path=None): # Python < 3.12 compat shim
return None


if "torchao" not in _sys_rocm_stub.modules:
# Only Windows + ROCm (HIP) PyTorch actually needs this stub: that is the one
# build where `import torchao` crashes on the missing torch.distributed
# C-extension stack. On every other platform a failing `import torchao` simply
# means torchao is not installed -- transformers handles that correctly on its
# own. Installing the stub there is actively harmful: transformers'
# is_torchao_available() would then read torchao.__version__, get a sentinel
# class, and crash in packaging.version.parse() with
# "'_ROCmSentinelMeta' object is not iterable".
_is_windows_rocm = False
if _sys_rocm_stub.platform == "win32":
try:
import torch as _torch_rocm_probe
_is_windows_rocm = bool(
getattr(getattr(_torch_rocm_probe, "version", None), "hip", None)
or "rocm" in getattr(_torch_rocm_probe, "__version__", "").lower()
)
del _torch_rocm_probe
Comment on lines +258 to +263

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since torch is already imported globally at the top of this file (on line 42), there is no need to import it again as a local alias (_torch_rocm_probe) and delete it afterwards. You can directly reference the globally imported torch module.

        _is_windows_rocm = bool(
            getattr(getattr(torch, "version", None), "hip", None)
            or "rocm" in getattr(torch, "__version__", "").lower()
        )

except Exception:
_is_windows_rocm = False

if _is_windows_rocm and "torchao" not in _sys_rocm_stub.modules:
try:
import torchao # noqa: F401
except Exception:
Expand All @@ -256,6 +276,7 @@ def find_module(self, fullname, path=None): # Python < 3.12 compat shim
# alive -- the loader and sentinel classes call them at runtime.
del _ROCmTorchaoLoader, _ROCmTorchaoFinder
del _MetaPathFinder, _Loader, _ModuleSpec, _sys_rocm_stub, _types_rocm_stub
del _is_windows_rocm

try:
from transformers.processing_utils import Unpack
Expand Down
Loading