Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,58 @@ case "$TORCH_INDEX_URL" in
fi
;;
esac
# ── RDNA2 (gfx1030-gfx1036, e.g. RX 6600/6700/6800/6900): cap to rocm6.2 ────
# ROCm 7.x PyTorch wheels for gfx103x are dev/nightly builds (version string
# contains a git hash like 2.10.0+rocm7.2.0.gitXXXXXXXX) that segfault during
# unsloth import on RDNA2. The last stable, tested wheel is rocm6.2 (torch
# 2.7.x). When the system has ROCm 7.x but the runtime GPU is RDNA2, override
# TORCH_INDEX_URL to the rocm6.2 index so users get a stable install.
case "$TORCH_INDEX_URL" in
*/rocm7.*)
_rdna2_gfx_all=""
if command -v rocminfo >/dev/null 2>&1; then
_rdna2_gfx_all=$(rocminfo 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}')
fi
if [ -z "$_rdna2_gfx_all" ] && command -v amd-smi >/dev/null 2>&1; then
_rdna2_gfx_all=$(amd-smi list 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}')
if [ -z "$_rdna2_gfx_all" ]; then
_rdna2_gfx_all=$(amd-smi static --asic 2>/dev/null | grep -oE 'gfx[1-9][0-9a-z]{2,3}')
fi
fi
_rdna2_runtime_gfx=""
if [ -n "$_rdna2_gfx_all" ]; then
_vis="${HIP_VISIBLE_DEVICES:-${ROCR_VISIBLE_DEVICES:-}}"
_idx=0
if [ -n "$_vis" ] && [ "$_vis" != "-1" ]; then
_first=${_vis%%,*}
case "$_first" in
''|*[!0-9]*) _idx=0 ;;
*) _idx=$_first ;;
esac
fi
_rdna2_runtime_gfx=$(printf '%s\n' "$_rdna2_gfx_all" | awk -v idx="$_idx" '
NF && !seen[$0]++ { vals[n++] = $0 }
END {
if (idx < 0 || idx >= n) idx = 0
if (n > 0) print vals[idx]
}')
fi
case "$_rdna2_runtime_gfx" in
gfx1030|gfx1031|gfx1032|gfx1033|gfx1034|gfx1035|gfx1036)
_pytorch_base="${UNSLOTH_PYTORCH_MIRROR:-https://download.pytorch.org/whl}"
_pytorch_base="${_pytorch_base%/}"
echo "" >&2
echo " [WARN] $_rdna2_runtime_gfx (RDNA2) + ROCm 7.x detected" >&2
echo " [WARN] ROCm 7.x PyTorch wheels are dev/nightly builds on gfx103x" >&2
echo " [WARN] and cause segfaults during unsloth import on RDNA2 hardware." >&2
echo " [WARN] Capping to rocm6.2 (torch 2.7.x) -- the last stable wheel." >&2
echo "" >&2
TORCH_INDEX_URL="${_pytorch_base}/rocm6.2"
TORCH_CONSTRAINT="torch>=2.4,<2.11.0"
;;
esac
;;
esac
_TAURI_TORCH_INDEX_FAMILY=$(_tauri_torch_index_family "$TORCH_INDEX_URL")
if [ "$_amd_gpu_radeon" = true ] && [ "$SKIP_TORCH" = false ]; then
_TAURI_TORCH_INDEX_FAMILY="radeon"
Expand Down
19 changes: 16 additions & 3 deletions studio/backend/core/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Core inference backend - streamlined
"""

from unsloth import FastLanguageModel, FastVisionModel
from unsloth import FastLanguageModel, FastVisionModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
from peft import PeftModel, PeftModelForCausalLM
Expand Down Expand Up @@ -460,13 +460,26 @@ def load_model(
logger.info(f"Loading {model_type} model{adapter_info}: {model_name}")
log_gpu_memory(f"Before loading {model_name}")

# AMD ROCm hardware without native bfloat16 (e.g. RDNA2 / gfx103x) crashes
# with an LLVM error at the first bf16 kernel dispatch when dtype=None lets
# unsloth auto-pick bf16. Force float16 there. NVIDIA keeps dtype=None so
# unsloth's own bf16/fp16/float32 auto-detection is honored.
_is_rocm = (
bool(getattr(torch.version, "hip", None))
or "rocm" in torch.__version__.lower()
)
_auto_dtype = (
torch.float16 if (_is_rocm and not is_bfloat16_supported()) else None
)
_effective_dtype = dtype if dtype is not None else _auto_dtype

# Load model - same approach for base models and LoRA adapters
if config.is_vision:
# Vision model (or vision LoRA adapter)
model, processor = FastVisionModel.from_pretrained(
model_name = config.path, # Can be base model OR LoRA adapter path
max_seq_length = max_seq_length,
dtype = dtype,
dtype = _effective_dtype,
load_in_4bit = load_in_4bit,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
Expand Down Expand Up @@ -523,7 +536,7 @@ def load_model(
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = config.path, # Can be base model OR LoRA adapter path
max_seq_length = max_seq_length,
dtype = dtype,
dtype = _effective_dtype,
load_in_4bit = load_in_4bit,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
Expand Down
50 changes: 40 additions & 10 deletions studio/install_python_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,30 @@ def _ensure_rocm_torch() -> None:

rocm_torch_ready = has_hip_torch

# RDNA2 (gfx1030-gfx1036, e.g. RX 6600/6700/6800/6900) cap
# ─────────────────────────────────────────────────────────────
# ROCm 7.x PyTorch builds are unstable on RDNA2: the gfx103x PyTorch wheels
# for ROCm 7.x are dev/nightly builds (version suffix .gitXXXXXXXX) that
# segfault during unsloth import on these GPUs. The last stable, tested
# PyTorch for RDNA2 is the rocm6.2 wheel. When the runtime GPU is RDNA2
# and the system has ROCm 7.x installed, override the tag selection to
# rocm6.2 so users get a stable torch regardless of their ROCm stack version.
_RDNA2_GFX = {"gfx1030", "gfx1031", "gfx1032", "gfx1033", "gfx1034", "gfx1035", "gfx1036"}
_rdna2_cap_tag: "str | None" = None
if ver >= (7, 0):
gfx_codes = _detect_amd_gfx_codes()
_runtime_gfx = (
gfx_codes[_pick_visible_index(len(gfx_codes))] if gfx_codes else None
)
if _runtime_gfx in _RDNA2_GFX:
_rdna2_cap_tag = "rocm6.2"
print(
f"\n {_runtime_gfx} (RDNA2) detected with ROCm {ver[0]}.{ver[1]}.\n"
f" ROCm 7.x PyTorch wheels are unstable on RDNA2 (dev builds that\n"
f" segfault on import). Capping torch install to the last known-good\n"
f" wheel: pytorch.org/whl/rocm6.2 (torch 2.7.x).\n"
)

# Strix Halo / Strix Point (gfx1151 / gfx1150) segfault under ROCm 7.1
# in torch._grouped_mm. AMD's per-gfx repo ships torch 2.11.0+rocm7.13.0
# with the real fix, so route those hosts there instead of the generic
Expand Down Expand Up @@ -771,16 +795,22 @@ def _ensure_rocm_torch() -> None:
constrain = False,
)
rocm_torch_ready = True
elif not has_hip_torch:
# Select best matching wheel tag (newest ROCm version <= installed)
tag = next(
(
t
for (maj, mn), t in sorted(_ROCM_TORCH_INDEX.items(), reverse = True)
if ver >= (maj, mn)
),
None,
)
elif _rdna2_cap_tag is not None or not has_hip_torch:
# RDNA2 with ROCm 7.x: force the capped tag regardless of whether
# a hip torch is already present -- the dev build that ends up there
# is exactly what we need to replace.
# Otherwise: select best matching wheel tag (newest ROCm version <= installed).
if _rdna2_cap_tag is not None:
tag = _rdna2_cap_tag
else:
tag = next(
(
t
for (maj, mn), t in sorted(_ROCM_TORCH_INDEX.items(), reverse = True)
if ver >= (maj, mn)
),
None,
)
if tag is None:
print(
f" No PyTorch wheel for ROCm {ver[0]}.{ver[1]} -- "
Expand Down
Loading