Skip to content

Commit

Permalink
Respect local versions for all user requirements (#5232)
Browse files Browse the repository at this point in the history
## Summary

This fixes a few bugs introduced by
#5104. I previously thought we could
track conflicting locals the same way we track conflicting URLs in
forks, but it turns out that ends up being very tricky. URL forks work
because we prioritize directly URL requirements. We can't prioritize
locals in the same way without conflicting with the URL prioritization
(this may be possible but it's not trivial), so we run into issues where
a correct resolution depends on the order in which dependencies are
traversed.

Instead, we track local versions across all forks in `Locals`. When
applying a local version, we apply all locals with markers that
intersect with the current fork. This way we end up applying some local
versions without creating a fork. For example, given:
```
// pyproject.toml
dependencies = [
    "torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
]

// requirements.in
torch==2.0.0
.
```

We choose `2.0.0+cu118` in all cases. However, if a disjoint fork is
created based on local versions, the resolver will choose the most
compatible local when it narrows to a specific fork. Thus we correctly
respect local versions when forking:
```
// pyproject.toml
dependencies = [
    "torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
    "torch==2.0.0+cpu ; platform_machine != 'x86_64'"
]

// requirements.in
torch==2.0.0
.
``` 

We should also be able to use a similar strategy for
#5150.

## Test Plan

This fixes #5220 locally for me,
as well as a few other bugs that were not reported yet.
  • Loading branch information
ibraheemdev committed Jul 19, 2024
1 parent 92e1102 commit bb73edb
Show file tree
Hide file tree
Showing 4 changed files with 717 additions and 94 deletions.
49 changes: 15 additions & 34 deletions crates/uv-resolver/src/pubgrub/dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,37 @@ use pypi_types::{
use uv_normalize::{ExtraName, PackageName};

use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner};
use crate::resolver::ForkLocals;
use crate::{PubGrubSpecifier, ResolveError};

#[derive(Clone, Debug)]
pub(crate) struct PubGrubDependency {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,

/// The original version specifiers from the requirement.
pub(crate) specifier: Option<VersionSpecifiers>,

/// This field is set if the [`Requirement`] had a URL. We still use a URL from [`Urls`]
/// even if this field is None where there is an override with a URL or there is a different
/// requirement or constraint for the same package that has a URL.
pub(crate) url: Option<VerbatimParsedUrl>,
/// The local version for this requirement, if specified.
pub(crate) local: Option<Version>,
}

impl PubGrubDependency {
pub(crate) fn from_requirement<'a>(
requirement: &'a Requirement,
source_name: Option<&'a PackageName>,
fork_locals: &'a ForkLocals,
) -> impl Iterator<Item = Result<Self, ResolveError>> + 'a {
// Add the package, plus any extra variants.
iter::once(None)
.chain(requirement.extras.clone().into_iter().map(Some))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra))
.filter_map_ok(move |requirement| {
let PubGrubRequirement {
package,
version,
specifier,
url,
local,
} = requirement;
match &*package {
PubGrubPackageInner::Package { name, .. } => {
Expand All @@ -55,15 +55,15 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
})
}
PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
}),
PubGrubPackageInner::Extra { name, .. } => {
debug_assert!(
Expand All @@ -73,8 +73,8 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url: None,
local: None,
})
}
_ => None,
Expand All @@ -88,8 +88,8 @@ impl PubGrubDependency {
pub(crate) struct PubGrubRequirement {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,
pub(crate) specifier: Option<VersionSpecifiers>,
pub(crate) url: Option<VerbatimParsedUrl>,
pub(crate) local: Option<Version>,
}

impl PubGrubRequirement {
Expand All @@ -98,11 +98,10 @@ impl PubGrubRequirement {
pub(crate) fn from_requirement(
requirement: &Requirement,
extra: Option<ExtraName>,
fork_locals: &ForkLocals,
) -> Result<Self, ResolveError> {
let (verbatim_url, parsed_url) = match &requirement.source {
RequirementSource::Registry { specifier, .. } => {
return Self::from_registry_requirement(specifier, extra, requirement, fork_locals);
return Self::from_registry_requirement(specifier, extra, requirement);
}
RequirementSource::Url {
subdirectory,
Expand Down Expand Up @@ -165,48 +164,30 @@ impl PubGrubRequirement {
requirement.marker.clone(),
),
version: Range::full(),
specifier: None,
url: Some(VerbatimParsedUrl {
parsed_url,
verbatim: verbatim_url.clone(),
}),
local: None,
})
}

fn from_registry_requirement(
specifier: &VersionSpecifiers,
extra: Option<ExtraName>,
requirement: &Requirement,
fork_locals: &ForkLocals,
) -> Result<PubGrubRequirement, ResolveError> {
// If the specifier is an exact version and the user requested a local version for this
// fork that's more precise than the specifier, use the local version instead.
let version = if let Some(local) = fork_locals.get(&requirement.name) {
specifier
.iter()
.map(|specifier| {
ForkLocals::map(local, specifier)
.map_err(ResolveError::InvalidVersion)
.and_then(|specifier| {
Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?)
})
})
.fold_ok(Range::full(), |range, specifier| {
range.intersection(&specifier.into())
})?
} else {
PubGrubSpecifier::from_pep440_specifiers(specifier)?.into()
};
let version = PubGrubSpecifier::from_pep440_specifiers(specifier)?.into();

let requirement = Self {
package: PubGrubPackage::from_package(
requirement.name.clone(),
extra,
requirement.marker.clone(),
),
version,
specifier: Some(specifier.clone()),
url: None,
local: None,
version,
};

Ok(requirement)
Expand Down
88 changes: 69 additions & 19 deletions crates/uv-resolver/src/resolver/locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,74 @@ use std::str::FromStr;
use distribution_filename::{SourceDistFilename, WheelFilename};
use distribution_types::RemoteSource;
use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError};
use pep508_rs::PackageName;
use pep508_rs::{MarkerEnvironment, MarkerTree, PackageName};
use pypi_types::RequirementSource;
use rustc_hash::FxHashMap;

/// A map of package names to their associated, required local versions in a given fork.
use crate::{marker::is_disjoint, DependencyMode, Manifest, ResolverMarkers};

/// A map of package names to their associated, required local versions across all forks.
#[derive(Debug, Default, Clone)]
pub(crate) struct ForkLocals(FxHashMap<PackageName, Version>);
pub(crate) struct Locals(FxHashMap<PackageName, Vec<(Option<MarkerTree>, Version)>>);

impl Locals {
/// Determine the set of permitted local versions in the [`Manifest`].
pub(crate) fn from_manifest(
manifest: &Manifest,
markers: Option<&MarkerEnvironment>,
dependencies: DependencyMode,
) -> Self {
let mut required: FxHashMap<PackageName, Vec<_>> = FxHashMap::default();

// Add all direct requirements and constraints. There's no need to look for conflicts,
// since conflicts will be enforced by the solver.
for requirement in manifest.requirements(markers, dependencies) {
if let Some(local) = from_source(&requirement.source) {
required
.entry(requirement.name.clone())
.or_default()
.push((requirement.marker.clone(), local));
}
}

impl ForkLocals {
/// Insert the local [`Version`] to which a package is pinned for this fork.
pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) {
assert!(local.is_local());
self.0.insert(package_name, local);
Self(required)
}

/// Return the local [`Version`] to which a package is pinned in this fork, if any.
pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> {
self.0.get(package_name)
/// Return a list of local versions that are compatible with a package in the given fork.
pub(crate) fn get(
&self,
package_name: &PackageName,
markers: &ResolverMarkers,
) -> Vec<&Version> {
let Some(locals) = self.0.get(package_name) else {
return Vec::new();
};

match markers {
// If we are solving for a specific environment we already filtered
// compatible requirements `from_manifest`.
ResolverMarkers::SpecificEnvironment(_) => {
locals.first().map(|(_, local)| local).into_iter().collect()
}

// Return all locals that were requested with markers that are compatible
// with the current fork.
//
// Compatibility implies that the markers are not disjoint. The resolver will
// choose the most compatible local when it narrows to the specific fork.
ResolverMarkers::Fork(fork) => locals
.iter()
.filter(|(marker, _)| {
!marker
.as_ref()
.is_some_and(|marker| is_disjoint(fork, marker))
})
.map(|(_, local)| local)
.collect(),

// If we haven't forked yet, all locals are potentially compatible.
ResolverMarkers::Universal => locals.iter().map(|(_, local)| local).collect(),
}
}

/// Given a specifier that may include the version _without_ a local segment, return a specifier
Expand Down Expand Up @@ -190,7 +240,7 @@ mod tests {
use pypi_types::ParsedUrl;
use pypi_types::RequirementSource;

use super::{from_source, ForkLocals};
use super::{from_source, Locals};

#[test]
fn extract_locals() -> Result<()> {
Expand Down Expand Up @@ -251,7 +301,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -260,7 +310,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -269,7 +319,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -278,7 +328,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?
);

Expand All @@ -287,7 +337,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?
);

Expand All @@ -296,7 +346,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);

Expand All @@ -305,7 +355,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?
);

Expand Down
Loading

0 comments on commit bb73edb

Please sign in to comment.