Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for CONDA_OVERRIDE_CUDA #818

Merged
merged 11 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<'de> DeserializeAs<'de, Vec<Requirement>> for Pep440MapOrVec {
} else {
Some(VersionOrUrl::VersionSpecifier(spec))
},
marker: Default::default(),
marker: Option::default(),
origin: None,
})
})
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_package_streaming/src/reqwest/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use zip::result::ZipError;
/// to find the compressed data length.
/// Since we stream the package over a non seekable HTTP connection, this condition will cause an error during
/// decompression. In this case, we fallback to reading the whole data to a buffer before attempting decompression.
/// Read more in https://github.com/conda/rattler/issues/794
/// Read more in <https://github.com/conda/rattler/issues/794>
const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header";

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
Expand Down
141 changes: 137 additions & 4 deletions crates/rattler_virtual_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,51 @@ pub mod osx;

use archspec::cpu::Microarchitecture;
use once_cell::sync::OnceCell;
use rattler_conda_types::{GenericVirtualPackage, PackageName, Platform, Version};
use rattler_conda_types::{
GenericVirtualPackage, PackageName, ParseVersionError, Platform, Version,
};
use std::env;
use std::hash::{Hash, Hasher};
use std::str::FromStr;
use std::sync::Arc;

use crate::osx::ParseOsxVersionError;
use libc::DetectLibCError;
use linux::ParseLinuxVersionError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

/// Traits for overridable virtual packages
/// Use as `Cuda::from_default_env_var.unwrap_or(Cuda::current().into()).unwrap()`
pub trait EnvOverride: Sized {
/// Parse `env_var_value`
fn from_env_var_name_with_var(
env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError>;

/// Read the environment variable and if it exists, try to parse it with [`EnvOverride::from_env_var_name_with_var`]
/// If the output is:
/// - `None`, then the environment variable did not exist,
/// - `Some(Err(None))`, then the environment variable exist but was set to zero, so the package should be disabled
/// - `Some(Ok(pkg))`, then the override was for the package.
fn from_env_var_name(env_var_name: &str) -> Option<Result<Self, Option<ParseVersionError>>> {
let var = env::var(env_var_name).ok()?;
if var.is_empty() {
Some(Err(None))
} else {
Some(Self::from_env_var_name_with_var(env_var_name, &var).map_err(Some))
}
}

/// Default name of the environment variable that overrides the virtual package.
const DEFAULT_ENV_NAME: &'static str;

/// Shortcut for `EnvOverride::from_env_var_name(EnvOverride::DEFAULT_ENV_NAME)`.
fn from_default_env_var() -> Option<Result<Self, Option<ParseVersionError>>> {
Self::from_env_var_name(Self::DEFAULT_ENV_NAME)
}
}

/// An enum that represents all virtual package types provided by this library.
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub enum VirtualPackage {
Expand Down Expand Up @@ -95,8 +131,8 @@ impl VirtualPackage {
/// Returns virtual packages detected for the current system or an error if the versions could
/// not be properly detected.
pub fn current() -> Result<&'static [Self], DetectVirtualPackageError> {
static DETECED_VIRTUAL_PACKAGES: OnceCell<Vec<VirtualPackage>> = OnceCell::new();
DETECED_VIRTUAL_PACKAGES
static DETECTED_VIRTUAL_PACKAGES: OnceCell<Vec<VirtualPackage>> = OnceCell::new();
DETECTED_VIRTUAL_PACKAGES
.get_or_try_init(try_detect_virtual_packages)
.map(Vec::as_slice)
}
Expand Down Expand Up @@ -188,6 +224,12 @@ impl From<Linux> for VirtualPackage {
}
}

impl From<Version> for Linux {
fn from(version: Version) -> Self {
Linux { version }
}
}

/// `LibC` virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct LibC {
Expand Down Expand Up @@ -229,6 +271,20 @@ impl From<LibC> for VirtualPackage {
}
}

impl EnvOverride for LibC {
const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_GLIBC";

fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self {
family: "glibc".into(),
version,
})
}
}

/// Cuda virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct Cuda {
Expand All @@ -243,6 +299,23 @@ impl Cuda {
}
}

impl From<Version> for Cuda {
fn from(version: Version) -> Self {
Self { version }
}
}

impl EnvOverride for Cuda {
fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self { version })
}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_CUDA";
}

impl From<Cuda> for GenericVirtualPackage {
fn from(cuda: Cuda) -> Self {
GenericVirtualPackage {
Expand Down Expand Up @@ -359,7 +432,7 @@ impl From<Archspec> for GenericVirtualPackage {
GenericVirtualPackage {
name: PackageName::new_unchecked("__archspec"),
version: Version::major(1),
build_string: archspec.spec.name().to_string(),
build_string: archspec.spec.name().into(),
}
}
}
Expand Down Expand Up @@ -403,13 +476,73 @@ impl From<Osx> for VirtualPackage {
}
}

impl From<Version> for Osx {
fn from(version: Version) -> Self {
Self { version }
}
}

impl EnvOverride for Osx {
fn from_env_var_name_with_var(
_env_var_name: &str,
env_var_value: &str,
) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self { version })
}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_OSX";
}

#[cfg(test)]
mod test {
use std::env;
use std::str::FromStr;

use rattler_conda_types::Version;

use crate::Cuda;
use crate::EnvOverride;
use crate::LibC;
use crate::Osx;
use crate::VirtualPackage;

#[test]
fn doesnt_crash() {
let virtual_packages = VirtualPackage::current().unwrap();
println!("{virtual_packages:?}");
}
#[test]
fn parse_libc() {
let v = "1.23";
let res = LibC {
version: Version::from_str(v).unwrap(),
family: "glibc".into(),
};
env::set_var(LibC::DEFAULT_ENV_NAME, v);
assert_eq!(LibC::from_default_env_var(), Some(Ok(res)));
env::set_var(LibC::DEFAULT_ENV_NAME, "");
assert_eq!(LibC::from_default_env_var(), Some(Err(None)));
env::remove_var(LibC::DEFAULT_ENV_NAME);
assert_eq!(LibC::from_default_env_var(), None);
}

#[test]
fn parse_cuda() {
let v = "1.234";
let res = Cuda {
version: Version::from_str(v).unwrap(),
};
env::set_var(Cuda::DEFAULT_ENV_NAME, v);
assert_eq!(Cuda::from_default_env_var(), Some(Ok(res)));
}

#[test]
fn parse_osx() {
let v = "2.345";
let res = Osx {
version: Version::from_str(v).unwrap(),
};
env::set_var(Osx::DEFAULT_ENV_NAME, v);
assert_eq!(Osx::from_default_env_var(), Some(Ok(res)));
}
}