diff --git a/crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs b/crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs index 1826581cf..0adf3103b 100644 --- a/crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs +++ b/crates/rattler_lock/src/utils/serde/pep440_map_or_vec.rs @@ -35,7 +35,7 @@ impl<'de> DeserializeAs<'de, Vec> for Pep440MapOrVec { } else { Some(VersionOrUrl::VersionSpecifier(spec)) }, - marker: Default::default(), + marker: Option::default(), origin: None, }) }) diff --git a/crates/rattler_package_streaming/src/reqwest/tokio.rs b/crates/rattler_package_streaming/src/reqwest/tokio.rs index 965428f82..ba273510b 100644 --- a/crates/rattler_package_streaming/src/reqwest/tokio.rs +++ b/crates/rattler_package_streaming/src/reqwest/tokio.rs @@ -19,7 +19,7 @@ use zip::result::ZipError; /// to find the compressed data length. /// Since we stream the package over a non seekable HTTP connection, this condition will cause an error during /// decompression. In this case, we fallback to reading the whole data to a buffer before attempting decompression. -/// Read more in https://github.com/conda/rattler/issues/794 +/// Read more in const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header"; fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result { diff --git a/crates/rattler_virtual_packages/src/lib.rs b/crates/rattler_virtual_packages/src/lib.rs index e0e0027eb..c9d5219bc 100644 --- a/crates/rattler_virtual_packages/src/lib.rs +++ b/crates/rattler_virtual_packages/src/lib.rs @@ -35,8 +35,12 @@ pub mod osx; use archspec::cpu::Microarchitecture; use once_cell::sync::OnceCell; -use rattler_conda_types::{GenericVirtualPackage, PackageName, Platform, Version}; +use rattler_conda_types::{ + GenericVirtualPackage, PackageName, ParseVersionError, Platform, Version, +}; +use std::env; use std::hash::{Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; use crate::osx::ParseOsxVersionError; @@ -44,6 +48,38 @@ use libc::DetectLibCError; use linux::ParseLinuxVersionError; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +/// Traits for overridable virtual packages +/// Use as `Cuda::from_default_env_var.unwrap_or(Cuda::current().into()).unwrap()` +pub trait EnvOverride: Sized { + /// Parse `env_var_value` + fn from_env_var_name_with_var( + env_var_name: &str, + env_var_value: &str, + ) -> Result; + + /// Read the environment variable and if it exists, try to parse it with [`EnvOverride::from_env_var_name_with_var`] + /// If the output is: + /// - `None`, then the environment variable did not exist, + /// - `Some(Err(None))`, then the environment variable exist but was set to zero, so the package should be disabled + /// - `Some(Ok(pkg))`, then the override was for the package. + fn from_env_var_name(env_var_name: &str) -> Option>> { + let var = env::var(env_var_name).ok()?; + if var.is_empty() { + Some(Err(None)) + } else { + Some(Self::from_env_var_name_with_var(env_var_name, &var).map_err(Some)) + } + } + + /// Default name of the environment variable that overrides the virtual package. + const DEFAULT_ENV_NAME: &'static str; + + /// Shortcut for `EnvOverride::from_env_var_name(EnvOverride::DEFAULT_ENV_NAME)`. + fn from_default_env_var() -> Option>> { + Self::from_env_var_name(Self::DEFAULT_ENV_NAME) + } +} + /// An enum that represents all virtual package types provided by this library. #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum VirtualPackage { @@ -95,8 +131,8 @@ impl VirtualPackage { /// Returns virtual packages detected for the current system or an error if the versions could /// not be properly detected. pub fn current() -> Result<&'static [Self], DetectVirtualPackageError> { - static DETECED_VIRTUAL_PACKAGES: OnceCell> = OnceCell::new(); - DETECED_VIRTUAL_PACKAGES + static DETECTED_VIRTUAL_PACKAGES: OnceCell> = OnceCell::new(); + DETECTED_VIRTUAL_PACKAGES .get_or_try_init(try_detect_virtual_packages) .map(Vec::as_slice) } @@ -188,6 +224,12 @@ impl From for VirtualPackage { } } +impl From for Linux { + fn from(version: Version) -> Self { + Linux { version } + } +} + /// `LibC` virtual package description #[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)] pub struct LibC { @@ -229,6 +271,20 @@ impl From for VirtualPackage { } } +impl EnvOverride for LibC { + const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_GLIBC"; + + fn from_env_var_name_with_var( + _env_var_name: &str, + env_var_value: &str, + ) -> Result { + Version::from_str(env_var_value).map(|version| Self { + family: "glibc".into(), + version, + }) + } +} + /// Cuda virtual package description #[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)] pub struct Cuda { @@ -243,6 +299,23 @@ impl Cuda { } } +impl From for Cuda { + fn from(version: Version) -> Self { + Self { version } + } +} + +impl EnvOverride for Cuda { + fn from_env_var_name_with_var( + _env_var_name: &str, + env_var_value: &str, + ) -> Result { + Version::from_str(env_var_value).map(|version| Self { version }) + } + + const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_CUDA"; +} + impl From for GenericVirtualPackage { fn from(cuda: Cuda) -> Self { GenericVirtualPackage { @@ -359,7 +432,7 @@ impl From for GenericVirtualPackage { GenericVirtualPackage { name: PackageName::new_unchecked("__archspec"), version: Version::major(1), - build_string: archspec.spec.name().to_string(), + build_string: archspec.spec.name().into(), } } } @@ -403,8 +476,34 @@ impl From for VirtualPackage { } } +impl From for Osx { + fn from(version: Version) -> Self { + Self { version } + } +} + +impl EnvOverride for Osx { + fn from_env_var_name_with_var( + _env_var_name: &str, + env_var_value: &str, + ) -> Result { + Version::from_str(env_var_value).map(|version| Self { version }) + } + + const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_OSX"; +} + #[cfg(test)] mod test { + use std::env; + use std::str::FromStr; + + use rattler_conda_types::Version; + + use crate::Cuda; + use crate::EnvOverride; + use crate::LibC; + use crate::Osx; use crate::VirtualPackage; #[test] @@ -412,4 +511,38 @@ mod test { let virtual_packages = VirtualPackage::current().unwrap(); println!("{virtual_packages:?}"); } + #[test] + fn parse_libc() { + let v = "1.23"; + let res = LibC { + version: Version::from_str(v).unwrap(), + family: "glibc".into(), + }; + env::set_var(LibC::DEFAULT_ENV_NAME, v); + assert_eq!(LibC::from_default_env_var(), Some(Ok(res))); + env::set_var(LibC::DEFAULT_ENV_NAME, ""); + assert_eq!(LibC::from_default_env_var(), Some(Err(None))); + env::remove_var(LibC::DEFAULT_ENV_NAME); + assert_eq!(LibC::from_default_env_var(), None); + } + + #[test] + fn parse_cuda() { + let v = "1.234"; + let res = Cuda { + version: Version::from_str(v).unwrap(), + }; + env::set_var(Cuda::DEFAULT_ENV_NAME, v); + assert_eq!(Cuda::from_default_env_var(), Some(Ok(res))); + } + + #[test] + fn parse_osx() { + let v = "2.345"; + let res = Osx { + version: Version::from_str(v).unwrap(), + }; + env::set_var(Osx::DEFAULT_ENV_NAME, v); + assert_eq!(Osx::from_default_env_var(), Some(Ok(res))); + } }