Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 282 additions & 4 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ pub mod serde_with;
pub use header::{EnvelopeConfig, EnvelopeFormat, MAGIC_NUMBERS, ZstdConfig};
pub use package_json::PackageEncodingError;

use crate::Hugr;
use crate::{extension::ExtensionRegistry, package::Package};
use crate::{Hugr, HugrView};
use crate::{
extension::{ExtensionRegistry, Version},
package::Package,
};
use header::EnvelopeHeader;
use std::io::BufRead;
use std::io::Write;
Expand All @@ -61,6 +64,51 @@ use itertools::Itertools as _;
use crate::import::ImportError;
use crate::{Extension, import::import_package};

/// Key used to store the name of the generator that produced the envelope.
pub const GENERATOR_KEY: &str = "__generator";

/// Get the name of the generator from the metadata of the HUGR modules.
/// If multiple modules have different generators, a comma-separated list is returned in
/// module order.
/// If no generator is found, `None` is returned.
fn get_generator<H: HugrView>(modules: &[H]) -> Option<String> {
let generators: Vec<String> = modules
.iter()
.filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY))
.map(|v| v.to_string())
.collect();
if generators.is_empty() {
return None;
}

Some(generators.join(", "))
}

fn gen_str(generator: &Option<String>) -> String {
match generator {
Some(g) => format!("\ngenerated by {g}"),
None => String::new(),
}
}

/// Wrap an error with a generator string.
#[derive(Error, Debug)]
#[error("{inner}{}", gen_str(&self.generator))]
pub struct WithGenerator<E: std::fmt::Display> {
inner: E,
/// The name of the generator that produced the envelope, if any.
generator: Option<String>,
}

impl<E: std::fmt::Display> WithGenerator<E> {
fn new(err: E, modules: &[impl HugrView]) -> Self {
Self {
inner: err,
generator: get_generator(modules),
}
}
}

/// Read a HUGR envelope from a reader.
///
/// Returns the deserialized package and the configuration used to encode it.
Expand Down Expand Up @@ -213,6 +261,7 @@ pub enum EnvelopeError {
/// The source error.
#[from]
source: ImportError,
// TODO add generator to model import errors
},
/// Error reading a HUGR model payload.
#[error(transparent)]
Expand Down Expand Up @@ -402,14 +451,96 @@ fn encode_model<'h>(
_ => unreachable!(),
}

// Apend extensions for binary model.
// Append extensions for binary model.
if format == EnvelopeFormat::ModelWithExtensions {
serde_json::to_writer(writer, &extensions.iter().collect_vec())?;
}

Ok(())
}

/// Key used to store the list of used extensions in the metadata of a HUGR.
pub const USED_EXTENSIONS_KEY: &str = "__used_extensions";

#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)]
struct UsedExtension {
name: String,
version: Version,
}

#[derive(Debug, Error)]
#[error(
"Extension '{name}' version mismatch: registered version is {registered}, but used version is {used}"
)]
/// Error raised when the reported used version of an extension
/// does not match the registered version in the extension registry.
pub struct ExtensionVersionMismatch {
name: String,
registered: Version,
used: Version,
}

#[derive(Debug, Error)]
#[non_exhaustive]
/// Error raised when checking for breaking changes in used extensions.
pub enum ExtensionBreakingError {
/// The extension version in the metadata does not match the registered version.
#[error("{0}")]
ExtensionVersionMismatch(ExtensionVersionMismatch),

/// Error deserializing the used extensions metadata.
#[error("Failed to deserialize used extensions metadata")]
Deserialization(#[from] serde_json::Error),
}
/// If HUGR metadata contains a list of used extensions, under the key [`USED_EXTENSIONS_KEY`],
/// and extension is registered in the given registry, check that the
/// version of the extension in the metadata matches the registered version (up to
/// MAJOR.MINOR).
fn check_breaking_extensions(
hugr: impl crate::HugrView,
registry: &ExtensionRegistry,
) -> Result<(), ExtensionBreakingError> {
let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
return Ok(()); // No used extensions metadata, nothing to check
};
let used_exts: Vec<UsedExtension> = serde_json::from_value(exts.clone())?; // TODO handle errors properly

for ext in used_exts {
let Some(registered) = registry.get(ext.name.as_str()) else {
continue; // Extension not registered, ignore
};
if !compatible_versions(registered.version(), &ext.version) {
// This is a breaking change, raise an error.

return Err(ExtensionBreakingError::ExtensionVersionMismatch(
ExtensionVersionMismatch {
name: ext.name,
registered: registered.version().clone(),
used: ext.version,
},
));
}
}

Ok(())
}

/// Check if two versions are compatible according to:
/// - Major version must match.
/// - If major version is 0, minor version must match.
fn compatible_versions(v1: &Version, v2: &Version) -> bool {
if v1.major != v2.major {
return false; // Major version mismatch
}

if v1.major == 0 {
// For major version 0, we only allow minor version matches
return v1.minor == v2.minor;
}

true
}

#[cfg(test)]
pub(crate) mod test {
use super::*;
Expand All @@ -420,9 +551,13 @@ pub(crate) mod test {

use crate::HugrView;
use crate::builder::test::{multi_module_package, simple_package};
use crate::extension::PRELUDE_REGISTRY;
use crate::extension::{Extension, ExtensionRegistry, Version};
use crate::extension::{ExtensionId, PRELUDE_REGISTRY};
use crate::hugr::HugrMut;
use crate::hugr::test::check_hugr_equality;
use crate::std_extensions::STD_REG;
use serde_json::json;
use std::sync::Arc;

/// Returns an `ExtensionRegistry` with the extensions from both
/// sets. Avoids cloning if the first one already contains all
Expand Down Expand Up @@ -538,4 +673,147 @@ pub(crate) mod test {

assert_eq!(package, new_package);
}

#[rstest]
#[case::simple(simple_package())]
fn test_check_breaking_extensions(#[case] mut package: Package) {
// extension with major version 0
let test_ext_v0 =
Extension::new(ExtensionId::new_unchecked("test-v0"), Version::new(0, 2, 3));
// extension with major version > 0
let test_ext_v1 =
Extension::new(ExtensionId::new_unchecked("test-v1"), Version::new(1, 2, 3));

// Create a registry with the test extensions
let registry =
ExtensionRegistry::new([Arc::new(test_ext_v0.clone()), Arc::new(test_ext_v1.clone())]);
let mut hugr = package.modules.remove(0);

// No metadata - should pass
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Matching version for v0 - should pass
let used_exts = json!([{ "name": "test-v0", "version": "0.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Matching major/minor but different patch for v0 - should pass
let used_exts = json!([{ "name": "test-v0", "version": "0.2.4" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

//Different minor version for v0 - should fail
let used_exts = json!([{ "name": "test-v0", "version": "0.3.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
used
})) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(0, 3, 3)
);

// Different major version for v0 - should fail
let used_exts = json!([{ "name": "test-v0", "version": "1.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
used
})) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(1, 2, 3)
);

// Matching version for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Different minor version for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.3.0" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Different patch for v1 - should pass
let used_exts = json!([{ "name": "test-v1", "version": "1.2.4" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Different major version for v1 - should fail
let used_exts = json!([{ "name": "test-v1", "version": "2.2.3" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
used
})) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 2, 3)
);

// Non-registered extension - should pass
let used_exts = json!([{ "name": "unknown", "version": "1.0.0" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));

// Multiple extensions - one mismatch should fail
let used_exts = json!([
{ "name": "unknown", "version": "1.0.0" },
{ "name": "test-v1", "version": "2.0.0" }
]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
name,
registered,
used
})) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 0, 0)
);

// Invalid metadata format - should fail with deserialization error
hugr.set_metadata(
hugr.module_root(),
USED_EXTENSIONS_KEY,
json!("not an array"),
);
assert_matches!(
check_breaking_extensions(&hugr, &registry),
Err(ExtensionBreakingError::Deserialization(_))
);

// Multiple extensions with all compatible versions - should pass
let used_exts = json!([
{ "name": "test-v0", "version": "0.2.5" },
{ "name": "test-v1", "version": "1.9.9" }
]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
assert_matches!(check_breaking_extensions(&hugr, &registry), Ok(()));
}

#[test]
fn test_with_generator_error_message() {
let test_ext = Extension::new(ExtensionId::new_unchecked("test"), Version::new(1, 0, 0));
let registry = ExtensionRegistry::new([Arc::new(test_ext)]);

let mut hugr = simple_package().modules.remove(0);

// Set a generator name in the metadata
let generator_name = json!({ "name": "TestGenerator", "version": "1.2.3" });
hugr.set_metadata(hugr.module_root(), GENERATOR_KEY, generator_name.clone());

// Set incompatible extension version in metadata
let used_exts = json!([{ "name": "test", "version": "2.0.0" }]);
hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);

// Create the error and wrap it with WithGenerator
let err = check_breaking_extensions(&hugr, &registry).unwrap_err();
let with_gen = WithGenerator::new(err, &[&hugr]);

let err_msg = with_gen.to_string();
assert!(err_msg.contains("Extension 'test' version mismatch"));
assert!(err_msg.contains(generator_name.to_string().as_str()));
}
}
18 changes: 13 additions & 5 deletions hugr-core/src/envelope/package_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use derive_more::{Display, Error, From};
use itertools::Itertools;
use std::io;

use super::{ExtensionBreakingError, WithGenerator, check_breaking_extensions};
use crate::extension::ExtensionRegistry;
use crate::extension::resolution::ExtensionResolutionError;
use crate::hugr::ExtensionError;
Expand All @@ -21,19 +22,24 @@ pub(super) fn from_json_reader(
extensions: pkg_extensions,
} = serde_json::from_value::<PackageDeser>(val.clone())?;
let mut modules = modules.into_iter().map(|h| h.0).collect_vec();

let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
pkg_extensions,
&extension_registry.into(),
)?;
)
.map_err(|err| WithGenerator::new(err, &modules))?;

// Resolve the operations in the modules using the defined registries.
let mut combined_registry = extension_registry.clone();
combined_registry.extend(&pkg_extensions);

for module in &mut modules {
module.resolve_extension_defs(&combined_registry)?;
for module in &modules {
check_breaking_extensions(module, &combined_registry)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth having a variant of this function that takes a whole package?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the package is only generated at the end of this function so I'm not so sure

.map_err(|err| WithGenerator::new(err, &modules))?;
}
modules
.iter_mut()
.try_for_each(|module| module.resolve_extension_defs(&combined_registry))
.map_err(|err| WithGenerator::new(err, &modules))?;

Ok(Package {
modules,
Expand Down Expand Up @@ -64,7 +70,9 @@ pub enum PackageEncodingError {
/// Error raised while reading from a file.
IOError(io::Error),
/// Could not resolve the extension needed to encode the hugr.
ExtensionResolution(ExtensionResolutionError),
ExtensionResolution(WithGenerator<ExtensionResolutionError>),
/// Error raised while checking for breaking extension version mismatch.
ExtensionVersion(WithGenerator<ExtensionBreakingError>),
/// Could not resolve the runtime extensions for the hugr.
RuntimeExtensionResolution(ExtensionError),
}
Expand Down
4 changes: 2 additions & 2 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use bumpalo::Bump;
use bumpalo::collections::Vec as BumpVec;
use std::io::BufRead;

/// An error encounted while deserialising a model.
/// An error encountered while deserialising a model.
#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
#[non_exhaustive]
pub enum ReadError {
#[from(forward)]
/// An error encounted while decoding a model from a `capnproto` buffer.
/// An error encountered while decoding a model from a `capnproto` buffer.
DecodingError(capnp::Error),
}

Expand Down
Loading
Loading