Skip to content

studio: ROCm cleanups follow-up to #5301#5874

Merged
danielhanchen merged 5 commits into
mainfrom
cleanup/rocm-5301-followups
May 30, 2026
Merged

studio: ROCm cleanups follow-up to #5301#5874
danielhanchen merged 5 commits into
mainfrom
cleanup/rocm-5301-followups

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Follow-up cleanups on the AMD ROCm support that landed in #5301. No new features; five focused changes, two of which are intentional behavior changes (tasks 3 and 5 below).

1. De-duplicate the torchao Windows-ROCm import stub

The ~100 line torchao stub (_StubTypeMeta, _make_mod_stub, _StubSubpackageFinder, the win32-ROCm probe and the torchao submodule seeding) was copy-pasted into both run_export_process() and run_training_process(). It now lives once in studio/backend/core/_torchao_stub.py; both workers call install_torchao_windows_rocm_stub(). Behavior is unchanged (the two copies were logically identical). Source-string tests in test_rocm_support.py were repointed at the new module.

2. Align the gfx name/arch tables (whitespace only)

The _setup_gfx case block in setup.sh and the $nameArchTable / arch-family-map tables in setup.ps1 had ragged value/comment columns. Re-aligned; no logic change.

3. Isolate the float16 dtype fallback to AMD

trainer.py previously did _auto_dtype = None if is_bfloat16_supported() else torch.float16, which forced fp16 on any host lacking bf16 including older NVIDIA (T4/V100) and overrode FORCE_FLOAT32 models there. The fp16 fallback is now gated on AMD ROCm (torch.version.hip / rocm in the build string); NVIDIA keeps dtype=None so unsloth's own bf16/fp16/float32 detection (including FORCE_FLOAT32) is honored. RDNA2/gfx103x behavior is preserved.

4. Hoist function-internal stdlib imports to module level

Across the source files #5301 touched, inline stdlib imports (gc, glob, re, subprocess, copy, types, sys, os, signal, importlib.metadata) were moved to the top of their modules. Lazy heavy/optional imports (torch, mlx.core, psutil, transformers, huggingface_hub, unsloth*) and circular-risk local imports were deliberately left inline. Driven and verified by a deterministic AST scan; pyflakes reports no new undefined names and no new unused imports.

5. bnb AMD Windows wheel: plain pip instead of UV_SKIP_WHEEL_FILENAME_CHECK

_install_bnb_windows_rocm() set UV_SKIP_WHEEL_FILENAME_CHECK=1 to make uv accept the intentionally version-mismatched wheel. Per the AMD install guide (https://unsloth.ai/docs/get-started/install/amd/amd-hackathon) uv mangles the bitsandbytes install, so this now passes force_pip=True to pip_install_try (plain pip performs no wheel filename/metadata check) and drops the env-var manipulation. Tests updated accordingly.

Testing

  • tests/studio/install/ 503 passed, 1 skipped (unchanged from base).
  • Torchao-stub and bnb test classes pass.
  • tests/studio/ and studio/backend/tests/ pre-existing failures are identical with and without this branch (test-isolation / missing-dep / API-server cases), so no regressions.
  • py_compile plus an AST inline-import gate clean across all changed files.

De-duplicate the torchao Windows-ROCm import stub into
studio/backend/core/_torchao_stub.py (it was copy-pasted into the export
and training workers); both workers now call
install_torchao_windows_rocm_stub().

Align the gfx name/arch tables in setup.sh and setup.ps1 (whitespace only).

trainer: only force the float16 dtype fallback on AMD ROCm hosts without
native bf16 (e.g. RDNA2/gfx103x). NVIDIA now keeps dtype=None so unsloth's
own bf16/fp16/float32 auto-detection (including FORCE_FLOAT32 models) is
honored; older NVIDIA without bf16 (T4/V100) is no longer wrongly coerced
to float16.

Hoist function-internal stdlib imports (gc, glob, re, subprocess, copy,
types, sys, os, signal, importlib.metadata) to module level across the
source files that #5301 touched. Lazy heavy/optional imports (torch, mlx,
psutil, transformers, huggingface_hub, unsloth) and circular-risk local
imports are kept inline on purpose.

bnb AMD Windows wheel: install with plain pip (force_pip=True) instead of
setting UV_SKIP_WHEEL_FILENAME_CHECK, per the AMD install guide, since uv
mangles the bitsandbytes wheel.
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the torchao Windows-ROCm stub logic by extracting it into a shared module to eliminate duplication across the export and training workers. It also cleans up local imports across several backend files and updates the bitsandbytes installation process on AMD Windows to force plain pip instead of uv. The review feedback highlights a critical NameError in studio/backend/main.py due to a removed import, and suggests improvements to prevent duplicate stub finder registration and to centralize the ROCm detection logic.

Comment thread studio/backend/main.py
import re as _re

html = html_bytes.decode("utf-8")
html = _re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since import re as _re was removed from this function, using _re here will raise a NameError at runtime. It should be updated to use re (assuming re is imported at the module level).

Suggested change
html = _re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html)
html = re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html)

Comment on lines +129 to +132
if _is_win32_rocm:
# Register the finder only on Windows ROCm -- on other platforms there
# are no stub modules seeded, so appending is a pure accumulation.
sys.meta_path.append(_StubSubpackageFinder())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent duplicate finders from being appended to sys.meta_path if install_torchao_windows_rocm_stub() is called multiple times, we should check if an instance of _StubSubpackageFinder is already registered.

Suggested change
if _is_win32_rocm:
# Register the finder only on Windows ROCm -- on other platforms there
# are no stub modules seeded, so appending is a pure accumulation.
sys.meta_path.append(_StubSubpackageFinder())
if _is_win32_rocm:
# Register the finder only on Windows ROCm -- on other platforms there
# are no stub modules seeded, so appending is a pure accumulation.
if not any(isinstance(f, _StubSubpackageFinder) for f in sys.meta_path):
sys.meta_path.append(_StubSubpackageFinder())

# bf16/fp16/float32 auto-detection (including FORCE_FLOAT32 models) is
# honored -- older NVIDIA without bf16 (T4/V100) must NOT be coerced to
# float16 here, which the previous unconditional branch did wrongly.
_is_rocm = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of duplicating complex logical checks to determine if ROCm is used across multiple files (such as _torchao_stub.py and trainer.py), we should centralize this recurring check into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.

Suggested change
_is_rocm = (
_is_rocm = is_rocm()
References
  1. Centralize recurring or complex logical checks (like determining if a host is external) into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.

danielhanchen and others added 3 commits May 30, 2026 07:32
- Add test_export_worker_calls_shared_torchao_stub for symmetry with the
  training-worker test (guards against silently dropping the export call).
- Note why trainer derives ROCm inline instead of reusing hardware.IS_ROCM.
scripts/verify_import_hoist.py is a scope-aware (LEGB) AST resolver that
catches two refactor bugs ruff and pyflakes both miss when stdlib imports
are hoisted to module top:
  - dangling alias: `from a import b as _b` hoisted to `from a import b`
    while a `_b` reference is left un-normalized (resolves to nothing or
    to a different module-level _b).
  - rename clash: `_b -> b` silently re-points at another object already
    named b in that scope.

Wire it into the source-lint job in lint-ci.yml as two steps: a hermetic
--self-test (8 negative controls) that fails if the verifier itself stops
catching a known bug class, and a pull_request-only compare gate that
diffs each in-place-modified .py (base vs PR head) and fails on a blocker.
INFO findings (a helper relocated to another file) do not fail.
@danielhanchen danielhanchen merged commit 8ec9a74 into main May 30, 2026
33 of 37 checks passed
@danielhanchen danielhanchen deleted the cleanup/rocm-5301-followups branch May 30, 2026 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant