Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 51 additions & 18 deletions studio/backend/utils/hardware/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DeviceType(str, Enum):
"""Supported compute backends. Inherits from str so it serializes cleanly in JSON."""

CUDA = "cuda"
ROCM = "rocm"
MLX = "mlx"
CPU = "cpu"

Expand Down Expand Up @@ -79,21 +80,27 @@ def detect_hardware() -> DeviceType:

Detection order:
1. CUDA (NVIDIA GPU, requires torch)
2. MLX (Apple Silicon via MLX framework)
3. CPU (fallback)
2. ROCm (AMD GPU, requires torch with HIP)
3. MLX (Apple Silicon via MLX framework)
4. CPU (fallback)
"""
global DEVICE, CHAT_ONLY
CHAT_ONLY = True # reset -- only CUDA sets it to False
CHAT_ONLY = True # reset -- only CUDA/ROCm sets it to False

# --- CUDA: try PyTorch ---
# --- CUDA / ROCm: try PyTorch ---
if _has_torch():
import torch

if torch.cuda.is_available():
DEVICE = DeviceType.CUDA
CHAT_ONLY = False
device_name = torch.cuda.get_device_properties(0).name
print(f"Hardware detected: CUDA — {device_name}")

if getattr(torch.version, "hip", None) is None:
DEVICE = DeviceType.CUDA
print(f"Hardware detected: CUDA — {device_name}")
else:
DEVICE = DeviceType.ROCM
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep PyTorch device string on ROCm as cuda

On ROCm hosts this branch returns DeviceType.ROCM, whose .value is "rocm", but downstream inference code uses get_device().value as a PyTorch device string (InferenceBackend.__init__ sets self.device, then .to(self.device) is called in generation paths). PyTorch HIP still uses the CUDA device namespace, so "rocm" is not a valid target for Tensor.to(...), which can break model load/inference specifically in ROCm environments.

Useful? React with 👍 / 👎.

print(f"Hardware detected: ROCm {torch.version.hip} — {device_name}")
return DEVICE

# --- MLX: Apple Silicon ---
Expand Down Expand Up @@ -134,7 +141,7 @@ def clear_gpu_cache():

device = get_device()

if device == DeviceType.CUDA:
if device in (DeviceType.CUDA, DeviceType.ROCM):
import torch

torch.cuda.synchronize()
Expand All @@ -153,8 +160,8 @@ def get_gpu_memory_info() -> Dict[str, Any]:
"""
device = get_device()

# ---- CUDA path ----
if device == DeviceType.CUDA:
# ---- CUDA / ROCm path ----
if device in (DeviceType.CUDA, DeviceType.ROCM):
try:
import torch

Expand Down Expand Up @@ -269,13 +276,15 @@ def get_package_versions() -> Dict[str, Optional[str]]:
except PackageNotFoundError:
versions[name] = None

# CUDA toolkit version bundled with torch
# GPU runtime version bundled with torch
try:
import torch

versions["cuda"] = getattr(torch.version, "cuda", None)
versions["rocm"] = getattr(torch.version, "hip", None)
except Exception:
versions["cuda"] = None
versions["rocm"] = None

return versions

Expand Down Expand Up @@ -305,7 +314,7 @@ def get_gpu_utilization() -> Dict[str, Any]:
"""
device = get_device()

if device != DeviceType.CUDA:
if device not in (DeviceType.CUDA, DeviceType.ROCM):
return {"available": False, "backend": device.value}

def _parse_smi_value(raw: str):
Expand Down Expand Up @@ -419,19 +428,20 @@ def _parse_smi_value(raw: str):

def get_physical_gpu_count() -> int:
"""
Return the number of physical NVIDIA GPUs on the machine.
Return the number of physical GPUs on the machine.

Uses ``nvidia-smi -L`` which is NOT affected by CUDA_VISIBLE_DEVICES,
so it always reflects the true hardware count.
Tries ``nvidia-smi -L`` (NVIDIA), then ``rocm-smi`` (AMD),
then falls back to ``torch.cuda.device_count()``.
Result is cached after the first call.
"""
global _physical_gpu_count
if _physical_gpu_count is not None:
return _physical_gpu_count

try:
import subprocess
import subprocess

# --- NVIDIA ---
try:
result = subprocess.run(
["nvidia-smi", "-L"],
capture_output = True,
Expand All @@ -440,8 +450,31 @@ def get_physical_gpu_count() -> int:
)
if result.returncode == 0 and result.stdout.strip():
_physical_gpu_count = len(result.stdout.strip().splitlines())
else:
_physical_gpu_count = 1
return _physical_gpu_count
except Exception:
pass

# --- AMD ROCm ---
try:
result = subprocess.run(
["rocm-smi", "--showserial"],
capture_output = True,
text = True,
timeout = 5,
)
if result.returncode == 0 and result.stdout.strip():
count = sum(1 for l in result.stdout.splitlines() if l.startswith("GPU["))
if count > 0:
_physical_gpu_count = count
return _physical_gpu_count
except Exception:
pass

# --- Fallback ---
try:
import torch

_physical_gpu_count = max(torch.cuda.device_count(), 1)
except Exception:
_physical_gpu_count = 1

Expand Down