Skip to content

fix: guard torch.distributed attrs missing on ROCm Windows#703

Merged
danielhanchen merged 4 commits into
unslothai:mainfrom
LeoBorcherding:fix/rocm-windows-distributed-stub
May 30, 2026
Merged

fix: guard torch.distributed attrs missing on ROCm Windows#703
danielhanchen merged 4 commits into
unslothai:mainfrom
LeoBorcherding:fix/rocm-windows-distributed-stub

Conversation

@LeoBorcherding

Copy link
Copy Markdown
Contributor

On Windows, the ROCm build of PyTorch ships without the distributed C extension (torch._C._distributed_c10d), so torch.distributed loads as a partial stub with several attributes missing entirely, including is_initialized, is_torchelastic_launched, and get_rank.

unsloth_zoo/utils.py grabbed all three at module import time with bare attribute access, causing an AttributeError the moment the module was imported, even in code paths like GGUF export that never use distributed features at all.

Fix: replace the three bare grabs with getattr(..., lambda: False/0) fallbacks so importing unsloth_zoo never crashes on platforms where torch.distributed is unavailable or stubbed.

Root cause tracked in ROCm/TheRock#3284 (libuv / torch.distributed missing on Windows ROCm builds). Companion fix for the export subprocess in unsloth/unsloth#5301.

Ref: ROCm/TheRock#3284
Ref: unslothai/unsloth#5301

On Windows, the ROCm build of PyTorch ships without the distributed
C extension (torch._C._distributed_c10d), so torch.distributed loads
as a partial stub with several attributes missing entirely, including
is_initialized, is_torchelastic_launched, and get_rank.

unsloth_zoo/utils.py grabbed all three at module import time with bare
attribute access, causing an AttributeError the moment the module was
imported -- even in code paths like GGUF export that never use
distributed features at all.

Fix: replace the three bare grabs with getattr(..., lambda: False/0)
fallbacks so importing unsloth_zoo never crashes on platforms where
torch.distributed is unavailable or stubbed.

Root cause tracked in ROCm/TheRock#3284 (libuv / torch.distributed
missing on Windows ROCm builds). Companion fix for the export subprocess
in unsloth/unsloth#5301.

Ref: ROCm/TheRock#3284
Ref: unslothai/unsloth#5301
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces safe fallbacks using getattr for torch.distributed attributes to prevent import crashes on ROCm for Windows, where torch.distributed is stubbed. The reviewer recommended updating the fallback lambdas to accept *args and **kwargs to avoid potential TypeError exceptions if external callers invoke these functions with arguments, and suggested using the existing dist alias for cleaner code.

Comment thread unsloth_zoo/utils.py Outdated
Comment on lines +129 to +131
torch_distributed_is_initialized = getattr(torch.distributed, "is_initialized", lambda: False)
torch_distributed_is_torchelastic_launched = getattr(torch.distributed, "is_torchelastic_launched", lambda: False)
torch_distributed_get_rank = getattr(torch.distributed, "get_rank", lambda: 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since torch_distributed_get_rank is exported in __all__ and is part of the public API, external callers might call it with arguments (such as group). The fallback lambda: 0 does not accept any arguments and would raise a TypeError if called with one. Using lambda *args, **kwargs: 0 (and similarly for the other fallbacks) ensures robustness. Additionally, we can use the already imported dist alias instead of torch.distributed for cleaner attribute access.

Suggested change
torch_distributed_is_initialized = getattr(torch.distributed, "is_initialized", lambda: False)
torch_distributed_is_torchelastic_launched = getattr(torch.distributed, "is_torchelastic_launched", lambda: False)
torch_distributed_get_rank = getattr(torch.distributed, "get_rank", lambda: 0)
torch_distributed_is_initialized = getattr(dist, "is_initialized", lambda *args, **kwargs: False)
torch_distributed_is_torchelastic_launched = getattr(dist, "is_torchelastic_launched", lambda *args, **kwargs: False)
torch_distributed_get_rank = getattr(dist, "get_rank", lambda *args, **kwargs: 0)

On Windows, the ROCm PyTorch build ships without the full
torch.distributed C-extension stack (torch._C._distributed_c10d,
DeviceMesh, ProcessGroup, etc.).  torchao imports the entire distributed
chain at module level, so `import torchao` crashes even in code paths
that never touch distributed features (plain LoRA / GRPO training).

Fix: if torchao can't be imported, install a sys.meta_path hook
(_ROCmTorchaoFinder) that intercepts every "torchao" and "torchao.*"
import and returns a self-contained stub:

* Sub-module imports (torchao.dtypes, torchao.quantization …) get proper
  ModuleType stubs registered in sys.modules so the import machinery is
  satisfied.
* Direct attribute access on a stub (e.g. AffineQuantizedTensor) returns
  a sentinel class created via _ROCmSentinelMeta, which is a real Python
  type.  This makes isinstance(weight, AffineQuantizedTensor) return False
  (as expected -- no weight is ever an instance of the sentinel) instead of
  raising TypeError.
* Sentinel classes are chainable via metaclass __getattr__, so patterns
  like torchao.quantization.Float8WeightOnlyConfig resolve cleanly.

The hook is only installed when `import torchao` actually fails, so
Linux / CUDA environments are completely unaffected.

Companion to the torch.distributed attribute-guard in unsloth_zoo/utils.py
(commit 115d849) and the torchao export-subprocess stub in
unsloth/unsloth#5301.

Ref: ROCm/TheRock#3284
@LeoBorcherding

Copy link
Copy Markdown
Contributor Author

Added two Windows ROCm import fixes in the latest commit.

The first guards bare attribute access on torch.distributed (is_initialized, get_rank, etc.) which the ROCm Windows build strips out entirely.

The second installs a meta path hook that intercepts all torchao imports and returns self-contained stubs. On Windows ROCm, torchao tries to pull in the full distributed stack at module load time, which cascades into a hard crash even for plain LoRA training that never touches distributed features. The stubs give each submodule a proper spec so importlib.find_spec works, and return sentinel classes (not modules) for leaf attributes so isinstance() checks in peft pass correctly. The hook only activates when torchao actually fails to import, so Linux and CUDA are unaffected.

LeoBorcherding and others added 2 commits May 29, 2026 18:19
The three no-op lambdas used as getattr fallbacks when torch.distributed
attrs are missing on Windows ROCm only accepted zero arguments. Any caller
passing a group argument (e.g. get_rank(group=...)) would hit a TypeError.

Use lambda *args, **kwargs instead, and switch to the already-imported
dist alias for consistency.

Suggested by Gemini code review on PR unslothai#703.
Bind dist.is_initialized/is_torchelastic_launched/get_rank directly in a
try/except instead of per-name getattr, so the common path allocates no
fallback lambdas; only the ROCm Windows stub hits AttributeError.
@danielhanchen danielhanchen merged commit 0b8df5f into unslothai:main May 30, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants