Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ typetag = { workspace = true }
semver = { workspace = true, features = ["serde"] }
zstd = { workspace = true, optional = true }
relrc = { workspace = true, features = ["petgraph", "serde"] }
ordered-float = { workspace = true, features = ["serde"] }
base64.workspace = true

[dev-dependencies]
rstest = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,10 @@ impl<'a> Context<'a> {
TypeArg::Type { ty } => self.export_type(ty),
TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()),
TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()),
TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()),
TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()),
TypeArg::List { elems } => {
// For now we assume that the sequence is meant to be a list.
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
Expand Down Expand Up @@ -973,6 +976,8 @@ impl<'a> Context<'a> {
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]),
TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
TypeParam::List { param } => {
let item_type = self.export_type_param(param, None);
self.make_term_apply(model::CORE_LIST_TYPE, &[item_type])
Expand Down
18 changes: 13 additions & 5 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,14 @@ impl<'a> Context<'a> {
return Ok(TypeParam::String);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? {
return Ok(TypeParam::Bytes);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? {
return Ok(TypeParam::Float);
}

if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? {
return Ok(TypeParam::max_nat());
}
Expand Down Expand Up @@ -1194,11 +1202,11 @@ impl<'a> Context<'a> {
Ok(TypeArg::BoundedNat { n: *value })
}

table::Term::Literal(model::Literal::Bytes(_)) => {
Err(error_unsupported!("`(bytes ..)` as `TypeArg`"))
}
table::Term::Literal(model::Literal::Float(_)) => {
Err(error_unsupported!("float literal as `TypeArg`"))
table::Term::Literal(model::Literal::Bytes(value)) => Ok(TypeArg::Bytes {
value: value.clone(),
}),
table::Term::Literal(model::Literal::Float(value)) => {
Ok(TypeArg::Float { value: *value })
}
table::Term::Func { .. } => Err(error_unsupported!("function constant as `TypeArg`")),

Expand Down
60 changes: 55 additions & 5 deletions hugr-core/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
//! [`TypeDef`]: crate::extension::TypeDef

use itertools::Itertools;
use ordered_float::OrderedFloat;
#[cfg(test)]
use proptest_derive::Arbitrary;
use std::num::NonZeroU64;
use std::sync::Arc;
use thiserror::Error;

use super::row_var::MaybeRV;
Expand Down Expand Up @@ -79,6 +81,10 @@ pub enum TypeParam {
},
/// Argument is a [`TypeArg::String`].
String,
/// Argument is a [`TypeArg::Bytes`].
Bytes,
/// Argument is a [`TypeArg::Float`].
Float,
/// Argument is a [`TypeArg::List`]. A list of indeterminate size containing
/// parameters all of the (same) specified element type.
#[display("List[{param}]")]
Expand Down Expand Up @@ -171,6 +177,19 @@ pub enum TypeArg {
/// The string value for the parameter.
arg: String,
},
/// Instance of [`TypeParam::Bytes`]. Byte string.
#[display("bytes")]
Bytes {
/// The value of the bytes parameter.
#[serde(with = "base64")]
value: Arc<[u8]>,
},
/// Instance of [`TypeParam::Float`]. 64-bit floating point number.
#[display("{}", value.into_inner())]
Float {
/// The value of the float parameter.
value: OrderedFloat<f64>,
},
/// Instance of [`TypeParam::List`] defined by a sequence of elements of the same type.
#[display("[{}]", {
use itertools::Itertools as _;
Expand Down Expand Up @@ -301,12 +320,15 @@ impl TypeArg {
pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> {
match self {
TypeArg::Type { ty } => ty.validate(var_decls),
TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()),
TypeArg::List { elems } => {
// TODO: Full validation would check that the type of the elements agrees
elems.iter().try_for_each(|a| a.validate(var_decls))
}
TypeArg::Tuple { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)),
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Float { .. }
| TypeArg::Bytes { .. } => Ok(()),
TypeArg::Variable {
v: TypeArgVariable { idx, cached_decl },
} => {
Expand All @@ -326,7 +348,10 @@ impl TypeArg {
// RowVariables are represented as TypeArg::Variable
ty.substitute1(t).into()
}
TypeArg::BoundedNat { .. } | TypeArg::String { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Bytes { .. }
| TypeArg::Float { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's
TypeArg::List { elems } => {
let mut are_types = elems.iter().map(|ta| match ta {
TypeArg::Type { .. } => true,
Expand Down Expand Up @@ -369,9 +394,11 @@ impl Transformable for TypeArg {
TypeArg::Type { ty } => ty.transform(tr),
TypeArg::List { elems } => elems.transform(tr),
TypeArg::Tuple { elems } => elems.transform(tr),
TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => {
Ok(false)
}
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Variable { .. }
| TypeArg::Float { .. }
| TypeArg::Bytes { .. } => Ok(false),
}
}
}
Expand Down Expand Up @@ -489,6 +516,29 @@ pub enum TypeArgError {
InvalidValue(TypeArg),
}

/// Helper for to serialize and deserialize the byte string in `TypeArg::Bytes` via base64.
mod base64 {
use std::sync::Arc;

use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};

pub fn serialize<S: Serializer>(v: &Arc<[u8]>, s: S) -> Result<S::Ok, S::Error> {
let base64 = BASE64_STANDARD.encode(v);
base64.serialize(s)
}

pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Arc<[u8]>, D::Error> {
let base64 = String::deserialize(d)?;
BASE64_STANDARD
.decode(base64.as_bytes())
.map(|v| v.into())
.map_err(serde::de::Error::custom)
}
}

#[cfg(test)]
mod test {
use itertools::Itertools;
Expand Down
39 changes: 38 additions & 1 deletion hugr-py/src/hugr/_serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Annotated, Any, Literal

from pydantic import (
Base64Bytes,
BaseModel,
ConfigDict,
Field,
Expand Down Expand Up @@ -94,6 +95,20 @@
return tys.StringParam()


class BytesParam(BaseTypeParam):
tp: Literal["Bytes"] = "Bytes"

def deserialize(self) -> tys.BytesParam:
return tys.BytesParam()

Check warning on line 102 in hugr-py/src/hugr/_serialization/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_serialization/tys.py#L102

Added line #L102 was not covered by tests


class FloatParam(BaseTypeParam):
tp: Literal["Float"] = "Float"

def deserialize(self) -> tys.FloatParam:
return tys.FloatParam()

Check warning on line 109 in hugr-py/src/hugr/_serialization/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_serialization/tys.py#L109

Added line #L109 was not covered by tests


class ListParam(BaseTypeParam):
tp: Literal["List"] = "List"
param: TypeParam
Expand All @@ -114,7 +129,13 @@
"""A type parameter."""

root: Annotated[
TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam,
TypeTypeParam
| BoundedNatParam
| StringParam
| FloatParam
| BytesParam
| ListParam
| TupleParam,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

Expand Down Expand Up @@ -158,6 +179,22 @@
return tys.StringArg(value=self.arg)


class FloatArg(BaseTypeArg):
tya: Literal["Float"] = "Float"
value: float

def deserialize(self) -> tys.FloatArg:
return tys.FloatArg(value=self.value)

Check warning on line 187 in hugr-py/src/hugr/_serialization/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_serialization/tys.py#L187

Added line #L187 was not covered by tests


class BytesArg(BaseTypeArg):
tya: Literal["Bytes"] = "Bytes"
value: Base64Bytes

def deserialize(self) -> tys.BytesArg:
return tys.BytesArg(value=bytes(self.value))

Check warning on line 195 in hugr-py/src/hugr/_serialization/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_serialization/tys.py#L195

Added line #L195 was not covered by tests


class ListArg(BaseTypeArg):
tya: Literal["List"] = "List"
elems: list[TypeArg]
Expand Down
60 changes: 60 additions & 0 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@
return model.Apply("core.str")


@dataclass(frozen=True)
class FloatParam(TypeParam):
"""Float type parameter."""

def _to_serial(self) -> stys.FloatParam:
return stys.FloatParam()

Check warning on line 162 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L162

Added line #L162 was not covered by tests

def __str__(self) -> str:
return "Float"

Check warning on line 165 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L165

Added line #L165 was not covered by tests

def to_model(self) -> model.Term:
return model.Apply("core.float")

Check warning on line 168 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L168

Added line #L168 was not covered by tests


@dataclass(frozen=True)
class BytesParam(TypeParam):
"""Bytes type parameter."""

def _to_serial(self) -> stys.BytesParam:
return stys.BytesParam()

Check warning on line 176 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L176

Added line #L176 was not covered by tests

def __str__(self) -> str:
return "Bytes"

Check warning on line 179 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L179

Added line #L179 was not covered by tests

def to_model(self) -> model.Term:
return model.Apply("core.bytes")

Check warning on line 182 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L182

Added line #L182 was not covered by tests


@dataclass(frozen=True)
class ListParam(TypeParam):
"""Type parameter which requires a list of type arguments."""
Expand Down Expand Up @@ -244,6 +272,38 @@
return model.Literal(self.value)


@dataclass(frozen=True)
class FloatArg(TypeArg):
"""A floating point type argument."""

value: float

def _to_serial(self) -> stys.FloatArg:
return stys.FloatArg(value=self.value)

Check warning on line 282 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L282

Added line #L282 was not covered by tests

def __str__(self) -> str:
return f'"{self.value}"'

Check warning on line 285 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L285

Added line #L285 was not covered by tests

def to_model(self) -> model.Term:
return model.Literal(self.value)

Check warning on line 288 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L288

Added line #L288 was not covered by tests


@dataclass(frozen=True)
class BytesArg(TypeArg):
"""A bytes type argument."""

value: bytes

def _to_serial(self) -> stys.BytesArg:
return stys.BytesArg(value=self.value)

Check warning on line 298 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L298

Added line #L298 was not covered by tests

def __str__(self) -> str:
return "bytes"

Check warning on line 301 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L301

Added line #L301 was not covered by tests

def to_model(self) -> model.Term:
return model.Literal(self.value)

Check warning on line 304 in hugr-py/src/hugr/tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/tys.py#L304

Added line #L304 was not covered by tests


@dataclass(frozen=True)
class ListArg(TypeArg):
"""Sequence of type arguments for a :class:`ListParam`."""
Expand Down
6 changes: 6 additions & 0 deletions specification/hugr.md
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,9 @@ such declarations may include (bind) any number of type parameters, of kinds as
TypeParam ::= Type(Any|Copyable)
| BoundedUSize(u64|) -- note optional bound
| Extensions
| String
| Bytes
| Float
| List(TypeParam) -- homogeneous, any sized
| Tuple([TypeParam]) -- heterogenous, fixed size
| Opaque(Name, [TypeArg]) -- e.g. Opaque("Array", [5, Opaque("usize", [])])
Expand All @@ -841,6 +844,9 @@ TypeArgs appropriate for the function's TypeParams:
```haskell
TypeArg ::= Type(Type) -- could be a variable of kind Type, or contain variable(s)
| BoundedUSize(u64)
| String(String)
| Bytes([u8])
| Float(f64)
| Extensions(Extensions) -- may contain TypeArg's of kind Extensions
| List([TypeArg])
| Tuple([TypeArg])
Expand Down
Loading
Loading