diff --git a/crates/uv-resolver/src/pubgrub/dependencies.rs b/crates/uv-resolver/src/pubgrub/dependencies.rs index 745ddef3fca3..08dd39e2852e 100644 --- a/crates/uv-resolver/src/pubgrub/dependencies.rs +++ b/crates/uv-resolver/src/pubgrub/dependencies.rs @@ -12,7 +12,7 @@ use pypi_types::{ use uv_normalize::{ExtraName, PackageName}; use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner}; -use crate::resolver::Locals; +use crate::resolver::ForkLocals; use crate::{PubGrubSpecifier, ResolveError}; #[derive(Clone, Debug)] @@ -23,23 +23,26 @@ pub(crate) struct PubGrubDependency { /// even if this field is None where there is an override with a URL or there is a different /// requirement or constraint for the same package that has a URL. pub(crate) url: Option, + /// The local version for this requirement, if specified. + pub(crate) local: Option, } impl PubGrubDependency { pub(crate) fn from_requirement<'a>( requirement: &'a Requirement, source_name: Option<&'a PackageName>, - locals: &'a Locals, + fork_locals: &'a ForkLocals, ) -> impl Iterator> + 'a { // Add the package, plus any extra variants. iter::once(None) .chain(requirement.extras.clone().into_iter().map(Some)) - .map(|extra| PubGrubRequirement::from_requirement(requirement, extra, locals)) + .map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals)) .filter_map_ok(move |requirement| { let PubGrubRequirement { package, version, url, + local, } = requirement; match &*package { PubGrubPackageInner::Package { name, .. } => { @@ -53,12 +56,14 @@ impl PubGrubDependency { package: package.clone(), version: version.clone(), url, + local, }) } PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency { package: package.clone(), version: version.clone(), url, + local, }), PubGrubPackageInner::Extra { name, .. } => { debug_assert!( @@ -69,6 +74,7 @@ impl PubGrubDependency { package: package.clone(), version: version.clone(), url: None, + local: None, }) } _ => None, @@ -83,6 +89,7 @@ pub(crate) struct PubGrubRequirement { pub(crate) package: PubGrubPackage, pub(crate) version: Range, pub(crate) url: Option, + pub(crate) local: Option, } impl PubGrubRequirement { @@ -91,11 +98,11 @@ impl PubGrubRequirement { pub(crate) fn from_requirement( requirement: &Requirement, extra: Option, - locals: &Locals, + fork_locals: &ForkLocals, ) -> Result { let (verbatim_url, parsed_url) = match &requirement.source { RequirementSource::Registry { specifier, .. } => { - return Self::from_registry_requirement(specifier, extra, requirement, locals); + return Self::from_registry_requirement(specifier, extra, requirement, fork_locals); } RequirementSource::Url { subdirectory, @@ -162,6 +169,7 @@ impl PubGrubRequirement { parsed_url, verbatim: verbatim_url.clone(), }), + local: None, }) } @@ -169,15 +177,15 @@ impl PubGrubRequirement { specifier: &VersionSpecifiers, extra: Option, requirement: &Requirement, - locals: &Locals, + fork_locals: &ForkLocals, ) -> Result { - // If the specifier is an exact version, and the user requested a local version that's - // more precise than the specifier, use the local version instead. - let version = if let Some(expected) = locals.get(&requirement.name) { + // If the specifier is an exact version and the user requested a local version for this + // fork that's more precise than the specifier, use the local version instead. + let version = if let Some(local) = fork_locals.get(&requirement.name) { specifier .iter() .map(|specifier| { - Locals::map(expected, specifier) + ForkLocals::map(local, specifier) .map_err(ResolveError::InvalidVersion) .and_then(|specifier| { Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) @@ -198,7 +206,9 @@ impl PubGrubRequirement { ), version, url: None, + local: None, }; + Ok(requirement) } } diff --git a/crates/uv-resolver/src/resolver/locals.rs b/crates/uv-resolver/src/resolver/locals.rs index 079085ba2063..142799881b29 100644 --- a/crates/uv-resolver/src/resolver/locals.rs +++ b/crates/uv-resolver/src/resolver/locals.rs @@ -1,46 +1,25 @@ -use std::iter; use std::str::FromStr; -use rustc_hash::FxHashMap; - use distribution_filename::{SourceDistFilename, WheelFilename}; use distribution_types::RemoteSource; use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError}; -use pep508_rs::MarkerEnvironment; +use pep508_rs::PackageName; use pypi_types::RequirementSource; -use uv_normalize::PackageName; - -use crate::{DependencyMode, Manifest}; - -#[derive(Debug, Default)] -pub(crate) struct Locals { - /// A map of package names to their associated, required local versions. - required: FxHashMap, -} +use rustc_hash::FxHashMap; -impl Locals { - /// Determine the set of permitted local versions in the [`Manifest`]. - pub(crate) fn from_manifest( - manifest: &Manifest, - markers: Option<&MarkerEnvironment>, - dependencies: DependencyMode, - ) -> Self { - let mut required: FxHashMap = FxHashMap::default(); +/// A map of package names to their associated, required local versions in a given fork. +#[derive(Debug, Default, Clone)] +pub(crate) struct ForkLocals(FxHashMap); - // Add all direct requirements and constraints. There's no need to look for conflicts, - // since conflicts will be enforced by the solver. - for requirement in manifest.requirements(markers, dependencies) { - for local in iter_locals(&requirement.source) { - required.insert(requirement.name.clone(), local); - } - } - - Self { required } +impl ForkLocals { + /// Insert the local [`Version`] to which a package is pinned for this fork. + pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) { + self.0.insert(package_name, local); } - /// Return the local [`Version`] to which a package is pinned, if any. - pub(crate) fn get(&self, package: &PackageName) -> Option<&Version> { - self.required.get(package) + /// Return the local [`Version`] to which a package is pinned in this fork, if any. + pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> { + self.0.get(package_name) } /// Given a specifier that may include the version _without_ a local segment, return a specifier @@ -140,63 +119,61 @@ fn is_compatible(expected: &Version, provided: &Version) -> bool { } } -/// If a [`VersionSpecifier`] contains exact equality specifiers for a local version, returns an -/// iterator over the local versions. -fn iter_locals(source: &RequirementSource) -> Box + '_> { +/// If a [`VersionSpecifier`] contains an exact equality specifier for a local version, +/// returns the local version. +pub(crate) fn from_source(source: &RequirementSource) -> Option { match source { // Extract all local versions from specifiers that require an exact version (e.g., // `==1.0.0+local`). RequirementSource::Registry { specifier: version, .. - } => Box::new( - version - .iter() - .filter(|specifier| { - matches!(specifier.operator(), Operator::Equal | Operator::ExactEqual) - }) - .filter(|specifier| !specifier.version().local().is_empty()) - .map(|specifier| specifier.version().clone()), - ), + } => version + .iter() + .filter(|specifier| { + matches!(specifier.operator(), Operator::Equal | Operator::ExactEqual) + }) + .filter(|specifier| !specifier.version().local().is_empty()) + .map(|specifier| specifier.version().clone()) + // It's technically possible for there to be multiple local segments here. + // For example, `a==1.0+foo,==1.0+bar`. However, in that case resolution + // will fail later. + .next(), // Exact a local version from a URL, if it includes a fully-qualified filename (e.g., // `torch-2.2.1%2Bcu118-cp311-cp311-linux_x86_64.whl`). - RequirementSource::Url { url, .. } => Box::new( - url.filename() - .ok() - .and_then(|filename| { - if let Ok(filename) = WheelFilename::from_str(&filename) { - Some(filename.version) - } else if let Ok(filename) = - SourceDistFilename::parsed_normalized_filename(&filename) - { - Some(filename.version) - } else { - None - } - }) - .into_iter() - .filter(pep440_rs::Version::is_local), - ), - RequirementSource::Git { .. } => Box::new(iter::empty()), + RequirementSource::Url { url, .. } => url + .filename() + .ok() + .and_then(|filename| { + if let Ok(filename) = WheelFilename::from_str(&filename) { + Some(filename.version) + } else if let Ok(filename) = + SourceDistFilename::parsed_normalized_filename(&filename) + { + Some(filename.version) + } else { + None + } + }) + .filter(pep440_rs::Version::is_local), + RequirementSource::Git { .. } => None, RequirementSource::Path { install_path: path, .. - } => Box::new( - path.file_name() - .and_then(|filename| { - let filename = filename.to_string_lossy(); - if let Ok(filename) = WheelFilename::from_str(&filename) { - Some(filename.version) - } else if let Ok(filename) = - SourceDistFilename::parsed_normalized_filename(&filename) - { - Some(filename.version) - } else { - None - } - }) - .into_iter() - .filter(pep440_rs::Version::is_local), - ), - RequirementSource::Directory { .. } => Box::new(iter::empty()), + } => path + .file_name() + .and_then(|filename| { + let filename = filename.to_string_lossy(); + if let Ok(filename) = WheelFilename::from_str(&filename) { + Some(filename.version) + } else if let Ok(filename) = + SourceDistFilename::parsed_normalized_filename(&filename) + { + Some(filename.version) + } else { + None + } + }) + .filter(pep440_rs::Version::is_local), + RequirementSource::Directory { .. } => None, } } @@ -212,7 +189,7 @@ mod tests { use pypi_types::ParsedUrl; use pypi_types::RequirementSource; - use crate::resolver::locals::{iter_locals, Locals}; + use super::{from_source, ForkLocals}; #[test] fn extract_locals() -> Result<()> { @@ -220,7 +197,7 @@ mod tests { let url = VerbatimUrl::from_url(Url::parse("https://example.com/foo-1.0.0+local.tar.gz")?); let source = RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); - let locals: Vec<_> = iter_locals(&source).collect(); + let locals: Vec<_> = from_source(&source).into_iter().collect(); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); // Extract from a wheel in a URL. @@ -229,14 +206,14 @@ mod tests { )?); let source = RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); - let locals: Vec<_> = iter_locals(&source).collect(); + let locals: Vec<_> = from_source(&source).into_iter().collect(); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); // Don't extract anything if the URL is opaque. let url = VerbatimUrl::from_url(Url::parse("git+https://example.com/foo/bar")?); let source = RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); - let locals: Vec<_> = iter_locals(&source).collect(); + let locals: Vec<_> = from_source(&source).into_iter().collect(); assert!(locals.is_empty()); // Extract from `==` specifiers. @@ -248,7 +225,7 @@ mod tests { specifier: version, index: None, }; - let locals: Vec<_> = iter_locals(&source).collect(); + let locals: Vec<_> = from_source(&source).into_iter().collect(); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); // Ignore other specifiers. @@ -260,7 +237,7 @@ mod tests { specifier: version, index: None, }; - let locals: Vec<_> = iter_locals(&source).collect(); + let locals: Vec<_> = from_source(&source).into_iter().collect(); assert!(locals.is_empty()); Ok(()) @@ -273,7 +250,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -282,7 +259,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)? ); @@ -291,7 +268,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -300,7 +277,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)? ); @@ -309,7 +286,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)? ); @@ -318,7 +295,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -327,7 +304,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?; assert_eq!( - Locals::map(&local, &specifier)?, + ForkLocals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)? ); diff --git a/crates/uv-resolver/src/resolver/mod.rs b/crates/uv-resolver/src/resolver/mod.rs index 43fe644d393f..01a4dd4c01ac 100644 --- a/crates/uv-resolver/src/resolver/mod.rs +++ b/crates/uv-resolver/src/resolver/mod.rs @@ -26,7 +26,7 @@ use distribution_types::{ IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource, ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef, }; -pub(crate) use locals::Locals; +pub(crate) use locals::ForkLocals; use pep440_rs::{Version, MIN_VERSION}; use pep508_rs::{MarkerEnvironment, MarkerTree}; use platform_tags::Tags; @@ -92,7 +92,6 @@ struct ResolverState { git: GitResolver, exclusions: Exclusions, urls: Urls, - locals: Locals, dependency_mode: DependencyMode, hasher: HashStrategy, /// When not set, the resolver is in "universal" mode. @@ -197,7 +196,6 @@ impl selector: CandidateSelector::for_resolution(options, &manifest, markers), dependency_mode: options.dependency_mode, urls: Urls::from_manifest(&manifest, markers, git, options.dependency_mode)?, - locals: Locals::from_manifest(&manifest, markers, options.dependency_mode), project: manifest.project, requirements: manifest.requirements, constraints: manifest.constraints, @@ -292,6 +290,7 @@ impl ResolverState ResolverState ResolverState ResolverState ResolverState, ) -> Result { - let result = self.get_dependencies(package, version, fork_urls, markers, requires_python); + let result = self.get_dependencies( + package, + version, + fork_urls, + fork_locals, + markers, + requires_python, + ); if self.markers.is_some() { return result.map(|deps| match deps { Dependencies::Available(deps) => ForkedDependencies::Unforked(deps), @@ -1053,6 +1063,7 @@ impl ResolverState, ) -> Result { @@ -1073,7 +1084,12 @@ impl ResolverState, _>>()? } @@ -1202,13 +1218,13 @@ impl ResolverState, _>>()?; + // If a package has metadata for an enabled dependency group, // add a dependency from it to the same package with the group // enabled. - if extra.is_none() && dev.is_none() { for group in &self.dev { if !metadata.dev_dependencies.contains_key(group) { @@ -1222,6 +1238,7 @@ impl ResolverState ResolverState ResolverState ResolverState, @@ -1941,16 +1963,24 @@ impl ForkState { package, version, url, + local, } = dependency; - // From the [`Requirement`] to [`PubGrubDependency`] conversion, we get a URL if the - // requirement was a URL requirement. `Urls` applies canonicalization to this and - // override URLs to both URL and registry requirements, which we then check for - // conflicts using [`ForkUrl`]. if let Some(name) = package.name() { + // From the [`Requirement`] to [`PubGrubDependency`] conversion, we get a URL if the + // requirement was a URL requirement. `Urls` applies canonicalization to this and + // override URLs to both URL and registry requirements, which we then check for + // conflicts using [`ForkUrl`]. if let Some(url) = urls.get_url(name, url.as_ref(), git)? { self.fork_urls.insert(name, url, &self.markers)?; }; + + // `PubGrubDependency` also gives us a local version if specified by the user. + // Keep track of which local version we will be using in this fork for transitive + // dependencies. + if let Some(local) = local { + self.fork_locals.insert(name.clone(), local.clone()); + } } if let Some(for_package) = for_package { @@ -1972,6 +2002,7 @@ impl ForkState { package, version, url: _, + local: _, } = dependency; (package, version) }), diff --git a/crates/uv/tests/pip_compile.rs b/crates/uv/tests/pip_compile.rs index 51f7f9a9d618..d9098b95cfa2 100644 --- a/crates/uv/tests/pip_compile.rs +++ b/crates/uv/tests/pip_compile.rs @@ -6708,6 +6708,138 @@ fn universal_multi_version() -> Result<()> { Ok(()) } +#[test] +fn universal_disjoint_locals() -> Result<()> { + let context = TestContext::new("3.12"); + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc::indoc! {r" + --find-links https://download.pytorch.org/whl/torch_stable.html + + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + torch==2.0.0+cpu ; platform_machine != 'x86_64' + "})?; + + uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile() + .arg("requirements.in") + .arg("--universal"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cpu ; platform_machine != 'x86_64' + # via -r requirements.in + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + # via + # -r requirements.in + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 12 packages in [TIME] + "### + ); + + Ok(()) +} + +#[test] +fn universal_transitive_disjoint_locals() -> Result<()> { + let context = TestContext::new("3.12"); + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc::indoc! {r" + --find-links https://download.pytorch.org/whl/torch_stable.html + + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + torch==2.0.0+cpu ; platform_machine != 'x86_64' + torchvision==0.15.1 + "})?; + + uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile() + .arg("requirements.in") + .arg("--universal"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + certifi==2024.2.2 + # via requests + charset-normalizer==3.3.2 + # via requests + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + filelock==3.13.1 + # via + # torch + # triton + idna==3.6 + # via requests + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + numpy==1.26.4 + # via torchvision + pillow==10.2.0 + # via torchvision + requests==2.31.0 + # via torchvision + sympy==1.12 + # via torch + torch==2.0.0+cpu + # via + # -r requirements.in + # torchvision + torch==2.0.0+cu118 + # via + # -r requirements.in + # torchvision + # triton + torchvision==0.15.1 + # via -r requirements.in + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + urllib3==2.2.1 + # via requests + + ----- stderr ----- + Resolved 20 packages in [TIME] + "### + ); + + Ok(()) +} + /// Perform a universal resolution that requires narrowing the supported Python range in one of the /// fork branches. ///