diff --git a/unsloth_zoo/temporary_patches/utils.py b/unsloth_zoo/temporary_patches/utils.py index b1acd0c2d..b165a72c2 100644 --- a/unsloth_zoo/temporary_patches/utils.py +++ b/unsloth_zoo/temporary_patches/utils.py @@ -244,7 +244,27 @@ def find_module(self, fullname, path=None): # Python < 3.12 compat shim return None -if "torchao" not in _sys_rocm_stub.modules: +# Only Windows + ROCm (HIP) PyTorch actually needs this stub: that is the one +# build where `import torchao` crashes on the missing torch.distributed +# C-extension stack. On every other platform a failing `import torchao` simply +# means torchao is not installed -- transformers handles that correctly on its +# own. Installing the stub there is actively harmful: transformers' +# is_torchao_available() would then read torchao.__version__, get a sentinel +# class, and crash in packaging.version.parse() with +# "'_ROCmSentinelMeta' object is not iterable". +_is_windows_rocm = False +if _sys_rocm_stub.platform == "win32": + try: + import torch as _torch_rocm_probe + _is_windows_rocm = bool( + getattr(getattr(_torch_rocm_probe, "version", None), "hip", None) + or "rocm" in getattr(_torch_rocm_probe, "__version__", "").lower() + ) + del _torch_rocm_probe + except Exception: + _is_windows_rocm = False + +if _is_windows_rocm and "torchao" not in _sys_rocm_stub.modules: try: import torchao # noqa: F401 except Exception: @@ -256,6 +276,7 @@ def find_module(self, fullname, path=None): # Python < 3.12 compat shim # alive -- the loader and sentinel classes call them at runtime. del _ROCmTorchaoLoader, _ROCmTorchaoFinder del _MetaPathFinder, _Loader, _ModuleSpec, _sys_rocm_stub, _types_rocm_stub +del _is_windows_rocm try: from transformers.processing_utils import Unpack