Skip to content

Commit

Permalink
respect local versions for all user requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
ibraheemdev committed Jul 19, 2024
1 parent 12dd450 commit c94e437
Show file tree
Hide file tree
Showing 4 changed files with 632 additions and 94 deletions.
49 changes: 15 additions & 34 deletions crates/uv-resolver/src/pubgrub/dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,37 @@ use pypi_types::{
use uv_normalize::{ExtraName, PackageName};

use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner};
use crate::resolver::ForkLocals;
use crate::{PubGrubSpecifier, ResolveError};

#[derive(Clone, Debug)]
pub(crate) struct PubGrubDependency {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,

/// The original version specifiers from the requirement.
pub(crate) specifier: Option<VersionSpecifiers>,

/// This field is set if the [`Requirement`] had a URL. We still use a URL from [`Urls`]
/// even if this field is None where there is an override with a URL or there is a different
/// requirement or constraint for the same package that has a URL.
pub(crate) url: Option<VerbatimParsedUrl>,
/// The local version for this requirement, if specified.
pub(crate) local: Option<Version>,
}

impl PubGrubDependency {
pub(crate) fn from_requirement<'a>(
requirement: &'a Requirement,
source_name: Option<&'a PackageName>,
fork_locals: &'a ForkLocals,
) -> impl Iterator<Item = Result<Self, ResolveError>> + 'a {
// Add the package, plus any extra variants.
iter::once(None)
.chain(requirement.extras.clone().into_iter().map(Some))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra))
.filter_map_ok(move |requirement| {
let PubGrubRequirement {
package,
version,
specifier,
url,
local,
} = requirement;
match &*package {
PubGrubPackageInner::Package { name, .. } => {
Expand All @@ -55,15 +55,15 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
})
}
PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
}),
PubGrubPackageInner::Extra { name, .. } => {
debug_assert!(
Expand All @@ -73,8 +73,8 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url: None,
local: None,
})
}
_ => None,
Expand All @@ -88,8 +88,8 @@ impl PubGrubDependency {
pub(crate) struct PubGrubRequirement {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,
pub(crate) specifier: Option<VersionSpecifiers>,
pub(crate) url: Option<VerbatimParsedUrl>,
pub(crate) local: Option<Version>,
}

impl PubGrubRequirement {
Expand All @@ -98,11 +98,10 @@ impl PubGrubRequirement {
pub(crate) fn from_requirement(
requirement: &Requirement,
extra: Option<ExtraName>,
fork_locals: &ForkLocals,
) -> Result<Self, ResolveError> {
let (verbatim_url, parsed_url) = match &requirement.source {
RequirementSource::Registry { specifier, .. } => {
return Self::from_registry_requirement(specifier, extra, requirement, fork_locals);
return Self::from_registry_requirement(specifier, extra, requirement);
}
RequirementSource::Url {
subdirectory,
Expand Down Expand Up @@ -165,48 +164,30 @@ impl PubGrubRequirement {
requirement.marker.clone(),
),
version: Range::full(),
specifier: None,
url: Some(VerbatimParsedUrl {
parsed_url,
verbatim: verbatim_url.clone(),
}),
local: None,
})
}

fn from_registry_requirement(
specifier: &VersionSpecifiers,
extra: Option<ExtraName>,
requirement: &Requirement,
fork_locals: &ForkLocals,
) -> Result<PubGrubRequirement, ResolveError> {
// If the specifier is an exact version and the user requested a local version for this
// fork that's more precise than the specifier, use the local version instead.
let version = if let Some(local) = fork_locals.get(&requirement.name) {
specifier
.iter()
.map(|specifier| {
ForkLocals::map(local, specifier)
.map_err(ResolveError::InvalidVersion)
.and_then(|specifier| {
Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?)
})
})
.fold_ok(Range::full(), |range, specifier| {
range.intersection(&specifier.into())
})?
} else {
PubGrubSpecifier::from_pep440_specifiers(specifier)?.into()
};
let version = PubGrubSpecifier::from_pep440_specifiers(specifier)?.into();

let requirement = Self {
package: PubGrubPackage::from_package(
requirement.name.clone(),
extra,
requirement.marker.clone(),
),
version,
specifier: Some(specifier.clone()),
url: None,
local: None,
version,
};

Ok(requirement)
Expand Down
88 changes: 69 additions & 19 deletions crates/uv-resolver/src/resolver/locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,74 @@ use std::str::FromStr;
use distribution_filename::{SourceDistFilename, WheelFilename};
use distribution_types::RemoteSource;
use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError};
use pep508_rs::PackageName;
use pep508_rs::{MarkerEnvironment, MarkerTree, PackageName};
use pypi_types::RequirementSource;
use rustc_hash::FxHashMap;

/// A map of package names to their associated, required local versions in a given fork.
use crate::{marker::is_disjoint, DependencyMode, Manifest, ResolverMarkers};

/// A map of package names to their associated, required local versions across all forks.
#[derive(Debug, Default, Clone)]
pub(crate) struct ForkLocals(FxHashMap<PackageName, Version>);
pub(crate) struct Locals(FxHashMap<PackageName, Vec<(Option<MarkerTree>, Version)>>);

impl Locals {
/// Determine the set of permitted local versions in the [`Manifest`].
pub(crate) fn from_manifest(
manifest: &Manifest,
markers: Option<&MarkerEnvironment>,
dependencies: DependencyMode,
) -> Self {
let mut required: FxHashMap<PackageName, Vec<_>> = FxHashMap::default();

// Add all direct requirements and constraints. There's no need to look for conflicts,
// since conflicts will be enforced by the solver.
for requirement in manifest.requirements(markers, dependencies) {
if let Some(local) = from_source(&requirement.source) {
required
.entry(requirement.name.clone())
.or_default()
.push((requirement.marker.clone(), local));
}
}

impl ForkLocals {
/// Insert the local [`Version`] to which a package is pinned for this fork.
pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) {
assert!(local.is_local());
self.0.insert(package_name, local);
Self(required)
}

/// Return the local [`Version`] to which a package is pinned in this fork, if any.
pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> {
self.0.get(package_name)
/// Return a list of local versions that are compatible with a package in the given fork.
pub(crate) fn get(
&self,
package_name: &PackageName,
markers: &ResolverMarkers,
) -> Vec<&Version> {
let Some(locals) = self.0.get(package_name) else {
return Vec::new();
};

match markers {
// If we are solving for a specific environment we already filtered
// compatible requirements `from_manifest`.
ResolverMarkers::SpecificEnvironment(_) => {
locals.first().map(|(_, local)| local).into_iter().collect()
}

// Return all locals that were requested with markers that are compatible
// with the current fork.
//
// Compatibility implies that the markers are not disjoint. The resolver will
// choose the most compatible local when it narrows to the specific fork.
ResolverMarkers::Fork(fork) => locals
.iter()
.filter(|(marker, _)| {
!marker
.as_ref()
.is_some_and(|marker| is_disjoint(fork, marker))
})
.map(|(_, local)| local)
.collect(),

// If we haven't forked yet, all locals are potentially compatible.
ResolverMarkers::Universal => locals.iter().map(|(_, local)| local).collect(),
}
}

/// Given a specifier that may include the version _without_ a local segment, return a specifier
Expand Down Expand Up @@ -190,7 +240,7 @@ mod tests {
use pypi_types::ParsedUrl;
use pypi_types::RequirementSource;

use super::{from_source, ForkLocals};
use super::{from_source, Locals};

#[test]
fn extract_locals() -> Result<()> {
Expand Down Expand Up @@ -251,7 +301,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -260,7 +310,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -269,7 +319,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -278,7 +328,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?
);

Expand All @@ -287,7 +337,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?
);

Expand All @@ -296,7 +346,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -305,7 +355,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?
);

Expand Down
Loading

0 comments on commit c94e437

Please sign in to comment.