Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ typetag = { workspace = true }
semver = { workspace = true, features = ["serde"] }
zstd = { workspace = true, optional = true }
relrc = { workspace = true, features = ["petgraph", "serde"] }
ordered-float = { workspace = true, features = ["serde"] }
base64.workspace = true

[dev-dependencies]
rstest = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,10 @@ impl<'a> Context<'a> {
TypeArg::Type { ty } => self.export_type(ty),
TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()),
TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()),
TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()),
TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()),
TypeArg::List { elems } => {
// For now we assume that the sequence is meant to be a list.
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
Expand Down Expand Up @@ -973,6 +976,8 @@ impl<'a> Context<'a> {
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]),
TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
TypeParam::List { param } => {
let item_type = self.export_type_param(param, None);
self.make_term_apply(model::CORE_LIST_TYPE, &[item_type])
Expand Down
18 changes: 13 additions & 5 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,14 @@ impl<'a> Context<'a> {
return Ok(TypeParam::String);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? {
return Ok(TypeParam::Bytes);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? {
return Ok(TypeParam::Float);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? {
return Ok(TypeParam::max_nat());
}
Expand Down Expand Up @@ -1194,11 +1202,11 @@ impl<'a> Context<'a> {
Ok(TypeArg::BoundedNat { n: *value })
}

table::Term::Literal(model::Literal::Bytes(_)) => {
Err(error_unsupported!("`(bytes ..)` as `TypeArg`"))
}
table::Term::Literal(model::Literal::Float(_)) => {
Err(error_unsupported!("float literal as `TypeArg`"))
table::Term::Literal(model::Literal::Bytes(value)) => Ok(TypeArg::Bytes {
value: value.clone(),
}),
table::Term::Literal(model::Literal::Float(value)) => {
Ok(TypeArg::Float { value: *value })
}
table::Term::Func { .. } => Err(error_unsupported!("function constant as `TypeArg`")),

Expand Down
85 changes: 80 additions & 5 deletions hugr-core/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
//! [`TypeDef`]: crate::extension::TypeDef

use itertools::Itertools;
use ordered_float::OrderedFloat;
#[cfg(test)]
use proptest_derive::Arbitrary;
use std::num::NonZeroU64;
use std::sync::Arc;
use thiserror::Error;

use super::row_var::MaybeRV;
Expand Down Expand Up @@ -79,6 +81,10 @@ pub enum TypeParam {
},
/// Argument is a [`TypeArg::String`].
String,
/// Argument is a [`TypeArg::Bytes`].
Bytes,
/// Argument is a [`TypeArg::Float`].
Float,
/// Argument is a [`TypeArg::List`]. A list of indeterminate size containing
/// parameters all of the (same) specified element type.
#[display("List[{param}]")]
Expand Down Expand Up @@ -171,6 +177,19 @@ pub enum TypeArg {
/// The string value for the parameter.
arg: String,
},
/// Instance of [`TypeParam::Bytes`]. Byte string.
#[display("bytes")]
Bytes {
/// The value of the bytes parameter.
#[serde(with = "base64")]
value: Arc<[u8]>,
},
/// Instance of [`TypeParam::Float`]. 64-bit floating point number.
#[display("{}", value.into_inner())]
Float {
/// The value of the float parameter.
value: OrderedFloat<f64>,
},
/// Instance of [`TypeParam::List`] defined by a sequence of elements of the same type.
#[display("[{}]", {
use itertools::Itertools as _;
Expand Down Expand Up @@ -301,12 +320,15 @@ impl TypeArg {
pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> {
match self {
TypeArg::Type { ty } => ty.validate(var_decls),
TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()),
TypeArg::List { elems } => {
// TODO: Full validation would check that the type of the elements agrees
elems.iter().try_for_each(|a| a.validate(var_decls))
}
TypeArg::Tuple { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)),
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Float { .. }
| TypeArg::Bytes { .. } => Ok(()),
TypeArg::Variable {
v: TypeArgVariable { idx, cached_decl },
} => {
Expand All @@ -326,7 +348,10 @@ impl TypeArg {
// RowVariables are represented as TypeArg::Variable
ty.substitute1(t).into()
}
TypeArg::BoundedNat { .. } | TypeArg::String { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Bytes { .. }
| TypeArg::Float { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's
TypeArg::List { elems } => {
let mut are_types = elems.iter().map(|ta| match ta {
TypeArg::Type { .. } => true,
Expand Down Expand Up @@ -369,9 +394,11 @@ impl Transformable for TypeArg {
TypeArg::Type { ty } => ty.transform(tr),
TypeArg::List { elems } => elems.transform(tr),
TypeArg::Tuple { elems } => elems.transform(tr),
TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => {
Ok(false)
}
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Variable { .. }
| TypeArg::Float { .. }
| TypeArg::Bytes { .. } => Ok(false),
}
}
}
Expand Down Expand Up @@ -442,6 +469,8 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr
}

(TypeArg::String { .. }, TypeParam::String) => Ok(()),
(TypeArg::Bytes { .. }, TypeParam::Bytes) => Ok(()),
(TypeArg::Float { .. }, TypeParam::Float) => Ok(()),
_ => Err(TypeArgError::TypeMismatch {
arg: arg.clone(),
param: param.clone(),
Expand Down Expand Up @@ -489,6 +518,29 @@ pub enum TypeArgError {
InvalidValue(TypeArg),
}

/// Helper for to serialize and deserialize the byte string in `TypeArg::Bytes` via base64.
mod base64 {
use std::sync::Arc;

use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};

pub fn serialize<S: Serializer>(v: &Arc<[u8]>, s: S) -> Result<S::Ok, S::Error> {
let base64 = BASE64_STANDARD.encode(v);
base64.serialize(s)
}

pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Arc<[u8]>, D::Error> {
let base64 = String::deserialize(d)?;
BASE64_STANDARD
.decode(base64.as_bytes())
.map(|v| v.into())
.map_err(serde::de::Error::custom)
}
}

#[cfg(test)]
mod test {
use itertools::Itertools;
Expand Down Expand Up @@ -660,6 +712,16 @@ mod test {
);
}

#[test]
fn bytes_json_roundtrip() {
let bytes_arg = TypeArg::Bytes {
value: vec![0, 1, 2, 3, 255, 254, 253, 252].into(),
};
let serialized = serde_json::to_string(&bytes_arg).unwrap();
let deserialized: TypeArg = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, bytes_arg);
}

mod proptest {

use proptest::prelude::*;
Expand All @@ -685,6 +747,9 @@ mod test {
use prop::collection::vec;
use prop::strategy::Union;
let mut strat = Union::new([
Just(Self::String).boxed(),
Just(Self::Bytes).boxed(),
Just(Self::Float).boxed(),
Just(Self::String).boxed(),
any::<TypeBound>().prop_map(|b| Self::Type { b }).boxed(),
any::<UpperBound>()
Expand Down Expand Up @@ -715,6 +780,16 @@ mod test {
let mut strat = Union::new([
any::<u64>().prop_map(|n| Self::BoundedNat { n }).boxed(),
any::<String>().prop_map(|arg| Self::String { arg }).boxed(),
any::<Vec<u8>>()
.prop_map(|bytes| Self::Bytes {
value: bytes.into(),
})
.boxed(),
any::<f64>()
.prop_map(|value| Self::Float {
value: value.into(),
})
.boxed(),
any_with::<Type>(depth)
.prop_map(|ty| Self::Type { ty })
.boxed(),
Expand Down
33 changes: 32 additions & 1 deletion hugr-core/tests/snapshots/model__roundtrip_params.snap
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,44 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-para

(mod)

(import core.fn)
(import core.bytes)

(import core.nat)

(import core.call)

(import core.type)

(import core.fn)

(import core.str)

(import core.float)

(define-func
example.swap
(param ?0 core.type)
(param ?1 core.type)
(core.fn [?0 ?1] [?1 ?0])
(dfg [%0 %1] [%1 %0] (signature (core.fn [?0 ?1] [?1 ?0]))))

(declare-func
example.literals
(param ?0 core.str)
(param ?1 core.nat)
(param ?2 core.bytes)
(param ?3 core.float)
(core.fn [] []))

(define-func example.call_literals (core.fn [] [])
(dfg
(signature (core.fn [] []))
((core.call
[]
[]
(example.literals
"string"
42
(bytes "SGVsbG8gd29ybGQg8J+Yig==")
6.023e23))
(signature (core.fn [] [])))))
15 changes: 15 additions & 0 deletions hugr-model/tests/fixtures/model-params.edn
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,18 @@
(core.fn [?a ?b] [?b ?a])
(dfg [%a %b] [%b %a]
(signature (core.fn [?a ?b] [?b ?a]))))

(declare-func example.literals
(param ?a core.str)
(param ?b core.nat)
(param ?c core.bytes)
(param ?d core.float)
(core.fn [] []))

(define-func example.call_literals
(core.fn [] [])
(dfg [] []
(signature (core.fn [] []))
((core.call
(example.literals "string" 42 (bytes "SGVsbG8gd29ybGQg8J+Yig==") 6.023e23))
(signature (core.fn [] [])))))
Loading
Loading