Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions hugr-core/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ mod package_json;
pub mod serde_with;

pub use header::{EnvelopeConfig, EnvelopeFormat, MAGIC_NUMBERS, ZstdConfig};
use hugr_model::v0::bumpalo::Bump;
pub use package_json::PackageEncodingError;

use crate::{Hugr, HugrView};
Expand Down Expand Up @@ -345,6 +346,13 @@ pub enum EnvelopeError {
/// The unrecognized flag bits.
flag_ids: Vec<usize>,
},
/// Error raised while checking for breaking extension version mismatch.
#[error(transparent)]
ExtensionVersion {
/// The source error.
#[from]
source: WithGenerator<ExtensionBreakingError>,
},
}

/// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper.
Expand All @@ -353,7 +361,7 @@ fn read_impl(
header: EnvelopeHeader,
registry: &ExtensionRegistry,
) -> Result<Package, EnvelopeError> {
match header.format {
let (package, combined_registry) = match header.format {
#[allow(deprecated)]
EnvelopeFormat::PackageJson => Ok(package_json::from_json_reader(payload, registry)?),
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
Expand All @@ -362,7 +370,13 @@ fn read_impl(
EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
decode_model_ast(payload, registry, header.format)
}
}
}?;

package.modules.iter().try_for_each(|module| {
check_breaking_extensions(module, &combined_registry)
.map_err(|err| WithGenerator::new(err, &package.modules))
})?;
Ok(package)
}

/// Read a HUGR model payload from a reader.
Expand All @@ -372,20 +386,15 @@ fn read_impl(
/// - `extension_registry`: An extension registry with additional extensions to use when
/// decoding the HUGR, if they are not already included in the package.
/// - `format`: The format of the payload.
///
/// Returns package and the combined extension registry
/// of the provided registry and the package extensions.
fn decode_model(
mut stream: impl BufRead,
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<Package, EnvelopeError> {
use hugr_model::v0::bumpalo::Bump;

if format.model_version() != Some(0) {
return Err(EnvelopeError::FormatUnsupported {
format,
feature: None,
});
}

) -> Result<(Package, ExtensionRegistry), EnvelopeError> {
check_model_version(format)?;
let bump = Bump::default();
let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;

Expand All @@ -395,7 +404,19 @@ fn decode_model(
extension_registry.extend(extra_extensions);
}

Ok(import_package(&model_package, &extension_registry)?)
let package = import_package(&model_package, &extension_registry)?;

Ok((package, extension_registry))
}

fn check_model_version(format: EnvelopeFormat) -> Result<(), EnvelopeError> {
if format.model_version() != Some(0) {
return Err(EnvelopeError::FormatUnsupported {
format,
feature: None,
});
}
Ok(())
}

/// Read a HUGR model text payload from a reader.
Expand All @@ -409,16 +430,8 @@ fn decode_model_ast(
mut stream: impl BufRead,
extension_registry: &ExtensionRegistry,
format: EnvelopeFormat,
) -> Result<Package, EnvelopeError> {
use crate::import::import_package;
use hugr_model::v0::bumpalo::Bump;

if format.model_version() != Some(0) {
return Err(EnvelopeError::FormatUnsupported {
format,
feature: None,
});
}
) -> Result<(Package, ExtensionRegistry), EnvelopeError> {
check_model_version(format)?;

let mut extension_registry = extension_registry.clone();
if format == EnvelopeFormat::ModelTextWithExtensions {
Expand All @@ -444,12 +457,8 @@ fn decode_model_ast(
let model_package = ast_package.resolve(&bump)?;

let package = import_package(&model_package, &extension_registry)?;
for module in &package.modules {
check_breaking_extensions(module, &extension_registry).map_err(|err| {
PackageEncodingError::ExtensionVersion(WithGenerator::new(err, &package.modules))
})?;
}
Ok(package)

Ok((package, extension_registry))
}

/// Internal implementation of [`write_envelope`] to call with/without the zstd compression wrapper.
Expand Down
23 changes: 11 additions & 12 deletions hugr-core/src/envelope/package_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ use derive_more::{Display, Error, From};
use itertools::Itertools;
use std::io;

use super::{ExtensionBreakingError, WithGenerator, check_breaking_extensions};
use super::WithGenerator;
use crate::extension::ExtensionRegistry;
use crate::extension::resolution::ExtensionResolutionError;
use crate::package::Package;
use crate::{Extension, Hugr};

/// Read a Package in json format from an io reader.
/// Returns package and the combined extension registry
/// of the provided registry and the package extensions.
pub(super) fn from_json_reader(
reader: impl io::Read,
extension_registry: &ExtensionRegistry,
) -> Result<Package, PackageEncodingError> {
) -> Result<(Package, ExtensionRegistry), PackageEncodingError> {
let val: serde_json::Value = serde_json::from_reader(reader)?;

let PackageDeser {
Expand All @@ -31,19 +33,18 @@ pub(super) fn from_json_reader(
let mut combined_registry = extension_registry.clone();
combined_registry.extend(&pkg_extensions);

for module in &modules {
check_breaking_extensions(module, &combined_registry)
.map_err(|err| WithGenerator::new(err, &modules))?;
}
modules
.iter_mut()
.try_for_each(|module| module.resolve_extension_defs(&combined_registry))
.map_err(|err| WithGenerator::new(err, &modules))?;

Ok(Package {
modules,
extensions: pkg_extensions,
})
Ok((
Package {
modules,
extensions: pkg_extensions,
},
combined_registry,
))
}

/// Write the Package in json format into an io writer.
Expand Down Expand Up @@ -85,8 +86,6 @@ pub enum PackageEncodingError {
IOError(#[from] io::Error),
/// Could not resolve the extension needed to encode the hugr.
ExtensionResolution(#[from] WithGenerator<ExtensionResolutionError>),
/// Error raised while checking for breaking extension version mismatch.
ExtensionVersion(#[from] WithGenerator<ExtensionBreakingError>),
}

/// A private package structure implementing the serde traits.
Expand Down
Loading