From ce0e0ec628f56ab408b5f3add2b5fe93c1e6560b Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 19 Jul 2024 16:40:16 -0400 Subject: [PATCH 1/2] respect local versions for all user requirements --- .../uv-resolver/src/pubgrub/dependencies.rs | 49 +- crates/uv-resolver/src/resolver/locals.rs | 88 ++- crates/uv-resolver/src/resolver/mod.rs | 87 +-- crates/uv/tests/pip_compile.rs | 501 +++++++++++++++++- 4 files changed, 631 insertions(+), 94 deletions(-) diff --git a/crates/uv-resolver/src/pubgrub/dependencies.rs b/crates/uv-resolver/src/pubgrub/dependencies.rs index 08dd39e2852e..97a1a1e4a684 100644 --- a/crates/uv-resolver/src/pubgrub/dependencies.rs +++ b/crates/uv-resolver/src/pubgrub/dependencies.rs @@ -12,37 +12,37 @@ use pypi_types::{ use uv_normalize::{ExtraName, PackageName}; use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner}; -use crate::resolver::ForkLocals; use crate::{PubGrubSpecifier, ResolveError}; #[derive(Clone, Debug)] pub(crate) struct PubGrubDependency { pub(crate) package: PubGrubPackage, pub(crate) version: Range, + + /// The original version specifiers from the requirement. + pub(crate) specifier: Option, + /// This field is set if the [`Requirement`] had a URL. We still use a URL from [`Urls`] /// even if this field is None where there is an override with a URL or there is a different /// requirement or constraint for the same package that has a URL. pub(crate) url: Option, - /// The local version for this requirement, if specified. - pub(crate) local: Option, } impl PubGrubDependency { pub(crate) fn from_requirement<'a>( requirement: &'a Requirement, source_name: Option<&'a PackageName>, - fork_locals: &'a ForkLocals, ) -> impl Iterator> + 'a { // Add the package, plus any extra variants. iter::once(None) .chain(requirement.extras.clone().into_iter().map(Some)) - .map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals)) + .map(|extra| PubGrubRequirement::from_requirement(requirement, extra)) .filter_map_ok(move |requirement| { let PubGrubRequirement { package, version, + specifier, url, - local, } = requirement; match &*package { PubGrubPackageInner::Package { name, .. } => { @@ -55,15 +55,15 @@ impl PubGrubDependency { Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url, - local, }) } PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url, - local, }), PubGrubPackageInner::Extra { name, .. } => { debug_assert!( @@ -73,8 +73,8 @@ impl PubGrubDependency { Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url: None, - local: None, }) } _ => None, @@ -88,8 +88,8 @@ impl PubGrubDependency { pub(crate) struct PubGrubRequirement { pub(crate) package: PubGrubPackage, pub(crate) version: Range, + pub(crate) specifier: Option, pub(crate) url: Option, - pub(crate) local: Option, } impl PubGrubRequirement { @@ -98,11 +98,10 @@ impl PubGrubRequirement { pub(crate) fn from_requirement( requirement: &Requirement, extra: Option, - fork_locals: &ForkLocals, ) -> Result { let (verbatim_url, parsed_url) = match &requirement.source { RequirementSource::Registry { specifier, .. } => { - return Self::from_registry_requirement(specifier, extra, requirement, fork_locals); + return Self::from_registry_requirement(specifier, extra, requirement); } RequirementSource::Url { subdirectory, @@ -165,11 +164,11 @@ impl PubGrubRequirement { requirement.marker.clone(), ), version: Range::full(), + specifier: None, url: Some(VerbatimParsedUrl { parsed_url, verbatim: verbatim_url.clone(), }), - local: None, }) } @@ -177,26 +176,8 @@ impl PubGrubRequirement { specifier: &VersionSpecifiers, extra: Option, requirement: &Requirement, - fork_locals: &ForkLocals, ) -> Result { - // If the specifier is an exact version and the user requested a local version for this - // fork that's more precise than the specifier, use the local version instead. - let version = if let Some(local) = fork_locals.get(&requirement.name) { - specifier - .iter() - .map(|specifier| { - ForkLocals::map(local, specifier) - .map_err(ResolveError::InvalidVersion) - .and_then(|specifier| { - Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) - }) - }) - .fold_ok(Range::full(), |range, specifier| { - range.intersection(&specifier.into()) - })? - } else { - PubGrubSpecifier::from_pep440_specifiers(specifier)?.into() - }; + let version = PubGrubSpecifier::from_pep440_specifiers(specifier)?.into(); let requirement = Self { package: PubGrubPackage::from_package( @@ -204,9 +185,9 @@ impl PubGrubRequirement { extra, requirement.marker.clone(), ), - version, + specifier: Some(specifier.clone()), url: None, - local: None, + version, }; Ok(requirement) diff --git a/crates/uv-resolver/src/resolver/locals.rs b/crates/uv-resolver/src/resolver/locals.rs index c078118b15e9..7e1a09a7681d 100644 --- a/crates/uv-resolver/src/resolver/locals.rs +++ b/crates/uv-resolver/src/resolver/locals.rs @@ -3,24 +3,74 @@ use std::str::FromStr; use distribution_filename::{SourceDistFilename, WheelFilename}; use distribution_types::RemoteSource; use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError}; -use pep508_rs::PackageName; +use pep508_rs::{MarkerEnvironment, MarkerTree, PackageName}; use pypi_types::RequirementSource; use rustc_hash::FxHashMap; -/// A map of package names to their associated, required local versions in a given fork. +use crate::{marker::is_disjoint, DependencyMode, Manifest, ResolverMarkers}; + +/// A map of package names to their associated, required local versions across all forks. #[derive(Debug, Default, Clone)] -pub(crate) struct ForkLocals(FxHashMap); +pub(crate) struct Locals(FxHashMap, Version)>>); + +impl Locals { + /// Determine the set of permitted local versions in the [`Manifest`]. + pub(crate) fn from_manifest( + manifest: &Manifest, + markers: Option<&MarkerEnvironment>, + dependencies: DependencyMode, + ) -> Self { + let mut required: FxHashMap> = FxHashMap::default(); + + // Add all direct requirements and constraints. There's no need to look for conflicts, + // since conflicts will be enforced by the solver. + for requirement in manifest.requirements(markers, dependencies) { + if let Some(local) = from_source(&requirement.source) { + required + .entry(requirement.name.clone()) + .or_default() + .push((requirement.marker.clone(), local)); + } + } -impl ForkLocals { - /// Insert the local [`Version`] to which a package is pinned for this fork. - pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) { - assert!(local.is_local()); - self.0.insert(package_name, local); + Self(required) } - /// Return the local [`Version`] to which a package is pinned in this fork, if any. - pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> { - self.0.get(package_name) + /// Return a list of local versions that are compatible with a package in the given fork. + pub(crate) fn get( + &self, + package_name: &PackageName, + markers: &ResolverMarkers, + ) -> Vec<&Version> { + let Some(locals) = self.0.get(package_name) else { + return Vec::new(); + }; + + match markers { + // If we are solving for a specific environment we already filtered + // compatible requirements `from_manifest`. + ResolverMarkers::SpecificEnvironment(_) => { + locals.first().map(|(_, local)| local).into_iter().collect() + } + + // Return all locals that were requested with markers that are compatible + // with the current fork. + // + // Compatibility implies that the markers are not disjoint. The resolver will + // choose the most compatible local when it narrows to the specific fork. + ResolverMarkers::Fork(fork) => locals + .iter() + .filter(|(marker, _)| { + !marker + .as_ref() + .is_some_and(|marker| is_disjoint(fork, marker)) + }) + .map(|(_, local)| local) + .collect(), + + // If we haven't forked yet, all locals are potentially compatible. + ResolverMarkers::Universal => locals.iter().map(|(_, local)| local).collect(), + } } /// Given a specifier that may include the version _without_ a local segment, return a specifier @@ -190,7 +240,7 @@ mod tests { use pypi_types::ParsedUrl; use pypi_types::RequirementSource; - use super::{from_source, ForkLocals}; + use super::{from_source, Locals}; #[test] fn extract_locals() -> Result<()> { @@ -251,7 +301,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -260,7 +310,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)? ); @@ -269,7 +319,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -278,7 +328,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)? ); @@ -287,7 +337,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)? ); @@ -296,7 +346,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -305,7 +355,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)? ); diff --git a/crates/uv-resolver/src/resolver/mod.rs b/crates/uv-resolver/src/resolver/mod.rs index e8dff8de3973..6000f31dd675 100644 --- a/crates/uv-resolver/src/resolver/mod.rs +++ b/crates/uv-resolver/src/resolver/mod.rs @@ -26,7 +26,7 @@ use distribution_types::{ IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource, ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef, }; -pub(crate) use locals::ForkLocals; +pub(crate) use locals::Locals; use pep440_rs::{Version, MIN_VERSION}; use pep508_rs::MarkerTree; use platform_tags::Tags; @@ -96,6 +96,7 @@ struct ResolverState { git: GitResolver, exclusions: Exclusions, urls: Urls, + locals: Locals, dependency_mode: DependencyMode, hasher: HashStrategy, markers: ResolverMarkers, @@ -215,6 +216,11 @@ impl git, options.dependency_mode, )?, + locals: Locals::from_manifest( + &manifest, + markers.marker_environment(), + options.dependency_mode, + ), project: manifest.project, requirements: manifest.requirements, constraints: manifest.constraints, @@ -309,7 +315,6 @@ impl ResolverState ResolverState ResolverState ResolverState ResolverState ResolverState, ) -> Result { - let result = self.get_dependencies( - package, - version, - fork_urls, - fork_locals, - markers, - requires_python, - ); + let result = self.get_dependencies(package, version, fork_urls, markers, requires_python); match markers { ResolverMarkers::SpecificEnvironment(_) => result.map(|deps| match deps { Dependencies::Available(deps) => ForkedDependencies::Unforked(deps), @@ -1128,7 +1126,6 @@ impl ResolverState, ) -> Result { @@ -1148,14 +1145,7 @@ impl ResolverState, _>>()? } PubGrubPackageInner::Package { @@ -1283,7 +1273,7 @@ impl ResolverState, _>>()?; @@ -1302,8 +1292,8 @@ impl ResolverState ResolverState ResolverState ResolverState, version: &Version, urls: &Urls, - dependencies: Vec, + locals: &Locals, + mut dependencies: Vec, git: &GitResolver, resolution_strategy: &ResolutionStrategy, ) -> Result<(), ResolveError> { - for dependency in &dependencies { + for dependency in &mut dependencies { let PubGrubDependency { package, version, + specifier, url, - local, } = dependency; let mut has_url = false; @@ -2057,11 +2046,29 @@ impl ForkState { has_url = true; }; - // `PubGrubDependency` also gives us a local version if specified by the user. - // Keep track of which local version we will be using in this fork for transitive - // dependencies. - if let Some(local) = local { - self.fork_locals.insert(name.clone(), local.clone()); + // If the specifier is an exact version and the user requested a local version for this + // fork that's more precise than the specifier, use the local version instead. + if let Some(specifier) = specifier { + // It's possible that there are multiple matching local versions requested with + // different marker expressions. All of these are potentially compatible until we + // narrow to a specific fork. + for local in locals.get(name, &self.markers) { + let local = specifier + .iter() + .map(|specifier| { + Locals::map(local, specifier) + .map_err(ResolveError::InvalidVersion) + .and_then(|specifier| { + Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) + }) + }) + .fold_ok(Range::full(), |range, specifier| { + range.intersection(&specifier.into()) + })?; + + // Add the local version. + *version = version.union(&local); + } } } @@ -2096,8 +2103,8 @@ impl ForkState { let PubGrubDependency { package, version, + specifier: _, url: _, - local: _, } = dependency; (package, version) }), diff --git a/crates/uv/tests/pip_compile.rs b/crates/uv/tests/pip_compile.rs index 683da129638a..2ce23fe2510a 100644 --- a/crates/uv/tests/pip_compile.rs +++ b/crates/uv/tests/pip_compile.rs @@ -6708,6 +6708,7 @@ fn universal_multi_version() -> Result<()> { Ok(()) } +// Requested distinct local versions with disjoint markers. #[test] fn universal_disjoint_locals() -> Result<()> { let context = TestContext::new("3.12"); @@ -6764,6 +6765,8 @@ fn universal_disjoint_locals() -> Result<()> { Ok(()) } +// Requested distinct local versions with disjoint markers of a package +// that is also present as a transitive dependency. #[test] fn universal_transitive_disjoint_locals() -> Result<()> { let context = TestContext::new("3.12"); @@ -6776,7 +6779,7 @@ fn universal_transitive_disjoint_locals() -> Result<()> { torchvision==0.15.1 "})?; - // The marker expressions on the output here are incorrect due to https://github.com/astral-sh/uv/issues/5086, + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, // but the local versions are still respected correctly. uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile() .arg("requirements.in") @@ -6842,6 +6845,502 @@ fn universal_transitive_disjoint_locals() -> Result<()> { Ok(()) } +/// Prefer local versions for dependencies of path requirements. +#[test] +fn universal_local_path_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 12 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests a local version with an overlapping marker expression, +/// we should prefer the local in all cases. +#[test] +fn universal_overlapping_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 12 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests distinct local versions with distinct marker expressions, +/// we should fork the root requirement. +#[test] +fn universal_disjoint_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + "torch==2.0.0+cpu ; platform_machine != 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cpu + # via + # -r requirements.in + # example + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 13 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests a local version with an overlapping marker expression +/// that form a nested fork, we should prefer the local in both children of the outer +/// fork. +#[test] +fn universal_nested_overlapping_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64' and os_name == 'Linux'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; platform_machine == 'x86_64' + torch==2.3.0 ; platform_machine != 'x86_64' + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; platform_machine != 'x86_64' + # via torch + intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 17 packages in [TIME] + "### + ); + + // A similar case, except the nested marker is now on the path requirement. + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; platform_machine == 'x86_64' + torch==2.3.0 ; platform_machine != 'x86_64' + . ; os_name == 'Linux' + "})?; + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + ] + requires-python = ">=3.11" + "#})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . ; os_name == 'Linux' + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; platform_machine != 'x86_64' + # via torch + intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 17 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests distinct local versions with disjoint marker expressions +/// that form a nested fork, we should create a nested fork. +#[test] +fn universal_nested_disjoint_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + "torch==2.0.0+cpu ; platform_machine != 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; os_name == 'Linux' + torch==2.3.0 ; os_name != 'Linux' + . ; os_name == 'Linux' + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . ; os_name == 'Linux' + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; os_name != 'Linux' + # via torch + intel-openmp==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + nvidia-cublas-cu12==12.1.3.1 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch + nvidia-cuda-cupti-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cuda-nvrtc-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cuda-runtime-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cudnn-cu12==8.9.2.26 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cufft-cu12==11.0.2.54 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-curand-cu12==10.3.2.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cusolver-cu12==11.4.5.107 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cusparse-cu12==12.1.0.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cusolver-cu12 + # torch + nvidia-nccl-cu12==2.20.5 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-nvjitlink-cu12==12.4.99 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + nvidia-nvtx-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via mkl + torch==2.0.0+cu118 ; os_name == 'Linux' + # via + # -r requirements.in + # example + # triton + torch==2.3.0 ; os_name != 'Linux' + # via -r requirements.in + torch==2.0.0+cpu ; os_name == 'Linux' + # via + # -r requirements.in + # example + triton==2.0.0 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 30 packages in [TIME] + "### + ); + + Ok(()) +} + /// Perform a universal resolution that requires narrowing the supported Python range in one of the /// fork branches. /// From 4155d782a3128d6726c8866df839fb53c2baac08 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 19 Jul 2024 17:44:32 -0400 Subject: [PATCH 2/2] exclude non-local version from range --- crates/uv-resolver/src/resolver/mod.rs | 9 ++- crates/uv/tests/pip_compile.rs | 81 +++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/crates/uv-resolver/src/resolver/mod.rs b/crates/uv-resolver/src/resolver/mod.rs index 6000f31dd675..97f6c5bfe48a 100644 --- a/crates/uv-resolver/src/resolver/mod.rs +++ b/crates/uv-resolver/src/resolver/mod.rs @@ -2049,10 +2049,17 @@ impl ForkState { // If the specifier is an exact version and the user requested a local version for this // fork that's more precise than the specifier, use the local version instead. if let Some(specifier) = specifier { + let locals = locals.get(name, &self.markers); + + // Prioritize local versions over the original version range. + if !locals.is_empty() { + *version = Range::empty(); + } + // It's possible that there are multiple matching local versions requested with // different marker expressions. All of these are potentially compatible until we // narrow to a specific fork. - for local in locals.get(name, &self.markers) { + for local in locals { let local = specifier .iter() .map(|specifier| { diff --git a/crates/uv/tests/pip_compile.rs b/crates/uv/tests/pip_compile.rs index 2ce23fe2510a..14a5b4f6f1c0 100644 --- a/crates/uv/tests/pip_compile.rs +++ b/crates/uv/tests/pip_compile.rs @@ -6986,7 +6986,7 @@ fn universal_overlapping_local_requirement() -> Result<()> { Ok(()) } -/// If a dependency requests distinct local versions with distinct marker expressions, +/// If a dependency requests distinct local versions with disjoint marker expressions, /// we should fork the root requirement. #[test] fn universal_disjoint_local_requirement() -> Result<()> { @@ -7064,6 +7064,85 @@ fn universal_disjoint_local_requirement() -> Result<()> { Ok(()) } +/// If a dependency requests distinct local versions and non-local versions with disjoint marker +/// expressions, we should fork the root requirement. +#[test] +fn universal_disjoint_base_or_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0; python_version < '3.10'", + "torch==2.0.0+cu118 ; python_version >= '3.10' and python_version <= '3.12'", + "torch==2.0.0+cpu ; python_version > '3.12'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cpu + # via + # -r requirements.in + # example + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 13 packages in [TIME] + "### + ); + + Ok(()) +} + /// If a dependency requests a local version with an overlapping marker expression /// that form a nested fork, we should prefer the local in both children of the outer /// fork.