-
Notifications
You must be signed in to change notification settings - Fork 18
feat: Declare WASM modules in guppy #942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
426a5e6
2da0811
219e293
7361fe2
d5c570f
af7f7d8
95b1cb2
cc8aef1
caf37f4
c8a7b71
c6d66d8
fe23a6c
55805ee
8de275a
c0f70b9
07e1f62
9b3bf0b
d355b67
3f4bf50
46f6469
50986f7
061a6ba
b2aef12
5ed8479
e27d362
147a73a
4ff1525
8a0a188
88b5b0c
faadea6
84ebbdb
1d09e45
79beec5
7aebba1
3b584d3
9ea17ec
5b3f43d
e24e3f5
e677be2
3fafda6
9482b31
925c22d
8488e33
c65d4a0
277182c
c73ef71
3aa675f
cba859c
4565fb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,42 @@ | ||||||
| from dataclasses import dataclass | ||||||
| from typing import ClassVar | ||||||
|
|
||||||
| from guppylang.diagnostic import Error | ||||||
| from guppylang.tys.ty import Type | ||||||
|
|
||||||
|
|
||||||
| @dataclass(frozen=True) | ||||||
| class WasmError(Error): | ||||||
| title: ClassVar[str] = "WASM signature error" | ||||||
|
|
||||||
|
|
||||||
| @dataclass(frozen=True) | ||||||
| class FirstArgNotModule(WasmError): | ||||||
| span_label: ClassVar[str] = ( | ||||||
| "First argument to WASM function should be a reference to a WASM module." | ||||||
| " Found {ty} instead" | ||||||
|
||||||
| " Found {ty} instead" | |
| " Found `{ty}` instead" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backticks still missing here
croyzor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "WASM function didn't have a function type, instead found {ty}" | |
| "WASM function didn't have a function type, instead found `{ty}`" |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,10 +14,12 @@ | |||||
| from typing_extensions import dataclass_transform | ||||||
|
|
||||||
| from guppylang.ast_util import annotate_location | ||||||
| from guppylang.compiler.core import GlobalConstId | ||||||
| from guppylang.definition.common import DefId | ||||||
| from guppylang.definition.const import RawConstDef | ||||||
| from guppylang.definition.custom import ( | ||||||
| CustomCallChecker, | ||||||
| CustomFunctionDef, | ||||||
| CustomInoutCallCompiler, | ||||||
| DefaultCallChecker, | ||||||
| NotImplementedCallCompiler, | ||||||
|
|
@@ -28,6 +30,7 @@ | |||||
| from guppylang.definition.extern import RawExternDef | ||||||
| from guppylang.definition.function import ( | ||||||
| CompiledFunctionDef, | ||||||
| PyFunc, | ||||||
| RawFunctionDef, | ||||||
| ) | ||||||
| from guppylang.definition.overloaded import OverloadedFunctionDef | ||||||
|
|
@@ -43,11 +46,26 @@ | |||||
| from guppylang.dummy_decorator import _DummyGuppy, sphinx_running | ||||||
| from guppylang.engine import DEF_STORE | ||||||
| from guppylang.span import Loc, SourceMap, Span | ||||||
| from guppylang.std._internal.checker import WasmCallChecker | ||||||
| from guppylang.std._internal.compiler.wasm import ( | ||||||
| WasmModuleCallCompiler, | ||||||
| WasmModuleDiscardCompiler, | ||||||
| WasmModuleInitCompiler, | ||||||
| ) | ||||||
| from guppylang.tracing.object import GuppyDefinition, TypeVarGuppyDefinition | ||||||
| from guppylang.tys.arg import Argument | ||||||
| from guppylang.tys.builtin import ( | ||||||
| WasmModuleTypeDef, | ||||||
| ) | ||||||
| from guppylang.tys.param import Parameter | ||||||
| from guppylang.tys.subst import Inst | ||||||
| from guppylang.tys.ty import FunctionType, NoneType, NumericType | ||||||
| from guppylang.tys.ty import ( | ||||||
| FuncInput, | ||||||
| FunctionType, | ||||||
| InputFlags, | ||||||
| NoneType, | ||||||
| NumericType, | ||||||
| ) | ||||||
|
|
||||||
| S = TypeVar("S") | ||||||
| T = TypeVar("T") | ||||||
|
|
@@ -420,6 +438,78 @@ def load_pytket( | |||||
| DEF_STORE.register_def(defn, get_calling_frame()) | ||||||
| return GuppyDefinition(defn) | ||||||
|
|
||||||
| def wasm_module( | ||||||
| self, filename: str, filehash: int | ||||||
| ) -> Decorator[builtins.type[T], GuppyDefinition]: | ||||||
| def dec(cls: builtins.type[T]) -> GuppyDefinition: | ||||||
| # N.B. Only one module per file and vice-versa | ||||||
| wasm_module = WasmModuleTypeDef( | ||||||
| DefId.fresh(), | ||||||
| cls.__name__, | ||||||
| None, | ||||||
| filename, | ||||||
| filehash, | ||||||
| ) | ||||||
|
|
||||||
| wasm_module_ty = wasm_module.check_instantiate([], None) | ||||||
|
|
||||||
| DEF_STORE.register_def(wasm_module, get_calling_frame()) | ||||||
| for val in cls.__dict__.values(): | ||||||
| if isinstance(val, GuppyDefinition): | ||||||
| DEF_STORE.register_impl(wasm_module.id, val.wrapped.name, val.id) | ||||||
| # Add a __call__ to the class | ||||||
|
||||||
| # Add a __call__ to the class | |
| # Add a constructor to the class |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |||||
| from guppylang.compiler.core import CompilerContext, DFContainer, GlobalConstId | ||||||
| from guppylang.definition.common import ParsableDef | ||||||
| from guppylang.definition.value import CallReturnWires, CompiledCallableDef | ||||||
|
|
||||||
| # from guppylang.definition.wasm import WasmModule | ||||||
|
||||||
| # from guppylang.definition.wasm import WasmModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for the __call__ and discard methods of WASM modules, which don't really have a place. I think we could get rid of it if we actually parsed WASM module definitions like we do structs, but I think that's a refactor for down the line...
If you have any other ideas I'm very happy to hear them!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,9 +11,16 @@ | |
| ArrayComprUnknownSizeError, | ||
| TypeMismatchError, | ||
| ) | ||
| from guppylang.checker.errors.wasm import ( | ||
| FirstArgNotModule, | ||
| NonFunctionWasmType, | ||
| UnWasmableType, | ||
| WasmTypeConversionError, | ||
| ) | ||
| from guppylang.checker.expr_checker import ( | ||
| ExprChecker, | ||
| ExprSynthesizer, | ||
| check_call, | ||
| check_num_args, | ||
| check_type_against, | ||
| synthesize_call, | ||
|
|
@@ -48,9 +55,11 @@ | |
| is_array_type, | ||
| is_bool_type, | ||
| is_sized_iter_type, | ||
| is_string_type, | ||
| nat_type, | ||
| sized_iter_type, | ||
| string_type, | ||
| wasm_module_info, | ||
| ) | ||
| from guppylang.tys.const import Const, ConstValue | ||
| from guppylang.tys.subst import Subst | ||
|
|
@@ -60,7 +69,9 @@ | |
| InputFlags, | ||
| NoneType, | ||
| NumericType, | ||
| OpaqueType, | ||
| StructType, | ||
| TupleType, | ||
| Type, | ||
| unify, | ||
| ) | ||
|
|
@@ -557,3 +568,71 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: | |
| assert len(inst) == 0, "func_ty is not generic" | ||
| node = BarrierExpr(args=args, func_ty=func_ty) | ||
| return with_loc(self.node, node), ret_ty | ||
|
|
||
|
|
||
| class WasmCallChecker(CustomCallChecker): | ||
| type_sanitised: bool = False | ||
|
|
||
| def sanitise_type(self) -> None: | ||
| # Place to highlight in error messages | ||
| loc = self.func.defined_at | ||
|
|
||
| if isinstance(self.func.ty, FunctionType): | ||
| match self.func.ty.inputs[0]: | ||
| case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_info( | ||
| ty | ||
| ) is not None: | ||
| pass | ||
| case FuncInput(ty=ty): | ||
| raise GuppyError(FirstArgNotModule(loc, ty)) | ||
| for inp in self.func.ty.inputs[1:]: | ||
| if not self.is_type_wasmable(inp.ty): | ||
| raise GuppyError(UnWasmableType(loc, inp.ty)) | ||
| if not self.is_type_wasmable(self.func.ty.output): | ||
| raise GuppyError(UnWasmableType(loc, self.func.ty.output)) | ||
| else: | ||
| raise GuppyError(NonFunctionWasmType(loc, self.func.ty)) | ||
|
||
| self.type_sanitised = True | ||
|
|
||
| def is_type_wasmable(self, ty: Type) -> bool: | ||
| match ty: | ||
| case NumericType(): | ||
| return True | ||
| case NoneType(): | ||
| return True | ||
| case TupleType(element_types=tys): | ||
| return all(self.is_type_wasmable(ty) for ty in tys) | ||
| case FunctionType() as f: | ||
| return self.is_type_wasmable(f.output) and all( | ||
| self.is_type_wasmable(inp.ty) and inp.flags != InputFlags.Inout | ||
| for inp in f.inputs | ||
| ) | ||
|
||
| case OpaqueType() as ty: | ||
| if is_string_type(ty): | ||
| return True | ||
| elif is_array_type(ty): | ||
| return all( | ||
| self.is_type_wasmable(arg.ty) | ||
| for arg in ty.args | ||
| if isinstance(arg, TypeArg) | ||
| ) | ||
| return is_bool_type(ty) | ||
| case _: | ||
| return False | ||
|
|
||
| def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're skipping the |
||
| if not self.is_type_wasmable(ty): | ||
| raise GuppyTypeError(WasmTypeConversionError(self.node, ty)) | ||
|
||
|
|
||
| # Use default implementation from the expression checker | ||
| args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) | ||
|
|
||
| return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst | ||
|
|
||
| def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: | ||
| if not self.type_sanitised: | ||
| self.sanitise_type() | ||
|
||
|
|
||
| # Use default implementation from the expression checker | ||
| args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) | ||
| return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.