-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add Radeon detection and download qualified .whls from Radeon repo #4770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1055,8 +1055,109 @@ get_torch_index_url() { | |
| elif [ "$_major" -ge 11 ]; then echo "$_base/cu118" | ||
| else echo "$_base/cpu"; fi | ||
| } | ||
|
|
||
| get_radeon_wheel_url() { | ||
| # Only meaningful on Linux. Picks a repo.radeon.com base URL whose listing | ||
| # contains torch wheels. Tries paths like rocm-rel-7.2.1/, rocm-rel-7.2/, | ||
| # rocm-rel-7.1.1/, rocm-rel-7.1/ (AMD publishes both M.m and M.m.p dirs). | ||
| case "$(uname -s)" in Linux) ;; *) echo ""; return ;; esac | ||
|
|
||
| # Detect full X.Y.Z version -- try amd-smi first, then /opt/rocm/.info/version, then hipconfig | ||
| _full_ver="" | ||
| _full_ver=$({ command -v amd-smi >/dev/null 2>&1 && \ | ||
| amd-smi version 2>/dev/null | awk -F'ROCm version: ' \ | ||
| 'NF>1{if(match($2,/[0-9]+\.[0-9]+\.[0-9]+/)){print substr($2,RSTART,RLENGTH); ok=1; exit}} END{exit !ok}'; } || \ | ||
| { [ -r /opt/rocm/.info/version ] && \ | ||
| awk -F'[.-]' 'NF>=3{print $1"."$2"."$3; exit}' /opt/rocm/.info/version; } || \ | ||
| { command -v hipconfig >/dev/null 2>&1 && \ | ||
| hipconfig --version 2>/dev/null | awk 'NR==1 && /^[0-9]+\.[0-9]+\.[0-9]/{print $1}'; }) 2>/dev/null | ||
|
|
||
| # Validate: must be X.Y.Z with X >= 1 | ||
| case "$_full_ver" in | ||
| [1-9]*.*[0-9].*[0-9]*) : ;; | ||
| *) echo ""; return ;; | ||
| esac | ||
| echo "https://repo.radeon.com/rocm/manylinux/rocm-rel-${_full_ver}/" | ||
| } | ||
|
|
||
| # ── Radeon repo wheel selection helpers ────────────────────────────────────── | ||
| # Fetches the Radeon repo directory listing once into _RADEON_LISTING (global). | ||
| # _RADEON_PYTAG holds the CPython tag for the running interpreter (e.g. cp312). | ||
| # _RADEON_BASE_URL holds the base URL for relative-href resolution. | ||
| _RADEON_LISTING="" | ||
| _RADEON_PYTAG="" | ||
| _RADEON_BASE_URL="" | ||
|
|
||
| _radeon_fetch_listing() { | ||
| # Usage: _radeon_fetch_listing BASE_URL | ||
| # Populates _RADEON_LISTING, _RADEON_PYTAG, _RADEON_BASE_URL. | ||
| _RADEON_BASE_URL="$1" | ||
| _RADEON_PYTAG=$("$_VENV_PY" -c " | ||
| import sys | ||
| print('cp{}{}'.format(sys.version_info.major, sys.version_info.minor)) | ||
| " 2>/dev/null) || return 1 | ||
| if command -v curl >/dev/null 2>&1; then | ||
| _RADEON_LISTING=$(curl -fsSL --max-time 20 "$_RADEON_BASE_URL" 2>/dev/null) | ||
| elif command -v wget >/dev/null 2>&1; then | ||
| _RADEON_LISTING=$(wget -qO- --timeout=20 "$_RADEON_BASE_URL" 2>/dev/null) | ||
| fi | ||
| [ -n "$_RADEON_LISTING" ] || return 1 | ||
| } | ||
|
|
||
| _pick_radeon_wheel() { | ||
| # Usage: _pick_radeon_wheel PACKAGE_NAME | ||
| # Scans $_RADEON_LISTING for the newest wheel whose filename starts exactly | ||
| # with PACKAGE_NAME- and matches _RADEON_PYTAG + linux_x86_64. | ||
| # Prints the full URL (resolving relative hrefs against _RADEON_BASE_URL). | ||
| _pkg="$1" | ||
| [ -n "$_RADEON_LISTING" ] || return 1 | ||
| [ -n "$_RADEON_PYTAG" ] || return 1 | ||
| _tag="$_RADEON_PYTAG" | ||
| _href=$(printf '%s\n' "$_RADEON_LISTING" \ | ||
| | grep -o 'href="[^"]*"' \ | ||
| | sed 's/href="//;s/"//' \ | ||
| | awk -F/ -v pkg="$_pkg" -v tag="$_tag" ' | ||
| { | ||
| base = $NF | ||
| sub(/[?#].*/, "", base) # strip query / fragment | ||
| prefix = pkg "-" | ||
| suffix = "-" tag "-" tag "-linux_x86_64.whl" | ||
| if (substr(base, 1, length(prefix)) == prefix && | ||
| substr(base, length(base) - length(suffix) + 1) == suffix) | ||
| print $0 | ||
| }' \ | ||
| | sort -V \ | ||
| | tail -1) | ||
| [ -z "$_href" ] && return 1 | ||
| case "$_href" in | ||
| http*) printf '%s\n' "$_href" ;; | ||
| *) printf '%s\n' "${_RADEON_BASE_URL%/}/${_href#/}" ;; | ||
| esac | ||
| } | ||
|
|
||
| TORCH_INDEX_URL=$(get_torch_index_url) | ||
|
|
||
| # Auto-detect GPU for AMD ROCm based | ||
| # get_torch_index_url must have chosen */rocm* | ||
| # (gfx in rocminfo or amd-smi list). Then require rocminfo "Marketing Name:.*Radeon". | ||
| case "$TORCH_INDEX_URL" in | ||
| */rocm*) | ||
| _amd_gpu_here=false | ||
| _amd_gpu_radeon=false | ||
| if command -v rocminfo >/dev/null 2>&1 && \ | ||
| rocminfo 2>/dev/null | awk '/Name:[[:space:]]*gfx[0-9]/{found=1} END{exit !found}'; then | ||
| _amd_gpu_here=true | ||
| elif command -v amd-smi >/dev/null 2>&1 && \ | ||
| amd-smi list 2>/dev/null | awk 'NR>1 && NF{found=1} END{exit !found}'; then | ||
| _amd_gpu_here=true | ||
| fi | ||
| if [ "$_amd_gpu_here" = true ] && command -v rocminfo >/dev/null 2>&1 && \ | ||
| rocminfo 2>/dev/null | grep -q 'Marketing Name:.*Radeon'; then | ||
| _amd_gpu_radeon=true | ||
| fi | ||
| ;; | ||
| esac | ||
|
|
||
| # ── Print CPU-only hint when no GPU detected ── | ||
| case "$TORCH_INDEX_URL" in | ||
| */cpu) | ||
|
|
@@ -1072,7 +1173,11 @@ case "$TORCH_INDEX_URL" in | |
| ;; | ||
| */rocm*) | ||
| echo "" | ||
| echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)" | ||
| if [ "$_amd_gpu_radeon" = true ]; then | ||
| echo " AMD Radeon + ROCm detected -- installing PyTorch wheels from repo.radeon.com" | ||
| else | ||
| echo " AMD ROCm detected -- installing ROCm-enabled PyTorch ($TORCH_INDEX_URL)" | ||
| fi | ||
| echo "" | ||
| ;; | ||
| esac | ||
|
|
@@ -1108,6 +1213,46 @@ elif [ -n "$TORCH_INDEX_URL" ]; then | |
| # Fresh: Step 1 - install torch from explicit index (skip when --no-torch or Intel Mac) | ||
| if [ "$SKIP_TORCH" = true ]; then | ||
| substep "skipping PyTorch (--no-torch or Intel Mac x86_64)." "$C_WARN" | ||
| elif [ "$_amd_gpu_radeon" = true ]; then | ||
| _radeon_url=$(get_radeon_wheel_url) | ||
| if [ -n "$_radeon_url" ]; then | ||
| _radeon_listing_ok=false | ||
| if _radeon_fetch_listing "$_radeon_url" 2>/dev/null; then | ||
| _radeon_listing_ok=true | ||
| else | ||
| # Try shorter X.Y path (AMD publishes both X.Y.Z and X.Y dirs) | ||
| _radeon_url_short=$(printf '%s\n' "$_radeon_url" \ | ||
| | sed 's|rocm-rel-\([0-9]*\)\.\([0-9]*\)\.[0-9]*/|rocm-rel-\1.\2/|') | ||
| if [ "$_radeon_url_short" != "$_radeon_url" ] && \ | ||
| _radeon_fetch_listing "$_radeon_url_short" 2>/dev/null; then | ||
| _radeon_listing_ok=true | ||
| fi | ||
| fi | ||
|
|
||
| if [ "$_radeon_listing_ok" = true ]; then | ||
| substep "installing PyTorch from Radeon repo (${_RADEON_BASE_URL})..." | ||
| _torch_arg="torch"; _tv_arg="torchvision"; _ta_arg="torchaudio"; _tri_arg="" | ||
| _torch_whl=$(_pick_radeon_wheel "torch" 2>/dev/null) && _torch_arg="$_torch_whl" | ||
| _tv_whl=$(_pick_radeon_wheel "torchvision" 2>/dev/null) && _tv_arg="$_tv_whl" | ||
| _ta_whl=$(_pick_radeon_wheel "torchaudio" 2>/dev/null) && _ta_arg="$_ta_whl" | ||
| _tri_whl=$(_pick_radeon_wheel "triton" 2>/dev/null) && _tri_arg="$_tri_whl" | ||
| run_install_cmd "install triton + PyTorch" uv pip install --python "$_VENV_PY" \ | ||
| --find-links "$_RADEON_BASE_URL" \ | ||
| "$_tri_arg" "$_torch_arg" "$_tv_arg" "$_ta_arg" | ||
| substep "installing bitsandbytes for AMD Radeon..." | ||
| run_install_cmd "install bitsandbytes (AMD)" uv pip install --python "$_VENV_PY" \ | ||
| "bitsandbytes>=0.49.1" | ||
| else | ||
| substep "[WARN] Radeon repo unavailable; falling back to CPU-only PyTorch" "$C_WARN" | ||
| run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" \ | ||
| "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ | ||
| --index-url "${TORCH_INDEX_URL%/*}/cpu" | ||
|
Comment on lines
+1246
to
+1249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Radeon wheel discovery fails, this branch forces a CPU-only install ( Useful? React with 👍 / 👎. |
||
| fi | ||
| else | ||
| substep "[WARN] Radeon GPU detected but could not detect full ROCm version; falling back to CPU-only PyTorch" "$C_WARN" | ||
| run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ | ||
| --index-url "${TORCH_INDEX_URL%/*}/cpu" | ||
| fi | ||
| else | ||
| substep "installing PyTorch ($TORCH_INDEX_URL)..." | ||
| run_install_cmd "install PyTorch" uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" "torchvision<0.26.0" "torchaudio<2.11.0" \ | ||
|
|
@@ -1121,7 +1266,7 @@ elif [ -n "$TORCH_INDEX_URL" ]; then | |
| esac | ||
| fi | ||
| # Fresh: Step 2 - install unsloth, preserving pre-installed torch | ||
| substep "installing unsloth (this may take a few minutes)..." | ||
| substep "installing unsloth (this may take a few minutes)..." | ||
| if [ "$SKIP_TORCH" = true ]; then | ||
| # No-torch: install unsloth + unsloth-zoo with --no-deps, then | ||
| # runtime deps (typer, safetensors, transformers, etc.) with --no-deps. | ||
|
|
@@ -1292,4 +1437,4 @@ else | |
| substep "source ${VENV_DIR}/bin/activate" | ||
| substep "unsloth studio -H 0.0.0.0 -p 8888" | ||
| echo "" | ||
| fi | ||
| fi | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
_pick_radeon_wheelcannot find a matching wheel (for example, the Radeon repo has no wheel for the currentcpXYtag), these defaults leavetorch,torchvision, andtorchaudiounconstrained, souv pip installcan resolve to latest PyPI builds instead of the installer’s intended<2.11.0range. That can pull unsupported or non-ROCm builds and break the environment on Radeon machines, whereas the non-Radeon path still enforces version bounds.Useful? React with 👍 / 👎.