Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions unsloth_zoo/temporary_patches/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,120 @@ def process_return(
pass
del _sys, _pil_mod

# ROCm on Windows ships PyTorch without the full torch.distributed C-extension
# stack (torch._C._distributed_c10d, DeviceMesh, etc.).
# torchao pulls the entire distributed chain in at module-import time
# (torch.distributed._functional_collectives → _tensor → DeviceMesh …),
# which cascades into ImportError even for code paths that never use
# distributed features (e.g. plain LoRA training).
# Fix: try to import torchao; if it fails due to the missing distributed
# stack, register a minimal stub so subsequent imports (transformers →
# quantizer_torchao → torchao) succeed. Any runtime path that actually
# needs a real distributed collective will still fail loudly.
# Ref: https://github.com/ROCm/TheRock/issues/3284
# ROCm on Windows ships PyTorch without the full torch.distributed C-extension
# stack. torchao pulls the entire distributed chain in at module-import time,
# cascading into ImportError even for code paths that never use distributed
# features (e.g. plain LoRA training).
#
# Fix: if torchao can't be imported, install a sys.meta_path hook that
# intercepts *all* "torchao" and "torchao.*" imports and returns self-contained
# stub modules. Each stub satisfies `from torchao.X import Y` by returning a
# no-op sentinel class for any attribute, so transformers can define its
# TorchAoHfQuantizer class at import time without the full torchao stack.
# Any runtime call that actually needs a real torchao op will still fail loudly.
# Ref: https://github.com/ROCm/TheRock/issues/3284
import sys as _sys_rocm_stub, types as _types_rocm_stub
from importlib.abc import MetaPathFinder as _MetaPathFinder, Loader as _Loader
from importlib.machinery import ModuleSpec as _ModuleSpec


# Metaclass that makes sentinel classes chainable via attribute access,
# e.g. AffineQuantizedTensor.subattr returns another sentinel class.
# This is needed because peft does isinstance(weight, AffineQuantizedTensor)
# which requires the second arg to be a real type, not a module.
class _ROCmSentinelMeta(type):
def __getattr__(cls, name):
child = _ROCmSentinelMeta(name, (), {"__module__": cls.__module__})
setattr(cls, name, child)
return child


def _rocm_make_sentinel(attr, parent_name):
"""Return a sentinel class that is a proper type (works in isinstance())."""
return _ROCmSentinelMeta(attr, (), {"__module__": parent_name})


def _rocm_make_torchao_stub(name):
"""
Create a stub module for a torchao path.

- Sub-module imports (torchao.dtypes, torchao.quantization …) are handled
by the meta_path finder and get proper module stubs.
- Direct attribute access on a stub (torchao.dtypes.AffineQuantizedTensor)
returns a sentinel CLASS so that isinstance() checks succeed (returning
False, since no real weight will ever be an instance of the sentinel).
"""
import sys as _s, types as _t
from importlib.machinery import ModuleSpec as _MS

mod = _t.ModuleType(name)
mod.__path__ = []
mod.__package__ = name
mod.__spec__ = _MS(name, loader=None)

def _getattr(attr):
full = f"{name}.{attr}"
# If a sub-module was already imported (via meta_path), use that.
if full in _s.modules:
obj = _s.modules[full]
else:
# Otherwise return a sentinel class usable in isinstance().
obj = _rocm_make_sentinel(attr, name)
setattr(mod, attr, obj)
return obj

mod.__getattr__ = _getattr
return mod


class _ROCmTorchaoLoader(_Loader):
"""Loader that creates a recursive stub module for any torchao path."""

def create_module(self, spec):
return _rocm_make_torchao_stub(spec.name)

def exec_module(self, module):
pass # _rocm_make_torchao_stub already configures the module


class _ROCmTorchaoFinder(_MetaPathFinder):
"""Intercepts torchao.* imports on Windows ROCm where torch.distributed is incomplete."""
_loader = _ROCmTorchaoLoader()

def find_spec(self, fullname, path, target=None):
if fullname == "torchao" or fullname.startswith("torchao."):
from importlib.machinery import ModuleSpec as _MS
return _MS(fullname, self._loader, is_package=True)
return None

def find_module(self, fullname, path=None): # Python < 3.12 compat shim
return None


if "torchao" not in _sys_rocm_stub.modules:
try:
import torchao # noqa: F401
except Exception:
# torchao import fails on Windows ROCm -- install the meta path hook
# so every subsequent "import torchao.*" gets a harmless stub instead.
_sys_rocm_stub.meta_path.insert(0, _ROCmTorchaoFinder())

# _rocm_make_torchao_stub, _rocm_make_sentinel, _ROCmSentinelMeta are kept
# alive -- the loader and sentinel classes call them at runtime.
del _ROCmTorchaoLoader, _ROCmTorchaoFinder
del _MetaPathFinder, _Loader, _ModuleSpec, _sys_rocm_stub, _types_rocm_stub

try:
from transformers.processing_utils import Unpack
assert \
Expand Down
16 changes: 13 additions & 3 deletions unsloth_zoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,19 @@ def _get_dtype(dtype):


import functools
torch_distributed_is_initialized = torch.distributed.is_initialized
torch_distributed_is_torchelastic_launched = torch.distributed.is_torchelastic_launched
torch_distributed_get_rank = torch.distributed.get_rank
# ROCm on Windows ships a stubbed torch.distributed missing these attributes
# entirely (https://github.com/ROCm/TheRock/issues/3284). Bind the real
# functions directly when present; only the stub hits the AttributeError path.
# Avoids per-name getattr and the throwaway lambdas it builds on every other
# platform.
try:
torch_distributed_is_initialized = dist.is_initialized
torch_distributed_is_torchelastic_launched = dist.is_torchelastic_launched
torch_distributed_get_rank = dist.get_rank
except AttributeError:
torch_distributed_is_initialized = lambda *args, **kwargs: False
torch_distributed_is_torchelastic_launched = lambda *args, **kwargs: False
torch_distributed_get_rank = lambda *args, **kwargs: 0

def is_main_process():
if torch_distributed_is_initialized():
Expand Down
Loading