Skip to content

fix(studio): ensure ROCm-enabled torch in venv on AMD hosts#4448

Closed
GoldenGrapeGentleman wants to merge 5 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/studio-rocm-venv-torch
Closed

fix(studio): ensure ROCm-enabled torch in venv on AMD hosts#4448
GoldenGrapeGentleman wants to merge 5 commits into
unslothai:mainfrom
GoldenGrapeGentleman:fix/studio-rocm-venv-torch

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

Problem

During Studio venv setup, pip resolves torch from PyPI and installs CPU-only wheels regardless of the host GPU. AMD ROCm users end up with a venv that cannot use their GPU for training.

First reported by @andyluo7 in #4390.

Fix

Add _ensure_rocm_torch() called immediately after base packages:

  • Detects ROCm via $ROCM_PATH / /opt/rocm / hipcc
  • Reads installed version from /opt/rocm/.info/version
  • Maps (major, minor) → PyTorch wheel index (ROCm 6.0–7.x covered)
  • Skips if torch already links against HIP or CUDA (no-op on NVIDIA hosts)
  • Force-reinstalls torch + torchvision + torchaudio from the matched index URL

Testing

Verified on 8×AMD MI355X (ROCm 7.1, 288 GB HBM3e):

  • _rocm_version() returns (7, 1)
  • Wheel mapping (7,1) → rocm6.3 correct ✅
  • Already-HIP-torch detected → no-op ✅
  • CPU-only mock → would trigger reinstall with correct URL ✅

Co-authored-by: billishyahao bill.he@amd.com

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@GoldenGrapeGentleman GoldenGrapeGentleman force-pushed the fix/studio-rocm-venv-torch branch from 60fc627 to 2f87648 Compare March 19, 2026 02:43
pip resolves torch from PyPI during base package installation, pulling
CPU-only wheels regardless of the host GPU.  AMD ROCm users end up with
a venv that cannot use their GPU for training.

Add _ensure_rocm_torch() which runs immediately after base packages:
- detects ROCm via $ROCM_PATH / /opt/rocm / hipcc
- reads the installed version from /opt/rocm/.info/version
- maps (major, minor) to the correct PyTorch wheel index via tuple comparison
- skips if torch is already GPU-enabled (checks both torch.version.hip and
  torch.version.cuda to avoid clobbering CUDA torch on mixed hosts)
- force-reinstalls torch + torchvision + torchaudio from the matched index URL

Tested on 8×AMD MI355X (ROCm 7.1) — version detection, wheel mapping,
and no-op behaviour all verified.

Fixes the issue raised by andyluo7 in unslothai#4390.

Co-authored-by: billishyahao <bill.he@amd.com>
@GoldenGrapeGentleman GoldenGrapeGentleman force-pushed the fix/studio-rocm-venv-torch branch from e46884c to e5d58f5 Compare March 19, 2026 07:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant