Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/uv-resolver/src/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
.options
.torch_backend
.as_ref()
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Auto { .. }))
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Cuda { .. }))
.and_then(|_| pins.get(name, version).and_then(ResolvedDist::index))
.map(IndexUrl::url)
.and_then(SystemDependency::from_index)
Expand Down
6 changes: 5 additions & 1 deletion crates/uv-static/src/env_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,14 @@ impl EnvVars {
/// This is a quasi-standard variable, described, e.g., in `ncurses(3x)`.
pub const COLUMNS: &'static str = "COLUMNS";

/// The CUDA driver version to assume when inferring the PyTorch backend.
/// The CUDA driver version to assume when inferring the PyTorch backend (e.g., `550.144.03`).
#[attr_hidden]
pub const UV_CUDA_DRIVER_VERSION: &'static str = "UV_CUDA_DRIVER_VERSION";

/// The AMD GPU architecture to assume when inferring the PyTorch backend (e.g., `gfx1100`).
#[attr_hidden]
pub const UV_AMD_GPU_ARCHITECTURE: &'static str = "UV_AMD_GPU_ARCHITECTURE";

/// Equivalent to the `--torch-backend` command-line argument (e.g., `cpu`, `cu126`, or `auto`).
pub const UV_TORCH_BACKEND: &'static str = "UV_TORCH_BACKEND";

Expand Down
110 changes: 109 additions & 1 deletion crates/uv-torch/src/accelerator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,30 @@ pub enum AcceleratorError {
Version(#[from] uv_pep440::VersionParseError),
#[error(transparent)]
Utf8(#[from] std::string::FromUtf8Error),
#[error("Unknown AMD GPU architecture: {0}")]
UnknownAmdGpuArchitecture(String),
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Accelerator {
/// The CUDA driver version (e.g., `550.144.03`).
///
/// This is in contrast to the CUDA toolkit version (e.g., `12.8.0`).
Cuda { driver_version: Version },
/// The AMD GPU architecture (e.g., `gfx906`).
///
/// This is in contrast to the user-space ROCm version (e.g., `6.4.0-47`) or the kernel-mode
/// driver version (e.g., `6.12.12`).
Amd {
gpu_architecture: AmdGpuArchitecture,
},
}

impl std::fmt::Display for Accelerator {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
}
}
}
Expand All @@ -33,9 +46,11 @@ impl Accelerator {
///
/// Query, in order:
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
/// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
/// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
/// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
/// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
// Read from `UV_CUDA_DRIVER_VERSION`.
if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
Expand All @@ -44,6 +59,15 @@ impl Accelerator {
return Ok(Some(Self::Cuda { driver_version }));
}

// Read from `UV_AMD_GPU_ARCHITECTURE`.
if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
debug!(
"Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}

// Read from `/sys/module/nvidia/version`.
match fs_err::read_to_string("/sys/module/nvidia/version") {
Ok(content) => {
Expand Down Expand Up @@ -100,7 +124,34 @@ impl Accelerator {
);
}

debug!("Failed to detect CUDA driver version");
// Query `rocm_agent_enumerator` to detect the AMD GPU architecture.
//
// See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html
if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
if output.status.success() {
let stdout = String::from_utf8(output.stdout)?;
if let Some(gpu_architecture) = stdout
.lines()
.map(str::trim)
.filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
.min()
{
debug!(
"Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}
} else {
debug!(
"Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
output.status,
String::from_utf8_lossy(&output.stderr)
);
}
}

debug!("Failed to detect GPU driver version");

Ok(None)
}
}
Expand Down Expand Up @@ -129,6 +180,63 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, Ac
Ok(Some(driver_version))
}

/// A GPU architecture for AMD GPUs.
///
/// See: <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html>
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum AmdGpuArchitecture {
Gfx900,
Gfx906,
Gfx908,
Gfx90a,
Gfx942,
Gfx1030,
Gfx1100,
Gfx1101,
Gfx1102,
Gfx1200,
Gfx1201,
}

impl FromStr for AmdGpuArchitecture {
type Err = AcceleratorError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gfx900" => Ok(Self::Gfx900),
"gfx906" => Ok(Self::Gfx906),
"gfx908" => Ok(Self::Gfx908),
"gfx90a" => Ok(Self::Gfx90a),
"gfx942" => Ok(Self::Gfx942),
"gfx1030" => Ok(Self::Gfx1030),
"gfx1100" => Ok(Self::Gfx1100),
"gfx1101" => Ok(Self::Gfx1101),
"gfx1102" => Ok(Self::Gfx1102),
"gfx1200" => Ok(Self::Gfx1200),
"gfx1201" => Ok(Self::Gfx1201),
_ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
}
}
}

impl std::fmt::Display for AmdGpuArchitecture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gfx900 => write!(f, "gfx900"),
Self::Gfx906 => write!(f, "gfx906"),
Self::Gfx908 => write!(f, "gfx908"),
Self::Gfx90a => write!(f, "gfx90a"),
Self::Gfx942 => write!(f, "gfx942"),
Self::Gfx1030 => write!(f, "gfx1030"),
Self::Gfx1100 => write!(f, "gfx1100"),
Self::Gfx1101 => write!(f, "gfx1101"),
Self::Gfx1102 => write!(f, "gfx1102"),
Self::Gfx1200 => write!(f, "gfx1200"),
Self::Gfx1201 => write!(f, "gfx1201"),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading
Loading