diff --git a/install.sh b/install.sh index 15755c589f..30571fa7bf 100755 --- a/install.sh +++ b/install.sh @@ -1972,6 +1972,58 @@ case "$TORCH_INDEX_URL" in fi ;; esac +# ── RDNA2 (gfx1030-gfx1036, e.g. RX 6600/6700/6800/6900): cap to rocm6.2 ──── +# ROCm 7.x PyTorch wheels for gfx103x are dev/nightly builds (version string +# contains a git hash like 2.10.0+rocm7.2.0.gitXXXXXXXX) that segfault during +# unsloth import on RDNA2. The last stable, tested wheel is rocm6.2 (torch +# 2.7.x). When the system has ROCm 7.x but the runtime GPU is RDNA2, override +# TORCH_INDEX_URL to the rocm6.2 index so users get a stable install. +case "$TORCH_INDEX_URL" in + */rocm7.*) + _rdna2_gfx_all="" + if command -v rocminfo >/dev/null 2>&1; then + _rdna2_gfx_all=$(rocminfo 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}') + fi + if [ -z "$_rdna2_gfx_all" ] && command -v amd-smi >/dev/null 2>&1; then + _rdna2_gfx_all=$(amd-smi list 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}') + if [ -z "$_rdna2_gfx_all" ]; then + _rdna2_gfx_all=$(amd-smi static --asic 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}') + fi + fi + _rdna2_runtime_gfx="" + if [ -n "$_rdna2_gfx_all" ]; then + _vis="${HIP_VISIBLE_DEVICES:-${ROCR_VISIBLE_DEVICES:-}}" + _idx=0 + if [ -n "$_vis" ] && [ "$_vis" != "-1" ]; then + _first=${_vis%%,*} + case "$_first" in + ''|*[!0-9]*) _idx=0 ;; + *) _idx=$_first ;; + esac + fi + _rdna2_runtime_gfx=$(printf '%s\n' "$_rdna2_gfx_all" | awk -v idx="$_idx" ' + NF && !seen[$0]++ { vals[n++] = $0 } + END { + if (idx < 0 || idx >= n) idx = 0 + if (n > 0) print vals[idx] + }') + fi + case "$_rdna2_runtime_gfx" in + gfx1030|gfx1031|gfx1032|gfx1033|gfx1034|gfx1035|gfx1036) + _pytorch_base="${UNSLOTH_PYTORCH_MIRROR:-https://download.pytorch.org/whl}" + _pytorch_base="${_pytorch_base%/}" + echo "" >&2 + echo " [WARN] $_rdna2_runtime_gfx (RDNA2) + ROCm 7.x detected" >&2 + echo " [WARN] ROCm 7.x PyTorch wheels are dev/nightly builds on gfx103x" >&2 + echo " [WARN] and cause segfaults during unsloth import on RDNA2 hardware." >&2 + echo " [WARN] Capping to rocm6.2 (torch 2.7.x) -- the last stable wheel." >&2 + echo "" >&2 + TORCH_INDEX_URL="${_pytorch_base}/rocm6.2" + TORCH_CONSTRAINT="torch>=2.4,<2.11.0" + ;; + esac + ;; +esac _TAURI_TORCH_INDEX_FAMILY=$(_tauri_torch_index_family "$TORCH_INDEX_URL") if [ "$_amd_gpu_radeon" = true ] && [ "$SKIP_TORCH" = false ]; then _TAURI_TORCH_INDEX_FAMILY="radeon" diff --git a/studio/backend/core/inference/inference.py b/studio/backend/core/inference/inference.py index e1620f5ca3..4306495077 100644 --- a/studio/backend/core/inference/inference.py +++ b/studio/backend/core/inference/inference.py @@ -5,7 +5,7 @@ Core inference backend - streamlined """ -from unsloth import FastLanguageModel, FastVisionModel +from unsloth import FastLanguageModel, FastVisionModel, is_bfloat16_supported from unsloth.chat_templates import get_chat_template from transformers import TextStreamer from peft import PeftModel, PeftModelForCausalLM @@ -460,13 +460,26 @@ def load_model( logger.info(f"Loading {model_type} model{adapter_info}: {model_name}") log_gpu_memory(f"Before loading {model_name}") + # AMD ROCm hardware without native bfloat16 (e.g. RDNA2 / gfx103x) crashes + # with an LLVM error at the first bf16 kernel dispatch when dtype=None lets + # unsloth auto-pick bf16. Force float16 there. NVIDIA keeps dtype=None so + # unsloth's own bf16/fp16/float32 auto-detection is honored. + _is_rocm = ( + bool(getattr(torch.version, "hip", None)) + or "rocm" in torch.__version__.lower() + ) + _auto_dtype = ( + torch.float16 if (_is_rocm and not is_bfloat16_supported()) else None + ) + _effective_dtype = dtype if dtype is not None else _auto_dtype + # Load model - same approach for base models and LoRA adapters if config.is_vision: # Vision model (or vision LoRA adapter) model, processor = FastVisionModel.from_pretrained( model_name = config.path, # Can be base model OR LoRA adapter path max_seq_length = max_seq_length, - dtype = dtype, + dtype = _effective_dtype, load_in_4bit = load_in_4bit, device_map = device_map, token = hf_token if hf_token and hf_token.strip() else None, @@ -523,7 +536,7 @@ def load_model( model, tokenizer = FastLanguageModel.from_pretrained( model_name = config.path, # Can be base model OR LoRA adapter path max_seq_length = max_seq_length, - dtype = dtype, + dtype = _effective_dtype, load_in_4bit = load_in_4bit, device_map = device_map, token = hf_token if hf_token and hf_token.strip() else None, diff --git a/studio/install_python_stack.py b/studio/install_python_stack.py index 5e785c934d..4d31719dd1 100644 --- a/studio/install_python_stack.py +++ b/studio/install_python_stack.py @@ -695,6 +695,30 @@ def _ensure_rocm_torch() -> None: rocm_torch_ready = has_hip_torch + # RDNA2 (gfx1030-gfx1036, e.g. RX 6600/6700/6800/6900) cap + # ───────────────────────────────────────────────────────────── + # ROCm 7.x PyTorch builds are unstable on RDNA2: the gfx103x PyTorch wheels + # for ROCm 7.x are dev/nightly builds (version suffix .gitXXXXXXXX) that + # segfault during unsloth import on these GPUs. The last stable, tested + # PyTorch for RDNA2 is the rocm6.2 wheel. When the runtime GPU is RDNA2 + # and the system has ROCm 7.x installed, override the tag selection to + # rocm6.2 so users get a stable torch regardless of their ROCm stack version. + _RDNA2_GFX = {"gfx1030", "gfx1031", "gfx1032", "gfx1033", "gfx1034", "gfx1035", "gfx1036"} + _rdna2_cap_tag: "str | None" = None + if ver >= (7, 0): + gfx_codes = _detect_amd_gfx_codes() + _runtime_gfx = ( + gfx_codes[_pick_visible_index(len(gfx_codes))] if gfx_codes else None + ) + if _runtime_gfx in _RDNA2_GFX: + _rdna2_cap_tag = "rocm6.2" + print( + f"\n {_runtime_gfx} (RDNA2) detected with ROCm {ver[0]}.{ver[1]}.\n" + f" ROCm 7.x PyTorch wheels are unstable on RDNA2 (dev builds that\n" + f" segfault on import). Capping torch install to the last known-good\n" + f" wheel: pytorch.org/whl/rocm6.2 (torch 2.7.x).\n" + ) + # Strix Halo / Strix Point (gfx1151 / gfx1150) segfault under ROCm 7.1 # in torch._grouped_mm. AMD's per-gfx repo ships torch 2.11.0+rocm7.13.0 # with the real fix, so route those hosts there instead of the generic @@ -771,16 +795,22 @@ def _ensure_rocm_torch() -> None: constrain = False, ) rocm_torch_ready = True - elif not has_hip_torch: - # Select best matching wheel tag (newest ROCm version <= installed) - tag = next( - ( - t - for (maj, mn), t in sorted(_ROCM_TORCH_INDEX.items(), reverse = True) - if ver >= (maj, mn) - ), - None, - ) + elif _rdna2_cap_tag is not None or not has_hip_torch: + # RDNA2 with ROCm 7.x: force the capped tag regardless of whether + # a hip torch is already present -- the dev build that ends up there + # is exactly what we need to replace. + # Otherwise: select best matching wheel tag (newest ROCm version <= installed). + if _rdna2_cap_tag is not None: + tag = _rdna2_cap_tag + else: + tag = next( + ( + t + for (maj, mn), t in sorted(_ROCM_TORCH_INDEX.items(), reverse = True) + if ver >= (maj, mn) + ), + None, + ) if tag is None: print( f" No PyTorch wheel for ROCm {ver[0]}.{ver[1]} -- "