From 7018eab1de5848f14f6f469c226807b8c6af3dea Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Thu, 2 Apr 2026 13:32:19 -0700 Subject: [PATCH] Support AMD Radeon for studio --- install.sh | 151 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 3 deletions(-) diff --git a/install.sh b/install.sh index 56f6e8e9de..56dc24a5d8 100755 --- a/install.sh +++ b/install.sh @@ -1055,8 +1055,109 @@ get_torch_index_url() { elif [ "$_major" -ge 11 ]; then echo "$_base/cu118" else echo "$_base/cpu"; fi } + +get_radeon_wheel_url() { + # Only meaningful on Linux. Picks a repo.radeon.com base URL whose listing + # contains torch wheels. Tries paths like rocm-rel-7.2.1/, rocm-rel-7.2/, + # rocm-rel-7.1.1/, rocm-rel-7.1/ (AMD publishes both M.m and M.m.p dirs). + case "$(uname -s)" in Linux) ;; *) echo ""; return ;; esac + + # Detect full X.Y.Z version -- try amd-smi first, then /opt/rocm/.info/version, then hipconfig + _full_ver="" + _full_ver=$({ command -v amd-smi >/dev/null 2>&1 && \ + amd-smi version 2>/dev/null | awk -F'ROCm version: ' \ + 'NF>1{if(match($2,/[0-9]+\.[0-9]+\.[0-9]+/)){print substr($2,RSTART,RLENGTH); ok=1; exit}} END{exit !ok}'; } || \ + { [ -r /opt/rocm/.info/version ] && \ + awk -F'[.-]' 'NF>=3{print $1"."$2"."$3; exit}' /opt/rocm/.info/version; } || \ + { command -v hipconfig >/dev/null 2>&1 && \ + hipconfig --version 2>/dev/null | awk 'NR==1 && /^[0-9]+\.[0-9]+\.[0-9]/{print $1}'; }) 2>/dev/null + + # Validate: must be X.Y.Z with X >= 1 + case "$_full_ver" in + [1-9]*.*[0-9].*[0-9]*) : ;; + *) echo ""; return ;; + esac + echo "https://repo.radeon.com/rocm/manylinux/rocm-rel-${_full_ver}/" +} + +# ── Radeon repo wheel selection helpers ────────────────────────────────────── +# Fetches the Radeon repo directory listing once into _RADEON_LISTING (global). +# _RADEON_PYTAG holds the CPython tag for the running interpreter (e.g. cp312). +# _RADEON_BASE_URL holds the base URL for relative-href resolution. +_RADEON_LISTING="" +_RADEON_PYTAG="" +_RADEON_BASE_URL="" + +_radeon_fetch_listing() { + # Usage: _radeon_fetch_listing BASE_URL + # Populates _RADEON_LISTING, _RADEON_PYTAG, _RADEON_BASE_URL. + _RADEON_BASE_URL="$1" + _RADEON_PYTAG=$("$_VENV_PY" -c " +import sys +print('cp{}{}'.format(sys.version_info.major, sys.version_info.minor)) +" 2>/dev/null) || return 1 + if command -v curl >/dev/null 2>&1; then + _RADEON_LISTING=$(curl -fsSL --max-time 20 "$_RADEON_BASE_URL" 2>/dev/null) + elif command -v wget >/dev/null 2>&1; then + _RADEON_LISTING=$(wget -qO- --timeout=20 "$_RADEON_BASE_URL" 2>/dev/null) + fi + [ -n "$_RADEON_LISTING" ] || return 1 +} + +_pick_radeon_wheel() { + # Usage: _pick_radeon_wheel PACKAGE_NAME + # Scans $_RADEON_LISTING for the newest wheel whose filename starts exactly + # with PACKAGE_NAME- and matches _RADEON_PYTAG + linux_x86_64. + # Prints the full URL (resolving relative hrefs against _RADEON_BASE_URL). + _pkg="$1" + [ -n "$_RADEON_LISTING" ] || return 1 + [ -n "$_RADEON_PYTAG" ] || return 1 + _tag="$_RADEON_PYTAG" + _href=$(printf '%s\n' "$_RADEON_LISTING" \ + | grep -o 'href="[^"]*"' \ + | sed 's/href="//;s/"//' \ + | awk -F/ -v pkg="$_pkg" -v tag="$_tag" ' + { + base = $NF + sub(/[?#].*/, "", base) # strip query / fragment + prefix = pkg "-" + suffix = "-" tag "-" tag "-linux_x86_64.whl" + if (substr(base, 1, length(prefix)) == prefix && + substr(base, length(base) - length(suffix) + 1) == suffix) + print $0 + }' \ + | sort -V \ + | tail -1) + [ -z "$_href" ] && return 1 + case "$_href" in + http*) printf '%s\n' "$_href" ;; + *) printf '%s\n' "${_RADEON_BASE_URL%/}/${_href#/}" ;; + esac +} + TORCH_INDEX_URL=$(get_torch_index_url) +# Auto-detect GPU for AMD ROCm based +# get_torch_index_url must have chosen */rocm* +# (gfx in rocminfo or amd-smi list). Then require rocminfo "Marketing Name:.*Radeon". +case "$TORCH_INDEX_URL" in + */rocm*) + _amd_gpu_here=false + _amd_gpu_radeon=false + if command -v rocminfo >/dev/null 2>&1 && \ + rocminfo 2>/dev/null | awk '/Name:[[:space:]]*gfx[0-9]/{found=1} END{exit !found}'; then + _amd_gpu_here=true + elif command -v amd-smi >/dev/null 2>&1 && \ + amd-smi list 2>/dev/null | awk 'NR>1 && NF{found=1} END{exit !found}'; then + _amd_gpu_here=true + fi + if [ "$_amd_gpu_here" = true ] && command -v rocminfo >/dev/null 2>&1 && \ + rocminfo 2>/dev/null | grep -q 'Marketing Name:.*Radeon'; then + _amd_gpu_radeon=true + fi + ;; +esac + # ── Print CPU-only hint when no GPU detected ── case "$TORCH_INDEX_URL" in */cpu) @@ -1072,7 +1173,11 @@ case "$TORCH_INDEX_URL" in ;; */rocm*) echo "" - echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)" + if [ "$_amd_gpu_radeon" = true ]; then + echo " AMD Radeon + ROCm detected -- installing PyTorch wheels from repo.radeon.com" + else + echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)" + fi echo "" ;; esac @@ -1108,6 +1213,46 @@ elif [ -n "$TORCH_INDEX_URL" ]; then # Fresh: Step 1 - install torch from explicit index (skip when --no-torch or Intel Mac) if [ "$SKIP_TORCH" = true ]; then substep "skipping PyTorch (--no-torch or Intel Mac x86_64)." "$C_WARN" + elif [ "$_amd_gpu_radeon" = true ]; then + _radeon_url=$(get_radeon_wheel_url) + if [ -n "$_radeon_url" ]; then + _radeon_listing_ok=false + if _radeon_fetch_listing "$_radeon_url" 2>/dev/null; then + _radeon_listing_ok=true + else + # Try shorter X.Y path (AMD publishes both X.Y.Z and X.Y dirs) + _radeon_url_short=$(printf '%s\n' "$_radeon_url" \ + | sed 's|rocm-rel-\([0-9]*\)\.\([0-9]*\)\.[0-9]*/|rocm-rel-\1.\2/|') + if [ "$_radeon_url_short" != "$_radeon_url" ] && \ + _radeon_fetch_listing "$_radeon_url_short" 2>/dev/null; then + _radeon_listing_ok=true + fi + fi + + if [ "$_radeon_listing_ok" = true ]; then + substep "installing PyTorch from Radeon repo (${_RADEON_BASE_URL})..." + _torch_arg="torch"; _tv_arg="torchvision"; _ta_arg="torchaudio"; _tri_arg="" + _torch_whl=$(_pick_radeon_wheel "torch" 2>/dev/null) && _torch_arg="$_torch_whl" + _tv_whl=$(_pick_radeon_wheel "torchvision" 2>/dev/null) && _tv_arg="$_tv_whl" + _ta_whl=$(_pick_radeon_wheel "torchaudio" 2>/dev/null) && _ta_arg="$_ta_whl" + _tri_whl=$(_pick_radeon_wheel "triton" 2>/dev/null) && _tri_arg="$_tri_whl" + run_install_cmd "install triton + PyTorch" uv pip install --python "$_VENV_PY" \ + --find-links "$_RADEON_BASE_URL" \ + "$_tri_arg" "$_torch_arg" "$_tv_arg" "$_ta_arg" + substep "installing bitsandbytes for AMD Radeon..." + run_install_cmd "install bitsandbytes (AMD)" uv pip install --python "$_VENV_PY" \ + "bitsandbytes>=0.49.1" + else + substep "[WARN] Radeon repo unavailable; falling back to CPU-only PyTorch" "$C_WARN" + run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ + "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ + --index-url "${TORCH_INDEX_URL%/*}/cpu" + fi + else + substep "[WARN] Radeon GPU detected but could not detect full ROCm version; falling back to CPU-only PyTorch" "$C_WARN" + run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ + --index-url "${TORCH_INDEX_URL%/*}/cpu" + fi else substep "installing PyTorch ($TORCH_INDEX_URL)..." run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ @@ -1121,7 +1266,7 @@ elif [ -n "$TORCH_INDEX_URL" ]; then esac fi # Fresh: Step 2 - install unsloth, preserving pre-installed torch - substep "installing unsloth (this may take a few minutes)..." + substep "installing unsloth (this may take a few minutes)..." if [ "$SKIP_TORCH" = true ]; then # No-torch: install unsloth + unsloth-zoo with --no-deps, then # runtime deps (typer, safetensors, transformers, etc.) with --no-deps. @@ -1292,4 +1437,4 @@ else substep "source ${VENV_DIR}/bin/activate" substep "unsloth studio -H 0.0.0.0 -p 8888" echo "" -fi +fi \ No newline at end of file