Skip to content

Commit

Permalink
uv-resolver: support unambiguous omission of 'source' and 'version'
Browse files Browse the repository at this point in the history
When there is only one distribution for a particular package name, any
dependencies (the edges in the resolution graph) that reference that
package name are completely unambiguous. Therefore, we can actually omit
their version and source information and instead derive it from the
distribution entry.

We add some tests to check the success and error cases. That is, when
`source` or `version` are omitted and there are more than one
corresponding distribution for the package name (i.e., it's ambiguous),
then lock deserialization should fail.
  • Loading branch information
BurntSushi committed Jun 26, 2024
1 parent 4cb1595 commit 4accbfd
Show file tree
Hide file tree
Showing 7 changed files with 944 additions and 18 deletions.
284 changes: 266 additions & 18 deletions crates/uv-resolver/src/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,15 @@ impl Lock {
doc.insert("requires-python", value(requires_python.to_string()));
}

// Count the number of distributions for each package name. When
// there's only one distribution for a particular package name (the
// overwhelmingly common case), we can omit some data (like source and
// version) on dependency edges since it is strictly redundant.
let mut dist_count_by_name: FxHashMap<PackageName, u64> = FxHashMap::default();
for dist in &self.distributions {
*dist_count_by_name.entry(dist.id.name.clone()).or_default() += 1;
}

let mut distributions = ArrayOfTables::new();
for dist in &self.distributions {
let mut table = Table::new();
Expand All @@ -469,7 +478,7 @@ impl Lock {
let deps = dist
.dependencies
.iter()
.map(Dependency::to_toml)
.map(|dep| dep.to_toml(&dist_count_by_name))
.collect::<ArrayOfTables>();
table.insert("dependencies", Item::ArrayOfTables(deps));
}
Expand All @@ -479,7 +488,7 @@ impl Lock {
for (extra, deps) in &dist.optional_dependencies {
let deps = deps
.iter()
.map(Dependency::to_toml)
.map(|dep| dep.to_toml(&dist_count_by_name))
.collect::<ArrayOfTables>();
optional_deps.insert(extra.as_ref(), Item::ArrayOfTables(deps));
}
Expand All @@ -491,7 +500,7 @@ impl Lock {
for (extra, deps) in &dist.dev_dependencies {
let deps = deps
.iter()
.map(Dependency::to_toml)
.map(|dep| dep.to_toml(&dist_count_by_name))
.collect::<ArrayOfTables>();
dev_dependencies.insert(extra.as_ref(), Item::ArrayOfTables(deps));
}
Expand Down Expand Up @@ -532,10 +541,27 @@ impl TryFrom<LockWire> for Lock {
type Error = LockError;

fn try_from(wire: LockWire) -> Result<Lock, LockError> {
// Count the number of distributions for each package name. When
// there's only one distribution for a particular package name (the
// overwhelmingly common case), we can omit some data (like source and
// version) on dependency edges since it is strictly redundant.
let mut unambiguous_dist_ids: FxHashMap<PackageName, DistributionId> = FxHashMap::default();
let mut ambiguous = FxHashSet::default();
for dist in &wire.distributions {
if ambiguous.contains(&dist.id.name) {
continue;
}
if unambiguous_dist_ids.remove(&dist.id.name).is_some() {
ambiguous.insert(dist.id.name.clone());
continue;
}
unambiguous_dist_ids.insert(dist.id.name.clone(), dist.id.clone());
}

let distributions = wire
.distributions
.into_iter()
.map(DistributionWire::unwire)
.map(|dist| dist.unwire(&unambiguous_dist_ids))
.collect::<Result<Vec<_>, _>>()?;
Lock::new(wire.version, distributions, wire.requires_python)
}
Expand Down Expand Up @@ -844,9 +870,14 @@ struct DistributionWire {
}

impl DistributionWire {
fn unwire(self) -> Result<Distribution, LockError> {
fn unwire(
self,
unambiguous_dist_ids: &FxHashMap<PackageName, DistributionId>,
) -> Result<Distribution, LockError> {
let unwire_deps = |deps: Vec<DependencyWire>| -> Result<Vec<Dependency>, LockError> {
deps.into_iter().map(DependencyWire::unwire).collect()
deps.into_iter()
.map(|dep| dep.unwire(unambiguous_dist_ids))
.collect()
};
Ok(Distribution {
id: self.id,
Expand Down Expand Up @@ -922,16 +953,38 @@ impl std::fmt::Display for DistributionId {
#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, serde::Deserialize)]
struct DistributionIdForDependency {
name: PackageName,
version: Version,
source: Source,
version: Option<Version>,
source: Option<Source>,
}

impl DistributionIdForDependency {
fn unwire(self) -> Result<DistributionId, LockError> {
fn unwire(
self,
unambiguous_dist_ids: &FxHashMap<PackageName, DistributionId>,
) -> Result<DistributionId, LockError> {
let unambiguous_dist_id = unambiguous_dist_ids.get(&self.name);
let version = self.version.map(Ok::<_, LockError>).unwrap_or_else(|| {
let Some(dist_id) = unambiguous_dist_id else {
return Err(LockErrorKind::MissingDependencyVersion {
name: self.name.clone(),
}
.into());
};
Ok(dist_id.version.clone())
})?;
let source = self.source.map(Ok::<_, LockError>).unwrap_or_else(|| {
let Some(dist_id) = unambiguous_dist_id else {
return Err(LockErrorKind::MissingDependencySource {
name: self.name.clone(),
}
.into());
};
Ok(dist_id.source.clone())
})?;
Ok(DistributionId {
name: self.name,
version: self.version,
source: self.source,
version,
source,
})
}
}
Expand All @@ -940,8 +993,8 @@ impl From<DistributionId> for DistributionIdForDependency {
fn from(id: DistributionId) -> DistributionIdForDependency {
DistributionIdForDependency {
name: id.name,
version: id.version,
source: id.source,
version: Some(id.version),
source: Some(id.source),
}
}
}
Expand Down Expand Up @@ -1762,11 +1815,17 @@ impl Dependency {
}

/// Returns the TOML representation of this dependency.
fn to_toml(&self) -> Table {
fn to_toml(&self, dist_count_by_name: &FxHashMap<PackageName, u64>) -> Table {
let count = dist_count_by_name
.get(&self.distribution_id.name)
.copied()
.expect("all dependencies have a corresponding distribution");
let mut table = Table::new();
table.insert("name", value(self.distribution_id.name.to_string()));
table.insert("version", value(self.distribution_id.version.to_string()));
table.insert("source", value(self.distribution_id.source.to_string()));
if count > 1 {
table.insert("version", value(self.distribution_id.version.to_string()));
table.insert("source", value(self.distribution_id.source.to_string()));
}
if let Some(ref extra) = self.extra {
table.insert("extra", value(extra.to_string()));
}
Expand Down Expand Up @@ -1813,9 +1872,12 @@ struct DependencyWire {
}

impl DependencyWire {
fn unwire(self) -> Result<Dependency, LockError> {
fn unwire(
self,
unambiguous_dist_ids: &FxHashMap<PackageName, DistributionId>,
) -> Result<Dependency, LockError> {
Ok(Dependency {
distribution_id: self.distribution_id.unwire()?,
distribution_id: self.distribution_id.unwire(unambiguous_dist_ids)?,
extra: self.extra,
marker: self.marker,
})
Expand Down Expand Up @@ -2034,6 +2096,26 @@ enum LockErrorKind {
#[source]
err: VerbatimUrlError,
},
/// An error that occurs when an ambiguous `distribution.dependency` is
/// missing a `version` field.
#[error(
"dependency {name} has missing `version` \
field but has more than one matching distribution"
)]
MissingDependencyVersion {
/// The name of the dependency that is missing a `version` field.
name: PackageName,
},
/// An error that occurs when an ambiguous `distribution.dependency` is
/// missing a `source` field.
#[error(
"dependency {name} has missing `source` \
field but has more than one matching distribution"
)]
MissingDependencySource {
/// The name of the dependency that is missing a `source` field.
name: PackageName,
},
}

/// An error that occurs when a source string could not be parsed.
Expand Down Expand Up @@ -2092,6 +2174,172 @@ impl std::fmt::Display for HashParseError {
mod tests {
use super::*;

#[test]
fn missing_dependency_source_unambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
version = "0.1.0"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn missing_dependency_version_unambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
source = "registry+https://pypi.org/simple"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn missing_dependency_source_version_unambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn missing_dependency_source_ambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "a"
version = "0.1.1"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
version = "0.1.0"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn missing_dependency_version_ambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "a"
version = "0.1.1"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
source = "registry+https://pypi.org/simple"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn missing_dependency_source_version_ambiguous() {
let data = r#"
version = 1
[[distribution]]
name = "a"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "a"
version = "0.1.1"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution]]
name = "b"
version = "0.1.0"
source = "registry+https://pypi.org/simple"
sdist = { url = "https://example.com", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 0 }
[[distribution.dependencies]]
name = "a"
"#;
let result: Result<Lock, _> = toml::from_str(data);
insta::assert_debug_snapshot!(result);
}

#[test]
fn hash_required_present() {
let data = r#"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: crates/uv-resolver/src/lock.rs
expression: result
---
Err(
Error {
inner: Error {
inner: TomlError {
message: "dependency a has missing `source` field but has more than one matching distribution",
raw: None,
keys: [],
span: None,
},
},
},
)
Loading

0 comments on commit 4accbfd

Please sign in to comment.