studio: ROCm cleanups follow-up to #5301#5874
Conversation
De-duplicate the torchao Windows-ROCm import stub into studio/backend/core/_torchao_stub.py (it was copy-pasted into the export and training workers); both workers now call install_torchao_windows_rocm_stub(). Align the gfx name/arch tables in setup.sh and setup.ps1 (whitespace only). trainer: only force the float16 dtype fallback on AMD ROCm hosts without native bf16 (e.g. RDNA2/gfx103x). NVIDIA now keeps dtype=None so unsloth's own bf16/fp16/float32 auto-detection (including FORCE_FLOAT32 models) is honored; older NVIDIA without bf16 (T4/V100) is no longer wrongly coerced to float16. Hoist function-internal stdlib imports (gc, glob, re, subprocess, copy, types, sys, os, signal, importlib.metadata) to module level across the source files that #5301 touched. Lazy heavy/optional imports (torch, mlx, psutil, transformers, huggingface_hub, unsloth) and circular-risk local imports are kept inline on purpose. bnb AMD Windows wheel: install with plain pip (force_pip=True) instead of setting UV_SKIP_WHEEL_FILENAME_CHECK, per the AMD install guide, since uv mangles the bitsandbytes wheel.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request refactors the torchao Windows-ROCm stub logic by extracting it into a shared module to eliminate duplication across the export and training workers. It also cleans up local imports across several backend files and updates the bitsandbytes installation process on AMD Windows to force plain pip instead of uv. The review feedback highlights a critical NameError in studio/backend/main.py due to a removed import, and suggests improvements to prevent duplicate stub finder registration and to centralize the ROCm detection logic.
| import re as _re | ||
|
|
||
| html = html_bytes.decode("utf-8") | ||
| html = _re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html) |
There was a problem hiding this comment.
Since import re as _re was removed from this function, using _re here will raise a NameError at runtime. It should be updated to use re (assuming re is imported at the module level).
| html = _re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html) | |
| html = re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html) |
| if _is_win32_rocm: | ||
| # Register the finder only on Windows ROCm -- on other platforms there | ||
| # are no stub modules seeded, so appending is a pure accumulation. | ||
| sys.meta_path.append(_StubSubpackageFinder()) |
There was a problem hiding this comment.
To prevent duplicate finders from being appended to sys.meta_path if install_torchao_windows_rocm_stub() is called multiple times, we should check if an instance of _StubSubpackageFinder is already registered.
| if _is_win32_rocm: | |
| # Register the finder only on Windows ROCm -- on other platforms there | |
| # are no stub modules seeded, so appending is a pure accumulation. | |
| sys.meta_path.append(_StubSubpackageFinder()) | |
| if _is_win32_rocm: | |
| # Register the finder only on Windows ROCm -- on other platforms there | |
| # are no stub modules seeded, so appending is a pure accumulation. | |
| if not any(isinstance(f, _StubSubpackageFinder) for f in sys.meta_path): | |
| sys.meta_path.append(_StubSubpackageFinder()) |
| # bf16/fp16/float32 auto-detection (including FORCE_FLOAT32 models) is | ||
| # honored -- older NVIDIA without bf16 (T4/V100) must NOT be coerced to | ||
| # float16 here, which the previous unconditional branch did wrongly. | ||
| _is_rocm = ( |
There was a problem hiding this comment.
Instead of duplicating complex logical checks to determine if ROCm is used across multiple files (such as _torchao_stub.py and trainer.py), we should centralize this recurring check into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.
| _is_rocm = ( | |
| _is_rocm = is_rocm() |
References
- Centralize recurring or complex logical checks (like determining if a host is external) into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.
- Add test_export_worker_calls_shared_torchao_stub for symmetry with the training-worker test (guards against silently dropping the export call). - Note why trainer derives ROCm inline instead of reusing hardware.IS_ROCM.
scripts/verify_import_hoist.py is a scope-aware (LEGB) AST resolver that
catches two refactor bugs ruff and pyflakes both miss when stdlib imports
are hoisted to module top:
- dangling alias: `from a import b as _b` hoisted to `from a import b`
while a `_b` reference is left un-normalized (resolves to nothing or
to a different module-level _b).
- rename clash: `_b -> b` silently re-points at another object already
named b in that scope.
Wire it into the source-lint job in lint-ci.yml as two steps: a hermetic
--self-test (8 negative controls) that fails if the verifier itself stops
catching a known bug class, and a pull_request-only compare gate that
diffs each in-place-modified .py (base vs PR head) and fails on a blocker.
INFO findings (a helper relocated to another file) do not fail.
for more information, see https://pre-commit.ci
Follow-up cleanups on the AMD ROCm support that landed in #5301. No new features; five focused changes, two of which are intentional behavior changes (tasks 3 and 5 below).
1. De-duplicate the torchao Windows-ROCm import stub
The ~100 line torchao stub (
_StubTypeMeta,_make_mod_stub,_StubSubpackageFinder, the win32-ROCm probe and the torchao submodule seeding) was copy-pasted into bothrun_export_process()andrun_training_process(). It now lives once instudio/backend/core/_torchao_stub.py; both workers callinstall_torchao_windows_rocm_stub(). Behavior is unchanged (the two copies were logically identical). Source-string tests intest_rocm_support.pywere repointed at the new module.2. Align the gfx name/arch tables (whitespace only)
The
_setup_gfxcase block insetup.shand the$nameArchTable/ arch-family-map tables insetup.ps1had ragged value/comment columns. Re-aligned; no logic change.3. Isolate the float16 dtype fallback to AMD
trainer.pypreviously did_auto_dtype = None if is_bfloat16_supported() else torch.float16, which forced fp16 on any host lacking bf16 including older NVIDIA (T4/V100) and overrode FORCE_FLOAT32 models there. The fp16 fallback is now gated on AMD ROCm (torch.version.hip/rocmin the build string); NVIDIA keepsdtype=Noneso unsloth's own bf16/fp16/float32 detection (including FORCE_FLOAT32) is honored. RDNA2/gfx103x behavior is preserved.4. Hoist function-internal stdlib imports to module level
Across the source files #5301 touched, inline stdlib imports (
gc,glob,re,subprocess,copy,types,sys,os,signal,importlib.metadata) were moved to the top of their modules. Lazy heavy/optional imports (torch,mlx.core,psutil,transformers,huggingface_hub,unsloth*) and circular-risk local imports were deliberately left inline. Driven and verified by a deterministic AST scan;pyflakesreports no new undefined names and no new unused imports.5. bnb AMD Windows wheel: plain pip instead of UV_SKIP_WHEEL_FILENAME_CHECK
_install_bnb_windows_rocm()setUV_SKIP_WHEEL_FILENAME_CHECK=1to make uv accept the intentionally version-mismatched wheel. Per the AMD install guide (https://unsloth.ai/docs/get-started/install/amd/amd-hackathon) uv mangles the bitsandbytes install, so this now passesforce_pip=Truetopip_install_try(plain pip performs no wheel filename/metadata check) and drops the env-var manipulation. Tests updated accordingly.Testing
tests/studio/install/503 passed, 1 skipped (unchanged from base).tests/studio/andstudio/backend/tests/pre-existing failures are identical with and without this branch (test-isolation / missing-dep / API-server cases), so no regressions.py_compileplus an AST inline-import gate clean across all changed files.