diff --git a/studio/install_python_stack.py b/studio/install_python_stack.py index 2c64695241..647572165a 100644 --- a/studio/install_python_stack.py +++ b/studio/install_python_stack.py @@ -22,6 +22,102 @@ IS_WINDOWS = sys.platform == "win32" +# ── ROCm / AMD GPU support ───────────────────────────────────────────────────── +_ROCM_TORCH_WHL: dict[tuple[int, int], str] = { + (7, 1): "rocm7.1", + (7, 0): "rocm7.0", + (6, 3): "rocm6.3", + (6, 2): "rocm6.2.4", + (6, 1): "rocm6.1", + (6, 0): "rocm6.0", +} +_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl" + + +def _rocm_version() -> tuple[int, int] | None: + """Return (major, minor) of the installed ROCm stack, or None if absent.""" + rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm") + for path in ( + os.path.join(rocm_root, ".info", "version"), + os.path.join(rocm_root, "lib", "rocm_version"), + ): + try: + parts = open(path).read().strip().split("-")[0].split(".") + return int(parts[0]), int(parts[1]) + except Exception: + pass + return None + + +def _ensure_rocm_torch() -> None: + """Reinstall torch with ROCm wheels when the venv received CPU-only torch. + + Triggered only on Linux hosts where ROCm is actually installed (checks + /opt/rocm or ROCM_PATH). No-op on Windows, NVIDIA-only hosts, and when + torch already links against HIP (i.e. ROCm wheels already in place). + Runs after base packages so torch is present before we inspect it. + """ + rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm") + if not os.path.isdir(rocm_root) and not shutil.which("hipcc"): + return # no ROCm toolchain → nothing to do + + ver = _rocm_version() + if ver is None: + print(" ROCm detected but version unreadable — skipping torch reinstall") + return + + # Skip if torch is already GPU-enabled: HIP (AMD ROCm) or CUDA (NVIDIA). + # Checking both prevents accidentally overwriting CUDA torch on hosts where + # /opt/rocm also exists (e.g. mixed CUDA+ROCm development machines). + probe = subprocess.run( + [ + sys.executable, + "-c", + "import torch; print(torch.version.hip or torch.version.cuda or '')", + ], + stdout = subprocess.PIPE, + stderr = subprocess.DEVNULL, + ) + if probe.returncode == 0 and probe.stdout.decode().strip(): + return # torch already GPU-enabled (HIP or CUDA) — nothing to reinstall + + # Select the best matching wheel tag (≤ installed ROCm, newest first) + tag = next( + (t for (maj, mn), t in _ROCM_TORCH_WHL.items() if ver >= (maj, mn)), + None, + ) + if tag is None: + print(f" No PyTorch wheel for ROCm {ver[0]}.{ver[1]} — skipping") + return + + index_url = f"{_PYTORCH_WHL_BASE}/{tag}" + print(f" ROCm {ver[0]}.{ver[1]} — installing torch from {index_url}") + result = subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "--force-reinstall", + "--no-cache-dir", + "torch", + "torchvision", + "torchaudio", + "--index-url", + index_url, + ], + stdout = subprocess.PIPE, + stderr = subprocess.STDOUT, + ) + if result.returncode == 0: + print(" ROCm torch installed") + else: + print(" ROCm torch reinstall failed — training may run on CPU") + if VERBOSE: + print(result.stdout.decode(errors = "replace")) + + +# ── Verbosity control ────────────────────────────────────────────────────────── # -- Verbosity control ---------------------------------------------------------- # By default the installer shows a minimal progress bar (one line, in-place). # Set UNSLOTH_VERBOSE=1 in the environment to restore full per-step output: @@ -341,6 +437,8 @@ def patch_package_file(package_name: str, relative_path: str, url: str) -> None: def install_python_stack() -> int: global USE_UV, _STEP, _TOTAL _STEP = 0 + # Step count: shared(9) + triton(non-Win) + ROCm-torch-check(non-Win) + finalize(1) + _TOTAL = 10 if IS_WINDOWS else 12 # When called from install.sh (which already installed unsloth into the venv), # SKIP_STUDIO_BASE=1 is set to avoid redundant reinstallation of base packages. @@ -444,6 +542,13 @@ def install_python_stack() -> int: req = REQ_ROOT / "base.txt", ) + # 2b. AMD ROCm: reinstall torch with HIP wheels if the host has ROCm but the + # venv received CPU-only torch (common when pip resolves torch from PyPI). + # Must come immediately after base packages so torch is present for inspection. + if not IS_WINDOWS: + _progress("ROCm torch") + _ensure_rocm_torch() + # 3. Extra dependencies _progress("unsloth extras") pip_install(