Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
426a5e6
feat: add wasm module (WIP)
qartik Apr 17, 2025
2da0811
Bits and pieces (WIP)
croyzor May 7, 2025
219e293
Maintain uniqueness of WASM contexts for a module
croyzor May 9, 2025
7361fe2
Add discard; make progress
croyzor May 9, 2025
d5c570f
Unification of WasmModuleTypes
croyzor May 9, 2025
af7f7d8
Redundant comment
croyzor May 9, 2025
95b1cb2
Oups - add missing file
croyzor May 15, 2025
cc8aef1
Associate WASM functions with modules
croyzor May 15, 2025
caf37f4
Let more types be wasmable
croyzor May 19, 2025
c8a7b71
Add dummy WASM decorators for sphinx
croyzor May 19, 2025
c6d66d8
Add ConstStringArg
croyzor May 20, 2025
fe23a6c
checkpoint
croyzor May 22, 2025
55805ee
yeee
croyzor May 23, 2025
8de275a
Remove array from test file
croyzor May 27, 2025
c0f70b9
cleanup
croyzor May 27, 2025
07e1f62
Merge remote-tracking branch 'origin/main' into 755-add-support-for-wasm
croyzor May 28, 2025
9b3bf0b
cleanup
croyzor May 28, 2025
d355b67
Improve printing of WASM errors
croyzor Jun 4, 2025
3f4bf50
Add tests for WASM errors
croyzor Jun 4, 2025
46f6469
Take GuppyModule arg in WASM decorators
croyzor Jun 4, 2025
50986f7
bad idea
croyzor Jun 11, 2025
061a6ba
Merge remote-tracking branch 'origin/main' into 755-add-support-for-wasm
croyzor Jun 17, 2025
b2aef12
nth times a charm
croyzor Jun 23, 2025
5ed8479
Merge remote-tracking branch 'origin/main' into 755-add-support-for-wasm
croyzor Jun 23, 2025
e27d362
tests: Update error tests
croyzor Jun 23, 2025
147a73a
refactor: Move wasm type error into right file
croyzor Jun 23, 2025
4ff1525
refactor: Use helper for itousize
croyzor Jun 23, 2025
8a0a188
cleanup: Remove duplicate def of `WasmCallChecker`
croyzor Jun 23, 2025
88b5b0c
cleanup: Remove duplicate defs
croyzor Jun 23, 2025
faadea6
refactor: Idiomatic usage of tket2 exts
croyzor Jun 23, 2025
84ebbdb
cleanup: Undo addition of `cached_sig` field to GlobalCall
croyzor Jun 23, 2025
1d09e45
refactor: Move ConstWasmModule to tket2_exts
croyzor Jun 23, 2025
79beec5
cleanup: Remove commented function
croyzor Jun 23, 2025
7aebba1
Update guppylang/tys/builtin.py
croyzor Jun 23, 2025
3b584d3
Update guppylang/checker/errors/wasm.py
croyzor Jun 23, 2025
9ea17ec
tests: Add backticks and update test
croyzor Jun 24, 2025
5b3f43d
Update comments
croyzor Jun 24, 2025
e24e3f5
Revert uv lock file
croyzor Jun 24, 2025
e677be2
remove ConstStringArg
croyzor Jun 24, 2025
3fafda6
Restrict allowed types in WASM signatures
croyzor Jun 24, 2025
9482b31
Add test with comptime interaction
croyzor Jun 24, 2025
925c22d
Update wasm call checker
croyzor Jun 24, 2025
8488e33
Parse wasm functions; move type signature sanity check out of checking
croyzor Jun 24, 2025
c65d4a0
Add test case
croyzor Jun 24, 2025
277182c
Add missing backticks
croyzor Jun 25, 2025
c73ef71
Remove redundant assert
croyzor Jun 25, 2025
3aa675f
Merge remote-tracking branch 'origin/main' into 755-add-support-for-wasm
croyzor Jun 26, 2025
cba859c
Undo uv.lock changes
croyzor Jun 26, 2025
4565fb9
Remove redundant check + error
croyzor Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
@guppy
def random_walk_phase_estimation(
eigenstate: Callable[[], qubit],
controlled_oracle: Callable[[qubit @owned, qubit @owned, float], tuple[qubit, qubit]],
controlled_oracle: Callable[
[qubit @ owned, qubit @ owned, float], tuple[qubit, qubit]
],
num_iters: int,
reset_rate: int,
mu: float,
Expand Down Expand Up @@ -62,7 +64,9 @@ def random_walk_phase_estimation(


@guppy
def example_controlled_oracle(q1: qubit @owned, q2: qubit @owned, t: float) -> tuple[qubit, qubit]:
def example_controlled_oracle(
q1: qubit @ owned, q2: qubit @ owned, t: float
) -> tuple[qubit, qubit]:
"""A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z"""
# This is just a controlled rz gate
a = angle(-0.5 * t)
Expand Down
34 changes: 34 additions & 0 deletions guppylang/checker/errors/wasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import ClassVar

from guppylang.diagnostic import Error
from guppylang.tys.ty import Type


@dataclass(frozen=True)
class WasmError(Error):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class WasmError(Error):
@dataclass(frozen=True)
class WasmError(Error):

title: ClassVar[str] = "WASM signature error"


@dataclass(frozen=True)
class FirstArgNotModule(WasmError):
span_label: ClassVar[str] = (
"First argument to WASM function should be a reference to a WASM module."
" Found `{ty}` instead"
)
ty: Type


@dataclass(frozen=True)
class UnWasmableType(WasmError):
span_label: ClassVar[str] = (
"WASM function signature contained an unsupported type: `{ty}`"
)
ty: Type


@dataclass(frozen=True)
class WasmTypeConversionError(Error):
title: ClassVar[str] = "Can't convert type to WASM"
span_label: ClassVar[str] = "`{thing}` cannot be converted to WASM"
ty: Type
93 changes: 92 additions & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from typing_extensions import dataclass_transform

from guppylang.ast_util import annotate_location
from guppylang.compiler.core import GlobalConstId
from guppylang.definition.common import DefId
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomFunctionDef,
CustomInoutCallCompiler,
DefaultCallChecker,
NotImplementedCallCompiler,
Expand All @@ -28,6 +30,7 @@
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import (
CompiledFunctionDef,
PyFunc,
RawFunctionDef,
)
from guppylang.definition.overloaded import OverloadedFunctionDef
Expand All @@ -40,14 +43,30 @@
from guppylang.definition.struct import RawStructDef
from guppylang.definition.traced import RawTracedFunctionDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.definition.wasm import RawWasmFunctionDef
from guppylang.dummy_decorator import _DummyGuppy, sphinx_running
from guppylang.engine import DEF_STORE
from guppylang.span import Loc, SourceMap, Span
from guppylang.std._internal.checker import WasmCallChecker
from guppylang.std._internal.compiler.wasm import (
WasmModuleCallCompiler,
WasmModuleDiscardCompiler,
WasmModuleInitCompiler,
)
from guppylang.tracing.object import GuppyDefinition, TypeVarGuppyDefinition
from guppylang.tys.arg import Argument
from guppylang.tys.builtin import (
WasmModuleTypeDef,
)
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import FunctionType, NoneType, NumericType
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
)

S = TypeVar("S")
T = TypeVar("T")
Expand Down Expand Up @@ -420,6 +439,78 @@ def load_pytket(
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyDefinition(defn)

def wasm_module(
self, filename: str, filehash: int
) -> Decorator[builtins.type[T], GuppyDefinition]:
def dec(cls: builtins.type[T]) -> GuppyDefinition:
# N.B. Only one module per file and vice-versa
wasm_module = WasmModuleTypeDef(
DefId.fresh(),
cls.__name__,
None,
filename,
filehash,
)

wasm_module_ty = wasm_module.check_instantiate([], None)

DEF_STORE.register_def(wasm_module, get_calling_frame())
for val in cls.__dict__.values():
if isinstance(val, GuppyDefinition):
DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id)
# Add a constructor to the class
call_method = CustomFunctionDef(
DefId.fresh(),
"__new__",
None,
FunctionType(
[
FuncInput(
NumericType(NumericType.Kind.Nat), flags=InputFlags.Owned
)
],
wasm_module_ty,
),
DefaultCallChecker(),
WasmModuleInitCompiler(),
True,
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
True,
)
discard = CustomFunctionDef(
DefId.fresh(),
"discard",
None,
FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
DefaultCallChecker(),
WasmModuleDiscardCompiler(),
False,
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
True,
)
DEF_STORE.register_def(call_method, get_calling_frame())
DEF_STORE.register_impl(wasm_module.id, "__new__", call_method.id)
DEF_STORE.register_def(discard, get_calling_frame())
DEF_STORE.register_impl(wasm_module.id, "discard", discard.id)

return GuppyDefinition(wasm_module)

return dec

def wasm(self, f: PyFunc) -> GuppyDefinition:
func = RawWasmFunctionDef(
DefId.fresh(),
f.__name__,
None,
f,
WasmCallChecker(),
WasmModuleCallCompiler(f.__name__),
True,
signature=None,
)
DEF_STORE.register_def(func, get_calling_frame())
return GuppyDefinition(func)


def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class CustomFunctionDef(CompiledCallableDef):
has_signature: Whether the function has a declared signature.
"""

defined_at: AstNode
defined_at: AstNode | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the __call__ and discard methods of WASM modules, which don't really have a place. I think we could get rid of it if we actually parsed WASM module definitions like we do structs, but I think that's a refactor for down the line...
If you have any other ideas I'm very happy to hear them!

ty: FunctionType
call_checker: "CustomCallChecker"
call_compiler: "CustomInoutCallCompiler"
Expand Down
58 changes: 58 additions & 0 deletions guppylang/definition/wasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING

from guppylang.ast_util import AstNode
from guppylang.checker.errors.wasm import (
FirstArgNotModule,
UnWasmableType,
)
from guppylang.definition.custom import CustomFunctionDef, RawCustomFunctionDef
from guppylang.error import GuppyError
from guppylang.span import SourceMap
from guppylang.tys.builtin import wasm_module_info
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
TupleType,
Type,
)

if TYPE_CHECKING:
from guppylang.checker.core import Globals


class RawWasmFunctionDef(RawCustomFunctionDef):
def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
# Place to highlight in error messages
match fun_ty.inputs[0]:
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_info(
ty
) is not None:
pass
case FuncInput(ty=ty):
raise GuppyError(FirstArgNotModule(loc, ty))
for inp in fun_ty.inputs[1:]:
if not self.is_type_wasmable(inp.ty):
raise GuppyError(UnWasmableType(loc, inp.ty))
if not self.is_type_wasmable(fun_ty.output):
match fun_ty.output:
case NoneType():
pass
case _:
raise GuppyError(UnWasmableType(loc, fun_ty.output))

def is_type_wasmable(self, ty: Type) -> bool:
match ty:
case NumericType():
return True
case TupleType(element_types=tys):
return all(self.is_type_wasmable(ty) for ty in tys)

return False

def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
parsed = super().parse(globals, sources)
self.sanitise_type(parsed.defined_at, parsed.ty)
return parsed
14 changes: 14 additions & 0 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from guppylang.checker.expr_checker import (
ExprChecker,
ExprSynthesizer,
check_call,
check_num_args,
check_type_against,
synthesize_call,
Expand Down Expand Up @@ -557,3 +558,16 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
assert len(inst) == 0, "func_ty is not generic"
node = BarrierExpr(args=args, func_ty=func_ty)
return with_loc(self.node, node), ret_ty


class WasmCallChecker(CustomCallChecker):
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're skipping the sanitise_type call here

# Use default implementation from the expression checker
args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)

return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
# Use default implementation from the expression checker
args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty
38 changes: 36 additions & 2 deletions guppylang/std/_internal/compiler/tket2_exts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from dataclasses import dataclass

import hugr.model
import hugr.std
from hugr import val
from tket2_exts import (
debug,
futures,
Expand All @@ -8,26 +13,55 @@
quantum,
result,
rotation,
wasm,
)

BOOL_EXTENSION = opaque_bool()
DEBUG_EXTENSION = debug()
FUTURES_EXTENSION = futures()
QSYSTEM_EXTENSION = qsystem()
QSYSTEM_RANDOM_EXTENSION = qsystem_random()
QSYSTEM_UTILS_EXTENSION = qsystem_utils()
QUANTUM_EXTENSION = quantum()
RESULT_EXTENSION = result()
ROTATION_EXTENSION = rotation()
DEBUG_EXTENSION = debug()
WASM_EXTENSION = wasm()

TKET2_EXTENSIONS = [
BOOL_EXTENSION,
DEBUG_EXTENSION,
FUTURES_EXTENSION,
QSYSTEM_EXTENSION,
QSYSTEM_RANDOM_EXTENSION,
QSYSTEM_UTILS_EXTENSION,
QUANTUM_EXTENSION,
RESULT_EXTENSION,
ROTATION_EXTENSION,
DEBUG_EXTENSION,
WASM_EXTENSION,
]


@dataclass(frozen=True)
class ConstWasmModule(val.ExtensionValue):
"""Python wrapper for the tket2 ConstWasmModule type"""

wasm_file: str
wasm_hash: int

def to_value(self) -> val.Extension:
ty = WASM_EXTENSION.get_type("module").instantiate([])

name = "tket2.wasm.ConstWasmModule"
payload = {"name": self.wasm_file, "hash": self.wasm_hash}
return val.Extension(name, typ=ty, val=payload, extensions=["tket2.wasm"])

def __str__(self) -> str:
return (
f"ConstWasmModule(wasm_file={self.wasm_file}, wasm_hash={self.wasm_hash})"
)

def to_model(self) -> hugr.model.Term:
file_tm = hugr.model.Literal(self.wasm_file)
hash_tm = hugr.model.Literal(self.wasm_hash)

return hugr.model.Apply("tket2.wasm.ConstWasmModule", [file_tm, hash_tm])
Loading
Loading