From e5d58f5151635cec343b2d7d40e462fec295036b Mon Sep 17 00:00:00 2001 From: GoldenGrapeGentleman Date: Wed, 18 Mar 2026 21:17:26 -0500 Subject: [PATCH 1/2] fix(studio): ensure ROCm-enabled torch in venv on AMD hosts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pip resolves torch from PyPI during base package installation, pulling CPU-only wheels regardless of the host GPU. AMD ROCm users end up with a venv that cannot use their GPU for training. Add _ensure_rocm_torch() which runs immediately after base packages: - detects ROCm via $ROCM_PATH / /opt/rocm / hipcc - reads the installed version from /opt/rocm/.info/version - maps (major, minor) to the correct PyTorch wheel index via tuple comparison - skips if torch is already GPU-enabled (checks both torch.version.hip and torch.version.cuda to avoid clobbering CUDA torch on mixed hosts) - force-reinstalls torch + torchvision + torchaudio from the matched index URL Tested on 8×AMD MI355X (ROCm 7.1) — version detection, wheel mapping, and no-op behaviour all verified. Fixes the issue raised by andyluo7 in #4390. Co-authored-by: billishyahao --- studio/install_python_stack.py | 93 +++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/studio/install_python_stack.py b/studio/install_python_stack.py index a141c64425..4f499272f3 100644 --- a/studio/install_python_stack.py +++ b/studio/install_python_stack.py @@ -22,6 +22,89 @@ IS_WINDOWS = sys.platform == "win32" +# ── ROCm / AMD GPU support ───────────────────────────────────────────────────── +_ROCM_TORCH_WHL: dict[tuple[int, int], str] = { + (7, 1): "rocm7.1", + (7, 0): "rocm7.0", + (6, 3): "rocm6.3", + (6, 2): "rocm6.2.4", + (6, 1): "rocm6.1", + (6, 0): "rocm6.0", +} +_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl" + + +def _rocm_version() -> tuple[int, int] | None: + """Return (major, minor) of the installed ROCm stack, or None if absent.""" + rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm") + for path in ( + os.path.join(rocm_root, ".info", "version"), + os.path.join(rocm_root, "lib", "rocm_version"), + ): + try: + parts = open(path).read().strip().split("-")[0].split(".") + return int(parts[0]), int(parts[1]) + except Exception: + pass + return None + + +def _ensure_rocm_torch() -> None: + """Reinstall torch with ROCm wheels when the venv received CPU-only torch. + + Triggered only on Linux hosts where ROCm is actually installed (checks + /opt/rocm or ROCM_PATH). No-op on Windows, NVIDIA-only hosts, and when + torch already links against HIP (i.e. ROCm wheels already in place). + Runs after base packages so torch is present before we inspect it. + """ + rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm") + if not os.path.isdir(rocm_root) and not shutil.which("hipcc"): + return # no ROCm toolchain → nothing to do + + ver = _rocm_version() + if ver is None: + print(" ROCm detected but version unreadable — skipping torch reinstall") + return + + # Skip if torch is already GPU-enabled: HIP (AMD ROCm) or CUDA (NVIDIA). + # Checking both prevents accidentally overwriting CUDA torch on hosts where + # /opt/rocm also exists (e.g. mixed CUDA+ROCm development machines). + probe = subprocess.run( + [sys.executable, "-c", + "import torch; print(torch.version.hip or torch.version.cuda or '')"], + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + ) + if probe.returncode == 0 and probe.stdout.decode().strip(): + return # torch already GPU-enabled (HIP or CUDA) — nothing to reinstall + + # Select the best matching wheel tag (≤ installed ROCm, newest first) + tag = next( + (t for (maj, mn), t in _ROCM_TORCH_WHL.items() if ver >= (maj, mn)), + None, + ) + if tag is None: + print(f" No PyTorch wheel for ROCm {ver[0]}.{ver[1]} — skipping") + return + + index_url = f"{_PYTORCH_WHL_BASE}/{tag}" + print(f" ROCm {ver[0]}.{ver[1]} — installing torch from {index_url}") + result = subprocess.run( + [ + sys.executable, "-m", "pip", "install", + "--force-reinstall", "--no-cache-dir", + "torch", "torchvision", "torchaudio", + "--index-url", index_url, + ], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + if result.returncode == 0: + print(" ROCm torch installed") + else: + print(" ROCm torch reinstall failed — training may run on CPU") + if VERBOSE: + print(result.stdout.decode(errors="replace")) + + # ── Verbosity control ────────────────────────────────────────────────────────── # By default the installer shows a minimal progress bar (one line, in-place). # Set UNSLOTH_VERBOSE=1 in the environment to restore full per-step output: @@ -291,7 +374,8 @@ def patch_package_file(package_name: str, relative_path: str, url: str) -> None: def install_python_stack() -> int: global USE_UV, _STEP, _TOTAL _STEP = 0 - _TOTAL = 10 if IS_WINDOWS else 11 + # Step count: shared(9) + triton(non-Win) + ROCm-torch-check(non-Win) + finalize(1) + _TOTAL = 10 if IS_WINDOWS else 12 # 1. Upgrade pip (needed even with uv as fallback and for bootstrapping) _progress("pip upgrade") @@ -308,6 +392,13 @@ def install_python_stack() -> int: req = REQ_ROOT / "base.txt", ) + # 2b. AMD ROCm: reinstall torch with HIP wheels if the host has ROCm but the + # venv received CPU-only torch (common when pip resolves torch from PyPI). + # Must come immediately after base packages so torch is present for inspection. + if not IS_WINDOWS: + _progress("ROCm torch") + _ensure_rocm_torch() + # 3. Extra dependencies _progress("unsloth extras") pip_install( From c9bce09351ddef1b993d9c3aee0b713c93a2d587 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 07:54:06 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/install_python_stack.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/studio/install_python_stack.py b/studio/install_python_stack.py index 4f499272f3..d73380d9ee 100644 --- a/studio/install_python_stack.py +++ b/studio/install_python_stack.py @@ -70,9 +70,13 @@ def _ensure_rocm_torch() -> None: # Checking both prevents accidentally overwriting CUDA torch on hosts where # /opt/rocm also exists (e.g. mixed CUDA+ROCm development machines). probe = subprocess.run( - [sys.executable, "-c", - "import torch; print(torch.version.hip or torch.version.cuda or '')"], - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + [ + sys.executable, + "-c", + "import torch; print(torch.version.hip or torch.version.cuda or '')", + ], + stdout = subprocess.PIPE, + stderr = subprocess.DEVNULL, ) if probe.returncode == 0 and probe.stdout.decode().strip(): return # torch already GPU-enabled (HIP or CUDA) — nothing to reinstall @@ -90,19 +94,27 @@ def _ensure_rocm_torch() -> None: print(f" ROCm {ver[0]}.{ver[1]} — installing torch from {index_url}") result = subprocess.run( [ - sys.executable, "-m", "pip", "install", - "--force-reinstall", "--no-cache-dir", - "torch", "torchvision", "torchaudio", - "--index-url", index_url, + sys.executable, + "-m", + "pip", + "install", + "--force-reinstall", + "--no-cache-dir", + "torch", + "torchvision", + "torchaudio", + "--index-url", + index_url, ], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + stdout = subprocess.PIPE, + stderr = subprocess.STDOUT, ) if result.returncode == 0: print(" ROCm torch installed") else: print(" ROCm torch reinstall failed — training may run on CPU") if VERBOSE: - print(result.stdout.decode(errors="replace")) + print(result.stdout.decode(errors = "replace")) # ── Verbosity control ──────────────────────────────────────────────────────────