Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions ramalama/gpu_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,54 @@ def get_nvidia_gpu(self):
if platform.system() != "Linux":
return # Skip on macOS and other platforms

gpus = []
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=index,memory.total", "--format=csv,noheader,nounits"],
capture_output=True,
text=True,
check=True,
)
nameresult = subprocess.run(
[
"nvidia-smi",
"-L",
],
capture_output=True,
text=True,
check=True,
)
nameline = nameresult.stdout.strip().split('\n')
output = result.stdout.strip()

ctr = 0
for line in output.split('\n'):
try:
index, memory_mib = line.split(',')
memory_mib = int(memory_mib.strip())
self._update_best_gpu(memory_mib, index.strip(), "CUDA_VISIBLE_DEVICES")
gpu_info = {
"GPU": "NVIDIA GPU",
"Details": nameline[ctr].strip(),
"VRAM": f"{memory_mib} MiB",
"Env": "CUDA_VISIBLE_DEVICES",
}
gpus.append(gpu_info)
except ValueError:
raise RuntimeError(f"Error parsing Nvidia GPU info: {line}")
ctr += 1

except FileNotFoundError:
raise RuntimeError("`nvidia-smi` not found. No NVIDIA GPU detected or drivers missing.")
except subprocess.CalledProcessError as e:
error_msg = e.stderr.strip() if e.stderr else "Unknown error (check if NVIDIA drivers are loaded)."
raise RuntimeError(f"Unable to detect NVIDIA GPU(s). Error: {error_msg}")
return gpus

def get_amd_gpu(self):
"""Detects AMD GPUs using sysfs on Linux or system_profiler on macOS."""
if platform.system() == "Linux":
self._read_gpu_memory('/sys/bus/pci/devices/*/mem_info_vram_total', "AMD GPU", "HIP_VISIBLE_DEVICES")
return self._read_gpu_memory('/sys/bus/pci/devices/*/mem_info_vram_total', "AMD GPU", "HIP_VISIBLE_DEVICES")
return None

def _read_gpu_memory(self, path_pattern, gpu_name, env_var):
"""Helper function to read GPU VRAM from `/sys/class/drm/`."""
Expand All @@ -78,7 +99,7 @@ def _read_gpu_memory(self, path_pattern, gpu_name, env_var):
return {"GPU": gpu_name, "VRAM": f"{vram_total} MiB", "Env": env_var}
except Exception as e:
return {"GPU": gpu_name, "VRAM": "Unknown", "Env": env_var, "Error": str(e)}
return {"GPU": gpu_name, "VRAM": "Unknown", "Env": env_var}
return None

def get_intel_gpu(self):
"""Detect Intel GPUs using `lspci` and `/sys/class/drm/` for VRAM info."""
Expand Down
Loading