diff --git a/unsloth_zoo/temporary_patches/utils.py b/unsloth_zoo/temporary_patches/utils.py index 5dee7d2e3..b1acd0c2d 100644 --- a/unsloth_zoo/temporary_patches/utils.py +++ b/unsloth_zoo/temporary_patches/utils.py @@ -143,6 +143,120 @@ def process_return( pass del _sys, _pil_mod +# ROCm on Windows ships PyTorch without the full torch.distributed C-extension +# stack (torch._C._distributed_c10d, DeviceMesh, etc.). +# torchao pulls the entire distributed chain in at module-import time +# (torch.distributed._functional_collectives → _tensor → DeviceMesh …), +# which cascades into ImportError even for code paths that never use +# distributed features (e.g. plain LoRA training). +# Fix: try to import torchao; if it fails due to the missing distributed +# stack, register a minimal stub so subsequent imports (transformers → +# quantizer_torchao → torchao) succeed. Any runtime path that actually +# needs a real distributed collective will still fail loudly. +# Ref: https://github.com/ROCm/TheRock/issues/3284 +# ROCm on Windows ships PyTorch without the full torch.distributed C-extension +# stack. torchao pulls the entire distributed chain in at module-import time, +# cascading into ImportError even for code paths that never use distributed +# features (e.g. plain LoRA training). +# +# Fix: if torchao can't be imported, install a sys.meta_path hook that +# intercepts *all* "torchao" and "torchao.*" imports and returns self-contained +# stub modules. Each stub satisfies `from torchao.X import Y` by returning a +# no-op sentinel class for any attribute, so transformers can define its +# TorchAoHfQuantizer class at import time without the full torchao stack. +# Any runtime call that actually needs a real torchao op will still fail loudly. +# Ref: https://github.com/ROCm/TheRock/issues/3284 +import sys as _sys_rocm_stub, types as _types_rocm_stub +from importlib.abc import MetaPathFinder as _MetaPathFinder, Loader as _Loader +from importlib.machinery import ModuleSpec as _ModuleSpec + + +# Metaclass that makes sentinel classes chainable via attribute access, +# e.g. AffineQuantizedTensor.subattr returns another sentinel class. +# This is needed because peft does isinstance(weight, AffineQuantizedTensor) +# which requires the second arg to be a real type, not a module. +class _ROCmSentinelMeta(type): + def __getattr__(cls, name): + child = _ROCmSentinelMeta(name, (), {"__module__": cls.__module__}) + setattr(cls, name, child) + return child + + +def _rocm_make_sentinel(attr, parent_name): + """Return a sentinel class that is a proper type (works in isinstance()).""" + return _ROCmSentinelMeta(attr, (), {"__module__": parent_name}) + + +def _rocm_make_torchao_stub(name): + """ + Create a stub module for a torchao path. + + - Sub-module imports (torchao.dtypes, torchao.quantization …) are handled + by the meta_path finder and get proper module stubs. + - Direct attribute access on a stub (torchao.dtypes.AffineQuantizedTensor) + returns a sentinel CLASS so that isinstance() checks succeed (returning + False, since no real weight will ever be an instance of the sentinel). + """ + import sys as _s, types as _t + from importlib.machinery import ModuleSpec as _MS + + mod = _t.ModuleType(name) + mod.__path__ = [] + mod.__package__ = name + mod.__spec__ = _MS(name, loader=None) + + def _getattr(attr): + full = f"{name}.{attr}" + # If a sub-module was already imported (via meta_path), use that. + if full in _s.modules: + obj = _s.modules[full] + else: + # Otherwise return a sentinel class usable in isinstance(). + obj = _rocm_make_sentinel(attr, name) + setattr(mod, attr, obj) + return obj + + mod.__getattr__ = _getattr + return mod + + +class _ROCmTorchaoLoader(_Loader): + """Loader that creates a recursive stub module for any torchao path.""" + + def create_module(self, spec): + return _rocm_make_torchao_stub(spec.name) + + def exec_module(self, module): + pass # _rocm_make_torchao_stub already configures the module + + +class _ROCmTorchaoFinder(_MetaPathFinder): + """Intercepts torchao.* imports on Windows ROCm where torch.distributed is incomplete.""" + _loader = _ROCmTorchaoLoader() + + def find_spec(self, fullname, path, target=None): + if fullname == "torchao" or fullname.startswith("torchao."): + from importlib.machinery import ModuleSpec as _MS + return _MS(fullname, self._loader, is_package=True) + return None + + def find_module(self, fullname, path=None): # Python < 3.12 compat shim + return None + + +if "torchao" not in _sys_rocm_stub.modules: + try: + import torchao # noqa: F401 + except Exception: + # torchao import fails on Windows ROCm -- install the meta path hook + # so every subsequent "import torchao.*" gets a harmless stub instead. + _sys_rocm_stub.meta_path.insert(0, _ROCmTorchaoFinder()) + +# _rocm_make_torchao_stub, _rocm_make_sentinel, _ROCmSentinelMeta are kept +# alive -- the loader and sentinel classes call them at runtime. +del _ROCmTorchaoLoader, _ROCmTorchaoFinder +del _MetaPathFinder, _Loader, _ModuleSpec, _sys_rocm_stub, _types_rocm_stub + try: from transformers.processing_utils import Unpack assert \ diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 22ab9d10d..41bec17e2 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -122,9 +122,19 @@ def _get_dtype(dtype): import functools -torch_distributed_is_initialized = torch.distributed.is_initialized -torch_distributed_is_torchelastic_launched = torch.distributed.is_torchelastic_launched -torch_distributed_get_rank = torch.distributed.get_rank +# ROCm on Windows ships a stubbed torch.distributed missing these attributes +# entirely (https://github.com/ROCm/TheRock/issues/3284). Bind the real +# functions directly when present; only the stub hits the AttributeError path. +# Avoids per-name getattr and the throwaway lambdas it builds on every other +# platform. +try: + torch_distributed_is_initialized = dist.is_initialized + torch_distributed_is_torchelastic_launched = dist.is_torchelastic_launched + torch_distributed_get_rank = dist.get_rank +except AttributeError: + torch_distributed_is_initialized = lambda *args, **kwargs: False + torch_distributed_is_torchelastic_launched = lambda *args, **kwargs: False + torch_distributed_get_rank = lambda *args, **kwargs: 0 def is_main_process(): if torch_distributed_is_initialized():