Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 148 additions & 3 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1055,8 +1055,109 @@ get_torch_index_url() {
elif [ "$_major" -ge 11 ]; then echo "$_base/cu118"
else echo "$_base/cpu"; fi
}

get_radeon_wheel_url() {
# Only meaningful on Linux. Picks a repo.radeon.com base URL whose listing
# contains torch wheels. Tries paths like rocm-rel-7.2.1/, rocm-rel-7.2/,
# rocm-rel-7.1.1/, rocm-rel-7.1/ (AMD publishes both M.m and M.m.p dirs).
case "$(uname -s)" in Linux) ;; *) echo ""; return ;; esac

# Detect full X.Y.Z version -- try amd-smi first, then /opt/rocm/.info/version, then hipconfig
_full_ver=""
_full_ver=$({ command -v amd-smi >/dev/null 2>&1 && \
amd-smi version 2>/dev/null | awk -F'ROCm version: ' \
'NF>1{if(match($2,/[0-9]+\.[0-9]+\.[0-9]+/)){print substr($2,RSTART,RLENGTH); ok=1; exit}} END{exit !ok}'; } || \
{ [ -r /opt/rocm/.info/version ] && \
awk -F'[.-]' 'NF>=3{print $1"."$2"."$3; exit}' /opt/rocm/.info/version; } || \
{ command -v hipconfig >/dev/null 2>&1 && \
hipconfig --version 2>/dev/null | awk 'NR==1 && /^[0-9]+\.[0-9]+\.[0-9]/{print $1}'; }) 2>/dev/null

# Validate: must be X.Y.Z with X >= 1
case "$_full_ver" in
[1-9]*.*[0-9].*[0-9]*) : ;;
*) echo ""; return ;;
esac
echo "https://repo.radeon.com/rocm/manylinux/rocm-rel-${_full_ver}/"
}

# ── Radeon repo wheel selection helpers ──────────────────────────────────────
# Fetches the Radeon repo directory listing once into _RADEON_LISTING (global).
# _RADEON_PYTAG holds the CPython tag for the running interpreter (e.g. cp312).
# _RADEON_BASE_URL holds the base URL for relative-href resolution.
_RADEON_LISTING=""
_RADEON_PYTAG=""
_RADEON_BASE_URL=""

_radeon_fetch_listing() {
# Usage: _radeon_fetch_listing BASE_URL
# Populates _RADEON_LISTING, _RADEON_PYTAG, _RADEON_BASE_URL.
_RADEON_BASE_URL="$1"
_RADEON_PYTAG=$("$_VENV_PY" -c "
import sys
print('cp{}{}'.format(sys.version_info.major, sys.version_info.minor))
" 2>/dev/null) || return 1
if command -v curl >/dev/null 2>&1; then
_RADEON_LISTING=$(curl -fsSL --max-time 20 "$_RADEON_BASE_URL" 2>/dev/null)
elif command -v wget >/dev/null 2>&1; then
_RADEON_LISTING=$(wget -qO- --timeout=20 "$_RADEON_BASE_URL" 2>/dev/null)
fi
[ -n "$_RADEON_LISTING" ] || return 1
}

_pick_radeon_wheel() {
# Usage: _pick_radeon_wheel PACKAGE_NAME
# Scans $_RADEON_LISTING for the newest wheel whose filename starts exactly
# with PACKAGE_NAME- and matches _RADEON_PYTAG + linux_x86_64.
# Prints the full URL (resolving relative hrefs against _RADEON_BASE_URL).
_pkg="$1"
[ -n "$_RADEON_LISTING" ] || return 1
[ -n "$_RADEON_PYTAG" ] || return 1
_tag="$_RADEON_PYTAG"
_href=$(printf '%s\n' "$_RADEON_LISTING" \
| grep -o 'href="[^"]*"' \
| sed 's/href="//;s/"//' \
| awk -F/ -v pkg="$_pkg" -v tag="$_tag" '
{
base = $NF
sub(/[?#].*/, "", base) # strip query / fragment
prefix = pkg "-"
suffix = "-" tag "-" tag "-linux_x86_64.whl"
if (substr(base, 1, length(prefix)) == prefix &&
substr(base, length(base) - length(suffix) + 1) == suffix)
print $0
}' \
| sort -V \
| tail -1)
[ -z "$_href" ] && return 1
case "$_href" in
http*) printf '%s\n' "$_href" ;;
*) printf '%s\n' "${_RADEON_BASE_URL%/}/${_href#/}" ;;
esac
}

TORCH_INDEX_URL=$(get_torch_index_url)

# Auto-detect GPU for AMD ROCm based
# get_torch_index_url must have chosen */rocm*
# (gfx in rocminfo or amd-smi list). Then require rocminfo "Marketing Name:.*Radeon".
case "$TORCH_INDEX_URL" in
*/rocm*)
_amd_gpu_here=false
_amd_gpu_radeon=false
if command -v rocminfo >/dev/null 2>&1 && \
rocminfo 2>/dev/null | awk '/Name:[[:space:]]*gfx[0-9]/{found=1} END{exit !found}'; then
_amd_gpu_here=true
elif command -v amd-smi >/dev/null 2>&1 && \
amd-smi list 2>/dev/null | awk 'NR>1 && NF{found=1} END{exit !found}'; then
_amd_gpu_here=true
fi
if [ "$_amd_gpu_here" = true ] && command -v rocminfo >/dev/null 2>&1 && \
rocminfo 2>/dev/null | grep -q 'Marketing Name:.*Radeon'; then
_amd_gpu_radeon=true
fi
;;
esac

# ── Print CPU-only hint when no GPU detected ──
case "$TORCH_INDEX_URL" in
*/cpu)
Expand All @@ -1072,7 +1173,11 @@ case "$TORCH_INDEX_URL" in
;;
*/rocm*)
echo ""
echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)"
if [ "$_amd_gpu_radeon" = true ]; then
echo " AMD Radeon + ROCm detected -- installing PyTorch wheels from repo.radeon.com"
else
echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)"
fi
echo ""
;;
esac
Expand Down Expand Up @@ -1108,6 +1213,46 @@ elif [ -n "$TORCH_INDEX_URL" ]; then
# Fresh: Step 1 - install torch from explicit index (skip when --no-torch or Intel Mac)
if [ "$SKIP_TORCH" = true ]; then
substep "skipping PyTorch (--no-torch or Intel Mac x86_64)." "$C_WARN"
elif [ "$_amd_gpu_radeon" = true ]; then
_radeon_url=$(get_radeon_wheel_url)
if [ -n "$_radeon_url" ]; then
_radeon_listing_ok=false
if _radeon_fetch_listing "$_radeon_url" 2>/dev/null; then
_radeon_listing_ok=true
else
# Try shorter X.Y path (AMD publishes both X.Y.Z and X.Y dirs)
_radeon_url_short=$(printf '%s\n' "$_radeon_url" \
| sed 's|rocm-rel-\([0-9]*\)\.\([0-9]*\)\.[0-9]*/|rocm-rel-\1.\2/|')
if [ "$_radeon_url_short" != "$_radeon_url" ] && \
_radeon_fetch_listing "$_radeon_url_short" 2>/dev/null; then
_radeon_listing_ok=true
fi
fi

if [ "$_radeon_listing_ok" = true ]; then
substep "installing PyTorch from Radeon repo (${_RADEON_BASE_URL})..."
_torch_arg="torch"; _tv_arg="torchvision"; _ta_arg="torchaudio"; _tri_arg=""
_torch_whl=$(_pick_radeon_wheel "torch" 2>/dev/null) && _torch_arg="$_torch_whl"
_tv_whl=$(_pick_radeon_wheel "torchvision" 2>/dev/null) && _tv_arg="$_tv_whl"
_ta_whl=$(_pick_radeon_wheel "torchaudio" 2>/dev/null) && _ta_arg="$_ta_whl"
Comment on lines +1234 to +1237

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep PyTorch constraints when Radeon wheel lookup fails

When _pick_radeon_wheel cannot find a matching wheel (for example, the Radeon repo has no wheel for the current cpXY tag), these defaults leave torch, torchvision, and torchaudio unconstrained, so uv pip install can resolve to latest PyPI builds instead of the installer’s intended <2.11.0 range. That can pull unsupported or non-ROCm builds and break the environment on Radeon machines, whereas the non-Radeon path still enforces version bounds.

Useful? React with 👍 / 👎.

_tri_whl=$(_pick_radeon_wheel "triton" 2>/dev/null) && _tri_arg="$_tri_whl"
run_install_cmd "install triton + PyTorch" uv pip install --python "$_VENV_PY" \
--find-links "$_RADEON_BASE_URL" \
"$_tri_arg" "$_torch_arg" "$_tv_arg" "$_ta_arg"
substep "installing bitsandbytes for AMD Radeon..."
run_install_cmd "install bitsandbytes (AMD)" uv pip install --python "$_VENV_PY" \
"bitsandbytes>=0.49.1"
else
substep "[WARN] Radeon repo unavailable; falling back to CPU-only PyTorch" "$C_WARN"
run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \
"torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \
--index-url "${TORCH_INDEX_URL%/*}/cpu"
Comment on lines +1246 to +1249

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fall back to ROCm index when Radeon repo is unavailable

When Radeon wheel discovery fails, this branch forces a CPU-only install (.../whl/cpu) even though ROCm GPU detection already succeeded and TORCH_INDEX_URL points to a usable ROCm index. In environments where repo.radeon.com is temporarily unreachable or its listing format changes, Radeon users are silently downgraded to CPU PyTorch instead of retaining GPU acceleration, which is a major regression from the previous ROCm path.

Useful? React with 👍 / 👎.

fi
else
substep "[WARN] Radeon GPU detected but could not detect full ROCm version; falling back to CPU-only PyTorch" "$C_WARN"
run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \
--index-url "${TORCH_INDEX_URL%/*}/cpu"
fi
else
substep "installing PyTorch ($TORCH_INDEX_URL)..."
run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \
Expand All @@ -1121,7 +1266,7 @@ elif [ -n "$TORCH_INDEX_URL" ]; then
esac
fi
# Fresh: Step 2 - install unsloth, preserving pre-installed torch
substep "installing unsloth (this may take a few minutes)..."
substep "installing unsloth (this may take a few minutes)..."
if [ "$SKIP_TORCH" = true ]; then
# No-torch: install unsloth + unsloth-zoo with --no-deps, then
# runtime deps (typer, safetensors, transformers, etc.) with --no-deps.
Expand Down Expand Up @@ -1292,4 +1437,4 @@ else
substep "source ${VENV_DIR}/bin/activate"
substep "unsloth studio -H 0.0.0.0 -p 8888"
echo ""
fi
fi