Skip to content

fix/strix halo and windows AMD ROCm support#5301

Merged
danielhanchen merged 184 commits into
unslothai:mainfrom
LeoBorcherding:fix/rocm-strix-halo-unified-memory
May 30, 2026
Merged

fix/strix halo and windows AMD ROCm support#5301
danielhanchen merged 184 commits into
unslothai:mainfrom
LeoBorcherding:fix/rocm-strix-halo-unified-memory

Conversation

@LeoBorcherding
Copy link
Copy Markdown
Contributor

@LeoBorcherding LeoBorcherding commented May 6, 2026

Summary

This PR started as a fix for AMD iGPU unified memory reporting on Strix Halo and grew into full AMD ROCm support for Windows plus targeted Linux fixes. 27 files changed, ~3,700 lines added.


1. Strix Halo / AMD iGPU: unified memory detection (hardware.py, amd.py)

Problem: amd-smi on iGPUs with unified/shared memory (e.g. Radeon 8060S / gfx1151 on Strix Halo) reports only the dedicated VRAM slice (~512 MB - 1 GB) rather than the full unified pool (~128 GB). get_visible_gpu_utilization() was returning usable_gb ~= 0.35 GB, producing a spurious "Falling back to all visible GPUs -- model may not fit" warning even with 100+ GB of unified memory available.

Fix: Added _reconcile_rocm_unified_memory() and _reconcile_primary_rocm_unified_memory() (sharing a common _apply_unified_memory_correction() helper): after amd-smi returns a result on a ROCm device, cross-check each device's vram_total_gb against torch.cuda.mem_get_info(). When torch reports a larger total, replace the amd-smi VRAM fields in-place. No-op for discrete AMD GPUs where the two sources agree. Applied in both get_gpu_utilization and get_visible_gpu_utilization.

Verified on Linux: Radeon 8060S / gfx1151 / 128 GB unified (@Chigoma333, @h34v3nzc0dex)
Windows Strix Halo: code path implemented; hardware confirmation pending (Strix Halo devices do ship with Windows but no Windows Strix Halo tester has confirmed in this thread)


2. Windows: AMD GPU detection (install.ps1, setup.ps1)

Both scripts now detect AMD GPUs on Windows with the same reliability as the Linux installer.

  • Full detection waterfall: hipinfo (PATH) -> hipinfo (HIP_PATH\bin / ROCM_PATH\bin) -> amd-smi list -> amd-smi static --asic -> WMI marketing-name fallback. The AMD HIP SDK sets HIP_PATH/ROCM_PATH but does not always add the bin dir to PATH, so the env-var fallback handles this silently.
  • Three-way message branching:
    • HIP SDK installed + GPU ROCm-accessible: normal ROCm install path
    • HIP SDK installed + GPU not ROCm-accessible: "AMD GPU detected -- not ROCm-accessible (HIP 7.1.xxx)" with driver-issue explanation
    • No HIP SDK: "AMD GPU detected -- HIP SDK not found" with install link
  • Terminal visibility: HIP SDK path and full hipconfig --version string shown as substeps under the gpu step on successful detection.
  • Python 3.12 preference: AMD ROCm users are steered toward Python 3.12 since AMD's index does not yet ship 3.13 wheels. GPU arch is detected before Python selection so the installer does not create a 3.13 venv that immediately needs to be recreated.
  • gfx arch normalisation: gcnArchName tokens are lowercased and feature suffixes stripped (e.g. gfx1151:xnack- -> gfx1151) before wheel-index lookup. [regex]::Matches() is used instead of -match to capture all GPU entries on multi-GPU hosts, and results are wrapped with @() to prevent PowerShell unwrapping a single-item result to a scalar string (which was previously causing the arch token to be indexed by character, returning "g" instead of "gfx1200").
  • HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES respected when selecting which GPU arch to install for on multi-GPU hosts.

Verified on Windows: RX 9060 XT / gfx1200 (@LeoBorcherding), RX 7600 XT / gfx1102 (@VELICAN-HACKS)


3. Windows: AMD ROCm PyTorch wheel installation (install.ps1, install_python_stack.py)

  • Arch-aware wheel index: Detected gfx arch maps to AMD's pip index at https://repo.amd.com/rocm/whl/{arch_family}/. Supported families: gfx1200/gfx1201 (RDNA 4), gfx1100/gfx1101/gfx1102/gfx1103 (RDNA 3), gfx1150/gfx1151 (Strix). UNSLOTH_ROCM_WINDOWS_MIRROR overrides the base URL for air-gapped installs.
  • Newest release auto-selection: Always picks the newest available ROCm release for the detected arch (e.g. rocm7.13 when available) rather than hardcoding a version.
  • Full stack installed: torch, torchvision, torchaudio, rocm_sdk_core, and rocm_sdk_libraries_gfx* installed together; wheel version cross-checking ensures the three torch packages are a matched set.
  • ROCm torch repair: Detects when dependency resolution (e.g. torchao or unsloth) overwrites the ROCm torch with a CPU wheel and reinstalls from the AMD index.
  • HIP_PATH fallback in Python stack: _detect_windows_gfx_arch() probes HIP_PATH\bin / ROCM_PATH\bin when hipinfo is not on PATH, then falls back to amd-smi as a third tier, so _ensure_rocm_torch() does not silently skip ROCm wheels on runtime-only AMD installs.
  • rocm7.2 index enabled: Previously commented out of _ROCM_TORCH_INDEX pending a torch upper bound bump. Now enabled with a matching _ROCM_TORCH_PKG_SPECS entry (torch>=2.11.0,<2.12.0).

Verified on Windows: RX 9060 XT / gfx1200, torch==2.11.0+rocm7.13.0 from repo.amd.com (@LeoBorcherding); RX 7600 XT / gfx1102, torch==2.9.0+rocmsdk (@VELICAN-HACKS)


4. Windows: ROCm runtime patches (worker.py, main.py, hardware.py)

Several issues in AMD's current Windows wheel require workarounds at startup. All patches are scoped to Windows ROCm only (gated on active HIP runtime via torch.version.hip or 'rocm' in torch.__version__) and self-retire when AMD ships fixes.

torch._C._distributed_c10d missing

torch._C is a compiled C extension on Windows ROCm. Python cannot do submodule imports from it, so torch._C._distributed_c10d does not exist, causing import failures when torch.distributed is loaded. A populated sys.modules stub is injected before import torch.distributed, exposing FakeProcessGroup, ProcessGroup, Work, Store, and all other symbols. The if name not in sys.modules guard means this self-retires the moment AMD fixes their build (tracked: ROCm/TheRock#3284).

torchao stubs

torchao imports torch.distributed._functional_collectives at module level, which crashes on Windows ROCm via the missing C backend. torchao is stubbed with a custom _StubTypeMeta metaclass so isinstance(x, StubClass) returns False (not TypeError) since peft's LoRA torchao path calls isinstance() on the imported types. A _StubSubpackageFinder meta path finder auto-stubs any depth of torchao.xxx.yyy imports. Scoped to Windows ROCm only so real torchao is unaffected on Windows NVIDIA.

BNB_ROCM_VERSION in server process

The server process (main.py) and training worker are separate processes, so env vars set in the worker are not visible to the server. main.py now scans libbitsandbytes_rocm*.dll at startup to detect the installed DLL suffix and sets BNB_ROCM_VERSION before any bitsandbytes import, eliminating the "Configured ROCm binary not found" error. DLL directory handles are retained at module scope in _ROCM_DLL_HANDLES in both main.py and worker.py so they are not garbage-collected. Falls back to "72" when no DLL is found.

_grouped_mm null kernel (gfx1200, ROCm <= 7.12)

On Windows ROCm with HIP < 7.13, torch._grouped_mm dispatches to a null HIP kernel on gfx1200, causing a 0xC0000005 access violation (crash) during training. A Python fallback implementation is patched in via torch.library.Library on the CUDA dispatch key when HIP < 7.13. The fallback handles 2-D (mm), 3-D batched (bmm/matmul), and grouped-with-offsets inputs. Bypassed entirely on HIP >= 7.13 where AMD fixed the kernel. The torch.library.Library registration is kept at module scope in _WINDOWS_ROCM_GROUPED_MM_LIB to prevent garbage collection mid-run.

amd-smi circuit breaker

amd-smi on some Windows installs returns non-zero or opens console popups. GPU polling disables itself after 3 consecutive failures (_AMD_SMI_FAILURE_LIMIT = 3) and all amd-smi subprocess calls suppress console windows via windows_hidden_subprocess_kwargs(). The default timeout is also raised from 5s to 30s on Windows (10s on Linux) to accommodate the cold-hardware ROCm runtime initialisation cost.

HIP_VISIBLE_DEVICES support

HIP_VISIBLE_DEVICES is now honoured in _get_parent_visible_gpu_spec and apply_gpu_ids sets HIP_VISIBLE_DEVICES for ROCm training workers, parallel to CUDA_VISIBLE_DEVICES on NVIDIA.

Verified on Windows: full QLoRA training run on unsloth/Llama-3.2-1B-Instruct / unsloth/alpaca-cleaned completed successfully on RX 9060 XT / gfx1200 / ROCm 7.13 (@LeoBorcherding)


5. Windows: ROCm inference fix (llama_cpp.py)

Problem: The llama.cpp prebuilt bundles its own rocblas.dll next to the binary but not the Tensile kernel library files it depends on at runtime (rocblas/library/TensileLibrary*.dat + .hsaco). When rocBLAS initialises for the first real GEMM (prefill), it searches for those files relative to its own location which does not exist in the prebuilt install tree, causing a crash with WinError 10054/10061. This is why the crash only appeared on the first generation request and not during model load.

Fix: Set ROCBLAS_TENSILE_LIBPATH in the llama-server subprocess env to <HIP_PATH>\bin\rocblas\library when the directory exists. Uses setdefault so a user-supplied value is never overwritten. No-op on CUDA/CPU and Linux.

Verified on Windows: gfx1200 / RX 9060 XT (@LeoBorcherding), gfx1102 / RX 7600 XT, inference no longer falling back to CPU after fix (@VELICAN-HACKS)


6. Windows: install_llama_prebuilt.py

  • _resolve_exe() added: checks PATH then HIP_PATH\bin / ROCM_PATH\bin before giving up, so has_rocm is not silently False on machines where the HIP SDK bin dir is not on PATH.
  • HIP asset (llama-{tag}-bin-win-hip-radeon-x64.zip) added to the simple-policy Windows path so the HIP prebuilt is selected when ROCm is detected.
  • --has-rocm flag added: setup.ps1 passes its own detection result through so the prebuilt installer does not need to re-run hipinfo/amd-smi internally.

7. Linux: Strix Halo + ROCm 7.1 routing fix (install.sh, install_python_stack.py)

Problem: ROCm 7.1 has a driver bug causing a segfault in torch._grouped_mm (moe_utils.py:167) on gfx1151/gfx1150. The Radeon repo began shipping Python 3.13 wheels for rocm-rel-7.1, so the installer was silently landing on the broken combo even for users who expected rocm7.2.

Fix: When gfx1151 or gfx1150 is detected and TORCH_INDEX_URL is */rocm7.1, the installer overrides to https://repo.amd.com/rocm/whl/gfx1151/ (AMD's arch-specific index), landing users on torch==2.11.0+rocm7.13.0 which has the actual _grouped_mm kernel fix. UNSLOTH_AMD_ROCM_MIRROR overrides the base URL for air-gapped installs. HIP_VISIBLE_DEVICES/ROCR_VISIBLE_DEVICES are respected on mixed iGPU+dGPU setups so the override only fires when the runtime-target GPU is actually a Strix arch. Also: ROCm tag normalisation now uses explicit case arms (e.g. rocm7.1|rocm7.1.*) so a future rocm7.10 does not match rocm7.1*. install.sh also adds a sysfs KFD topology fallback (/sys/class/kfd/kfd/topology/nodes/*/properties) for AMD detection when neither rocminfo nor amd-smi is available, and explicit warning messages when ROCm version cannot be determined.

Verified on Linux: gfx1151 / Radeon 8060S / Ubuntu 24.04 / ROCm 7.13 nightly: _check_torch_grouped_mm_supported() == True, no segfault, pip download resolves torch==2.11.0+rocm7.13.0 from the AMD index (@h34v3nzc0dex); gfx1201 / Radeon AI PRO R9700 / CachyOS (@jgreenburg)


8. Linux: Ubuntu 24.04 + ROCm 7.x llama.cpp build fix (setup.sh)

Problem: ROCm 7.x ships clang-20, which on Ubuntu 24.04 selects /usr/lib/gcc/x86_64-linux-gnu/14 (a runtime-only dir with no C++ headers) as its default GCC install dir, causing fatal error: 'cstdlib' file not found during the llama.cpp HIP source build.

Fix: setup.sh uses gcc -print-multiarch to determine the correct multiarch path (falling back to $(uname -m)-linux-gnu), then iterates gcc 14->11 to find the first version with both a runtime dir and /usr/include/c++/<ver> headers, and passes --gcc-install-dir=<path> via CMAKE_HIP_FLAGS.

Verified on Linux: picks /usr/lib/gcc/x86_64-linux-gnu/13 on Ubuntu 24.04; clean 417/417 llama.cpp build for gfx1151 (@h34v3nzc0dex)


9. GPU OOM guard (worker.py)

Problem: On ROCm GPUs (confirmed on RDNA 4 gfx1200/gfx1201 and Strix Halo gfx1151), exhausting VRAM/unified memory can cause the HIP driver to hang the GPU ring buffer rather than raising a Python OutOfMemoryError, resulting in a full system freeze requiring a hard reboot.

Fix: torch.cuda.set_per_process_memory_fraction() is called at worker startup on ROCm only:

  • Discrete cards: 0.90 -- leaves ~1.6 GB headroom on a 16 GB card for HIP context and allocator fragmentation.
  • Unified-memory APUs (gfx1150/gfx1151): 0.80 -- the unified pool is shared with the host OS and page cache. 0.90 x 128 GB would leave only ~13 GB for the entire system, reproducing the freeze rather than preventing it.

Classification uses gcnArchName stripped of feature suffixes (e.g. gfx1151:xnack- -> gfx1151), checked against {"gfx1150", "gfx1151"}. This is naming-independent and avoids the marketing-name ambiguity between Strix Point (890M) and Strix Halo (8060S). The OOM exception handler also surfaces a clear, actionable UI message with concrete remediation steps instead of a raw HIP error string.

Verified on Windows: gfx1200 / RX 9060 XT: training session that previously froze the machine now terminates cleanly with an OOM error (@LeoBorcherding)
Verified on Linux: gfx1151 / Radeon 8060S / 128 GB unified: set_per_process_memory_fraction(0.80) accepted, _is_unified=True, 102.4 GiB cap, 25.6 GiB left for OS (@h34v3nzc0dex)
Windows Strix Halo: gcnArchName classification will apply the 0.80 fraction; hardware confirmation pending


10. Tests (tests/studio/install/test_rocm_support.py)

~1,231 lines of new test coverage:

  • TestStrixHaloGfxArchDetection -- arch detection waterfall (hipinfo PATH / HIP_PATH / amd-smi)
  • TestHipSdkEnvPathResolution -- HIP_PATH/ROCM_PATH fallback for hipinfo/hipconfig
  • TestHipSdkDetectedSubstep -- terminal output substeps on AMD detection
  • TestStrixRocm71Override -- Strix routing to AMD arch-specific index; UNSLOTH_AMD_ROCM_MIRROR; non-Strix ROCm 7.1 stays on pytorch.org
  • TestSetupShGccInstallDir -- Ubuntu 24.04 gcc-install-dir fix
  • TestServerStartupRocmFixes -- BNB_ROCM_VERSION, distributed stubs, FakeProcessGroup
  • TestHipSdkInstalledButDeviceInaccessible -- SDK-found-but-not-ROCm-accessible branch
  • Multi-GPU selection, HIP_VISIBLE_DEVICES/ROCR_VISIBLE_DEVICES, Strix sibling handling

231 passed, 1 skipped.


Files changed

File What changed
install.ps1 AMD GPU detection waterfall, arch-aware wheel index, Python 3.12 preference, multi-GPU arch selection
studio/setup.ps1 Same as install.ps1 for the setup path; llama.cpp kind mismatch detection; verbose PyTorch output
install.sh Strix + ROCm 7.1 -> AMD arch-specific index; ROCm tag normalisation; sysfs KFD fallback; explicit version warnings
studio/install_python_stack.py _ensure_rocm_torch(), _detect_windows_gfx_arch(), rocm7.2 enabled, Strix routing parity with install.sh
studio/setup.sh Ubuntu 24.04 --gcc-install-dir fix for HIP llama.cpp builds
studio/backend/main.py BNB_ROCM_VERSION auto-detection; ROCm DLL directory registration (_ROCM_DLL_HANDLES)
studio/backend/core/training/worker.py torchao stubs; _grouped_mm Python fallback; _ROCM_DLL_HANDLES; OOM guard; BNB_ROCM_VERSION detection
studio/backend/utils/hardware/hardware.py _reconcile_rocm_unified_memory(); torch.distributed stubs; HIP_VISIBLE_DEVICES in apply_gpu_ids and _get_parent_visible_gpu_spec; broad ROCm detection
studio/backend/utils/hardware/amd.py amd-smi circuit breaker; CREATE_NO_WINDOW suppression; raised timeout; non-integer GPU id warning
studio/backend/core/inference/llama_cpp.py ROCBLAS_TENSILE_LIBPATH for Windows ROCm inference
studio/install_llama_prebuilt.py _resolve_exe() HIP_PATH fallback; HIP asset in simple policy; --has-rocm flag
tests/studio/install/test_rocm_support.py ~1,231 lines new test coverage

LeoBorcherding and others added 4 commits May 5, 2026 21:33
…ng workers

Training workers are spawned via multiprocessing spawn before detect_hardware()
runs, so IS_ROCM is still False. If the user never set HIP_VISIBLE_DEVICES in
their shell, _inherits_rocm_visibility is also False, leaving the worker with
only CUDA_VISIBLE_DEVICES set. On ROCm hosts the HIP runtime honors
HIP_VISIBLE_DEVICES over CUDA_VISIBLE_DEVICES, so the worker saw the full
device list and torch raised "no usable HIP accelerator" on some setups.

Fall back to probing torch.version.hip (a build-time attribute, safe to read
before GPU init) to detect ROCm when neither IS_ROCM nor inherited env vars
are available. Mirrors the existing fix in llama_cpp.py for llama-server
subprocess GPU pinning.

Fixes unslothai#5180
Replace loose OR chain with exact string matches, split into three
focused tests, and add a guard check for the try/except wrapper.
…lback

amd-smi on iGPUs with shared/unified memory (e.g. Radeon 8060S on Strix
Halo) reports only the dedicated VRAM slice (~512 MB) in its metric output,
so get_visible_gpu_utilization() was returning usable_gb ≈ 0.35 GB instead
of the full GTT pool (~128 GB).  torch.cuda.mem_get_info() already surfaces
the correct unified-pool size.

Add _reconcile_rocm_unified_memory(): after amd-smi returns a valid result
on a ROCm device, cross-check each device's vram_total_gb against
torch.cuda.mem_get_info().  When torch reports a larger total, replace the
amd-smi VRAM fields in-place.  No-op for discrete AMD GPUs where the two
sources agree.

Fixes: "Falling back to all visible GPUs -- model may not fit" on AMD iGPU
machines even when 100+ GB of unified memory is available.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a _reconcile_rocm_unified_memory function to correctly report VRAM on AMD iGPUs by reconciling amd-smi data with torch memory information, and integrates it into get_visible_gpu_utilization. Feedback suggests applying this same reconciliation logic to get_gpu_utilization to ensure consistent reporting across the module.

Comment thread studio/backend/utils/hardware/hardware.py Outdated
danielhanchen and others added 25 commits May 6, 2026 12:21
The visible-GPU path was already corrected for AMD iGPUs with unified memory
(Strix Halo / Radeon 8060S), but get_gpu_utilization was still returning the
raw 512 MB amd-smi VRAM slice. Studio's /api/train/hardware endpoint and the
live GPU monitor read from this primary path, so users continued seeing the
wrong total even after auto_select_gpu_ids picked the right device.

Refactor to share the per-device correction:
  * _apply_unified_memory_correction(metrics, torch_info) -- the actual
    replacement logic, in-place on a single metrics dict.
  * _reconcile_rocm_unified_memory(...)                   -- multi-device,
    iterates utilization["devices"] (visible-GPU path).
  * _reconcile_primary_rocm_unified_memory(...)           -- single flat
    metrics dict (primary-GPU path), uses parent_visible_spec to pick the
    primary index, falls back to ordinal 0 when no visibility env is set.

get_gpu_utilization now calls the primary reconciler under IS_ROCM, so both
endpoints surface the real unified-memory pool on iGPUs while leaving
discrete AMD GPUs untouched (torch_total <= smi_total -> no replace).
Two small follow-ups to the apply_gpu_ids ROCm fallback:

1. Match detect_hardware()'s 'getattr(torch.version, "hip", None) is not None'
   form so the entire codebase has one canonical 'this torch was built with
   HIP' check. On every shipping torch wheel hip is either None or a non-empty
   version string, so the new form agrees with the old bool() form on every
   real install.

2. Log the probe failure at debug level instead of swallowing it silently.
   The broad 'except Exception' is intentional (we never want apply_gpu_ids
   to crash a worker over a probe), but the silent pass made it impossible
   to tell whether the fallback was firing or being skipped.
…ec before IS_ROCM is set

When a user has HIP_VISIBLE_DEVICES set in their shell (e.g. "1" to select
GPU 1) but detect_hardware() has not yet run in the Studio parent process,
IS_ROCM is still False.  _get_parent_visible_gpu_spec() was gated on IS_ROCM
so it fell through to CUDA_VISIBLE_DEVICES (unset), saw all physical GPUs,
and auto-selected index 0.  apply_gpu_ids then overwrote HIP_VISIBLE_DEVICES
with "0", making the intended GPU invisible to ROCm torch in the worker,
which triggered the "no usable HIP accelerator" error (issue unslothai#5180).

Apply the same _inherits_rocm_visibility pattern already used in
apply_gpu_ids: check for HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES in the
environment regardless of IS_ROCM so the correct GPU index is preserved.
…tered setups

The previous rocminfo awk pattern could miss discrete GPUs on machines
where HIP_VISIBLE_DEVICES/ROCR_VISIBLE_DEVICES is used to mask an
integrated GPU — the env vars filter rocminfo output but may not
propagate into the install script subprocess, causing detection to
fail entirely.

Two changes:
- Tighten rocminfo pattern from /gfx[0-9]/ && !/gfx000/ to
  /gfx[1-9][0-9]/ — simpler and correctly excludes the CPU agent
  (gfx000) without a negative lookahead
- Add sysfs KFD topology fallback: reads
  /sys/class/kfd/kfd/topology/nodes/*/gpu_id which is a kernel-level
  view unaffected by HIP_VISIBLE_DEVICES or ROCR_VISIBLE_DEVICES

Fixes detection failure reported in Discord by Chains (gfx1201 + iGPU
machine where env var exclusion of the iGPU caused rocminfo to return
no usable device).
The fallback added by this PR reads /sys/class/kfd/kfd/topology/nodes/*/gpu_id
files but matches the literal token 'gpu_id' against their content. Those
files contain only a single decimal value (e.g. '0' for CPU agents, '50432'
for GPU agents), so the regex never matches and 'found' stays 0, making the
fallback a no-op on every host. The properties file in the same directory
contains key/value lines like 'gpu_id 50432' which is what the existing awk
pattern expects.

Reproduced with a synthetic sysfs layout: against gpu_id files awk exits 1;
against properties files awk exits 0 when any node reports gpu_id > 0.
…setup.sh

setup.ps1 only checked nvidia-smi and fell straight to "gpu: none" on AMD
machines. setup.sh already probed rocminfo/amd-smi/hipconfig/hipinfo.

Add three-tier detection mirroring install_llama_prebuilt.py's detect_host():
1. hipinfo: gcnArchName in output confirms a real HIP GPU (not just SDK)
2. amd-smi list: "GPU: <digit>" data rows as fallback
3. WMI Win32_VideoController: last resort -- detects AMD GPU even without
   HIP SDK, then guides user to install it rather than silently going CPU

Also corrects the "none" message to mention AMD ROCm alongside NVIDIA so
users with AMD hardware understand the requirement.

Fixes: rohit-style install where Strix Halo (Radeon 8060S) showed
"gpu: none" even with the HIP SDK present.
…h setup.ps1

install.ps1 had the same nvidia-smi-only GPU detection as setup.ps1 before
the setup.ps1 fix. Applies the same three-tier AMD detection:
1. hipinfo: gcnArchName confirms real HIP GPU
2. amd-smi list: GPU data rows as fallback
3. WMI Win32_VideoController: detects AMD GPU without HIP SDK and guides
   user to install it

Fixes: install.ps1 showing "gpu: none" while setup.ps1 correctly showed
"AMD GPU detected" on the same machine (reported by rohit, RX 7600 XT).
install_python_stack.py:
- Add _ROCM_WINDOWS_WHEEL_BASE and _ROCM_WINDOWS_RELEASES constants
  pointing to AMD repo.radeon.com (ROCm 7.2 -> torch 2.9.1+rocm7.2.1)
- Extend _ensure_rocm_torch() with a Windows branch: detects ROCm via
  _has_rocm_gpu() / _detect_rocm_version(), requires Python 3.12 (cp312
  is the only ABI AMD publishes for Windows), installs the direct wheel
  URL from repo.radeon.com

install.ps1:
- Capture ROCmVersion during AMD detection via hipconfig --version /
  amd-smi version (needed for wheel URL selection)
- After Get-TorchIndexUrl, add an AMD wheel override block: when HasROCm
  and Python 3.12 detected, set ROCmTorchWheelUrl to AMD wheel URL
- Expand torch install branch to handle ROCmTorchWheelUrl with
  uv pip install --force-reinstall --no-cache-dir
AMD publishes matching torchvision-0.24.1+rocm7.2.1 and
torchaudio-2.9.1+rocm7.2.1 cp312 wheels at the same repo.radeon.com
release folder. Install all three in both install.ps1 and
install_python_stack.py Windows ROCm path.
AMD uses a different version string for 7.1.1 wheels:
2.9.0+rocmsdk20251116 (date-tagged) instead of +rocm7.1.1.
Adds the 7.1.1 release folder to both install.ps1 and
install_python_stack.py so users with ROCm 7.1 get ROCm
torch instead of falling back to CPU.
The AMD Windows torch wheels declare rocm[libraries]==<ver> as a hard
dependency. Without installing rocm_sdk_core and rocm_sdk_libraries_custom
from the same AMD release folder, uv cannot resolve the dependency and
fails with 'No solution found'. Include all 5 wheels in one install call.
@array splatting inside a scriptblock only works when the native command
is prefixed with '&'. Invoke-InstallCommand uses '& $Command' to run the
block, so @ROCmAllWheelUrls was not being expanded. Extract to scalar
variables $rw0-$rw4 which are captured correctly by the closure.
uv's resolver looks up rocm[libraries]==0.1.dev0 on PyPI during
dependency resolution before downloading any wheels, and fails because
the package doesn't exist on PyPI. --no-deps skips resolution entirely
and installs all 5 AMD wheels directly. The GPU runtime dependency is
satisfied by the HIP SDK, not a Python package.
…Windows

setup.ps1 was always setting CuTag='cpu' for non-NVIDIA hosts and installing
cpu-only PyTorch, overwriting the ROCm torch installed by install.ps1.
Adds the same AMD wheel selection logic (ROCm version detection, Python 3.12
check, 5-wheel install with --no-deps) to setup.ps1's torch install block.

install_python_stack.py: remove IS_WINDOWS guard from _ensure_rocm_torch()
call site so the Windows path in _ensure_rocm_torch() is reachable during
'unsloth studio update' as well.
… fix progress counter

- Gate the 'must be installed manually' warning on torch.version.hip being empty
  so it doesn't fire when our ROCm torch install succeeded
- Update _TOTAL counter to include the 3 ROCm steps on Windows now that
  _ensure_rocm_torch() is called there (fixes 10/9 display)
…unter

- Add 'rocm' step after 'cuda' in setup.ps1 showing ROCm version or HIP SDK missing
- Move ROCm version detection up to GPU detection block so it's available early
- Suppress 'must be installed manually' warning when torch.version.hip is set
- Fix _TOTAL counter to include ROCm steps on Windows (fixes 10/9 display)
… is unset

AMD's repo.radeon.com wheels (e.g. 2.9.0+rocmsdk20251116) do not set
torch.version.hip, leaving it None. All three probes that relied solely on
torch.version.hip now also check for 'rocm' in torch.__version__.lower():

- hardware.py detect_hardware(): IS_ROCM was never set, causing the studio
  to report 'Hardware detected: CPU' even after AMD wheels were installed
  and HIP DLLs were on PATH.
- install_python_stack.py _ensure_rocm_torch(): skip-if-already-installed
  probe would always reinstall on subsequent runs.
- install_python_stack.py Windows AMD warning: suppression check always
  failed, so the 'must be installed manually' note kept appearing after
  a successful AMD wheel install.
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

1 similar comment
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

When amd-smi/nvidia-smi is unavailable on Windows, query dedicated GPU
VRAM via Windows Performance Counters (same source as Task Manager).
This gives system-wide cross-process usage, fixing the near-zero reading
caused by torch.cuda.mem_get_info only seeing the Studio server process.

Linux fallback path unchanged (mem_get_info is system-wide on ROCm).
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Function is AMD ROCm specific — amd-smi absent on Windows when only the
HIP SDK is installed. Scoped to IS_ROCM so NVIDIA Windows path is
untouched (nvidia-smi handles that case).
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

LeoBorcherding and others added 2 commits May 29, 2026 14:59
Linux: read /sys/class/drm/card*/device/mem_info_vram_used|total for
system-wide GPU memory across all processes. No tools required, always
present on Linux AMD systems.

Windows: Windows Performance Counter API (already added).

Both paths are gated on IS_ROCM and only fire when amd-smi is absent.
torch mem_get_info remains as last resort (process-local).
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

1 similar comment
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@LeoBorcherding
Copy link
Copy Markdown
Contributor Author

VRAM monitor is now AMD compatible on windows and linux. Linux needs validation, tested on windows.

image

LeoBorcherding added a commit to LeoBorcherding/unsloth that referenced this pull request May 29, 2026
LeoBorcherding and others added 3 commits May 29, 2026 21:07
…s and Linux fallback paths

- Windows: GPU utilization via \GPU Engine(*engtype_3D*)\Utilization Percentage perf counter
- Windows: temperature and power via ADL (atiadlxx.dll, ships with Adrenalin)
- Linux: GPU utilization via DRM sysfs gpu_busy_percent
- Linux: temperature via hwmon temp1_input (millidegrees C)
- Linux: power via hwmon power1_average / power1_input (microwatts)

All paths are no-op fallbacks (None) when the source is unavailable.
Mirrors what nvidia-smi provides on the CUDA path.
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

1 similar comment
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

1 similar comment
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen danielhanchen merged commit b6d5636 into unslothai:main May 30, 2026
43 of 44 checks passed
danielhanchen pushed a commit to unslothai/unsloth-zoo that referenced this pull request May 30, 2026
* fix: guard torch.distributed attrs missing on ROCm Windows

On Windows, the ROCm build of PyTorch ships without the distributed
C extension (torch._C._distributed_c10d), so torch.distributed loads
as a partial stub with several attributes missing entirely, including
is_initialized, is_torchelastic_launched, and get_rank.

unsloth_zoo/utils.py grabbed all three at module import time with bare
attribute access, causing an AttributeError the moment the module was
imported -- even in code paths like GGUF export that never use
distributed features at all.

Fix: replace the three bare grabs with getattr(..., lambda: False/0)
fallbacks so importing unsloth_zoo never crashes on platforms where
torch.distributed is unavailable or stubbed.

Root cause tracked in ROCm/TheRock#3284 (libuv / torch.distributed
missing on Windows ROCm builds). Companion fix for the export subprocess
in unsloth/unsloth#5301.

Ref: ROCm/TheRock#3284
Ref: unslothai/unsloth#5301

* fix: stub torchao on Windows ROCm to unblock LoRA training

On Windows, the ROCm PyTorch build ships without the full
torch.distributed C-extension stack (torch._C._distributed_c10d,
DeviceMesh, ProcessGroup, etc.).  torchao imports the entire distributed
chain at module level, so `import torchao` crashes even in code paths
that never touch distributed features (plain LoRA / GRPO training).

Fix: if torchao can't be imported, install a sys.meta_path hook
(_ROCmTorchaoFinder) that intercepts every "torchao" and "torchao.*"
import and returns a self-contained stub:

* Sub-module imports (torchao.dtypes, torchao.quantization …) get proper
  ModuleType stubs registered in sys.modules so the import machinery is
  satisfied.
* Direct attribute access on a stub (e.g. AffineQuantizedTensor) returns
  a sentinel class created via _ROCmSentinelMeta, which is a real Python
  type.  This makes isinstance(weight, AffineQuantizedTensor) return False
  (as expected -- no weight is ever an instance of the sentinel) instead of
  raising TypeError.
* Sentinel classes are chainable via metaclass __getattr__, so patterns
  like torchao.quantization.Float8WeightOnlyConfig resolve cleanly.

The hook is only installed when `import torchao` actually fails, so
Linux / CUDA environments are completely unaffected.

Companion to the torch.distributed attribute-guard in unsloth_zoo/utils.py
(commit 115d849) and the torchao export-subprocess stub in
unsloth/unsloth#5301.

Ref: ROCm/TheRock#3284

* fix: accept *args/**kwargs in torch.distributed fallback lambdas

The three no-op lambdas used as getattr fallbacks when torch.distributed
attrs are missing on Windows ROCm only accepted zero arguments. Any caller
passing a group argument (e.g. get_rank(group=...)) would hit a TypeError.

Use lambda *args, **kwargs instead, and switch to the already-imported
dist alias for consistency.

Suggested by Gemini code review on PR #703.

* Optimize torch.distributed attr binding for PR #703

Bind dist.is_initialized/is_torchelastic_launched/get_rank directly in a
try/except instead of per-name getattr, so the common path allocates no
fallback lambdas; only the ROCm Windows stub hits AttributeError.

---------

Co-authored-by: danielhanchen <michaelhan2050@gmail.com>
shaneholloman pushed a commit to shaneholloman/unsloth that referenced this pull request May 30, 2026
Follow-up cleanups to the merged AMD ROCm support PR unslothai#5301:

1. De-duplicate the torchao Windows-ROCm import stub into a single shared
   module (studio/backend/core/_torchao_stub.py); both workers call one
   install_torchao_windows_rocm_stub() entrypoint.
2. Align the gfx name/arch comment columns in setup.sh and setup.ps1.
3. Isolate the float16 dtype fallback to AMD without native bf16; NVIDIA
   keeps dtype=None so unsloth's own bf16/fp16/FORCE_FLOAT32 detection is
   honored.
4. Hoist unconditional stdlib imports (gc, glob, re, subprocess, copy,
   types, sys, importlib.metadata) from function bodies to module top
   across the PR unslothai#5301-touched files; heavy/optional/relative imports stay
   lazy.
5. bitsandbytes Windows-ROCm install now uses plain pip (force_pip=True)
   instead of UV_SKIP_WHEEL_FILENAME_CHECK, per the AMD hackathon docs.

Also adds scripts/verify_import_hoist.py (a scope-aware LEGB AST resolver
that catches dangling-alias and rename-clash bugs in import-hoist
refactors) and wires it into the Lint CI source-lint job as a self-test
plus a pull_request compare gate.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants