Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions studio/install_python_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,102 @@

IS_WINDOWS = sys.platform == "win32"

# ── ROCm / AMD GPU support ─────────────────────────────────────────────────────
_ROCM_TORCH_WHL: dict[tuple[int, int], str] = {
(7, 1): "rocm7.1",
(7, 0): "rocm7.0",
(6, 3): "rocm6.3",
(6, 2): "rocm6.2.4",
(6, 1): "rocm6.1",
(6, 0): "rocm6.0",
}
_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl"


def _rocm_version() -> tuple[int, int] | None:
"""Return (major, minor) of the installed ROCm stack, or None if absent."""
rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm")
for path in (
os.path.join(rocm_root, ".info", "version"),
os.path.join(rocm_root, "lib", "rocm_version"),
):
try:
parts = open(path).read().strip().split("-")[0].split(".")
return int(parts[0]), int(parts[1])
except Exception:
pass
return None


def _ensure_rocm_torch() -> None:
"""Reinstall torch with ROCm wheels when the venv received CPU-only torch.

Triggered only on Linux hosts where ROCm is actually installed (checks
/opt/rocm or ROCM_PATH). No-op on Windows, NVIDIA-only hosts, and when
torch already links against HIP (i.e. ROCm wheels already in place).
Runs after base packages so torch is present before we inspect it.
"""
rocm_root = os.environ.get("ROCM_PATH", "/opt/rocm")
if not os.path.isdir(rocm_root) and not shutil.which("hipcc"):
return # no ROCm toolchain → nothing to do

ver = _rocm_version()
if ver is None:
print(" ROCm detected but version unreadable — skipping torch reinstall")
return

# Skip if torch is already GPU-enabled: HIP (AMD ROCm) or CUDA (NVIDIA).
# Checking both prevents accidentally overwriting CUDA torch on hosts where
# /opt/rocm also exists (e.g. mixed CUDA+ROCm development machines).
probe = subprocess.run(
[
sys.executable,
"-c",
"import torch; print(torch.version.hip or torch.version.cuda or '')",
],
stdout = subprocess.PIPE,
stderr = subprocess.DEVNULL,
)
if probe.returncode == 0 and probe.stdout.decode().strip():
return # torch already GPU-enabled (HIP or CUDA) — nothing to reinstall

# Select the best matching wheel tag (≤ installed ROCm, newest first)
tag = next(
(t for (maj, mn), t in _ROCM_TORCH_WHL.items() if ver >= (maj, mn)),
None,
)
if tag is None:
print(f" No PyTorch wheel for ROCm {ver[0]}.{ver[1]} — skipping")
return

index_url = f"{_PYTORCH_WHL_BASE}/{tag}"
print(f" ROCm {ver[0]}.{ver[1]} — installing torch from {index_url}")
result = subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"--force-reinstall",
"--no-cache-dir",
"torch",
"torchvision",
"torchaudio",
"--index-url",
index_url,
],
stdout = subprocess.PIPE,
stderr = subprocess.STDOUT,
)
if result.returncode == 0:
print(" ROCm torch installed")
else:
print(" ROCm torch reinstall failed — training may run on CPU")
if VERBOSE:
print(result.stdout.decode(errors = "replace"))


# ── Verbosity control ──────────────────────────────────────────────────────────
# -- Verbosity control ----------------------------------------------------------
# By default the installer shows a minimal progress bar (one line, in-place).
# Set UNSLOTH_VERBOSE=1 in the environment to restore full per-step output:
Expand Down Expand Up @@ -341,6 +437,8 @@ def patch_package_file(package_name: str, relative_path: str, url: str) -> None:
def install_python_stack() -> int:
global USE_UV, _STEP, _TOTAL
_STEP = 0
# Step count: shared(9) + triton(non-Win) + ROCm-torch-check(non-Win) + finalize(1)
_TOTAL = 10 if IS_WINDOWS else 12

# When called from install.sh (which already installed unsloth into the venv),
# SKIP_STUDIO_BASE=1 is set to avoid redundant reinstallation of base packages.
Expand Down Expand Up @@ -444,6 +542,13 @@ def install_python_stack() -> int:
req = REQ_ROOT / "base.txt",
)

# 2b. AMD ROCm: reinstall torch with HIP wheels if the host has ROCm but the
# venv received CPU-only torch (common when pip resolves torch from PyPI).
# Must come immediately after base packages so torch is present for inspection.
if not IS_WINDOWS:
_progress("ROCm torch")
_ensure_rocm_torch()

# 3. Extra dependencies
_progress("unsloth extras")
pip_install(
Expand Down