Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self, frame: FrameType | None) -> None:
@cache
def builtin_defs() -> dict[str, Definition]:
import guppylang.std.builtins
from guppylang.tracing.object import GuppyDefinition
from guppylang.defs import GuppyDefinition

return BUILTIN_DEFS | {
name: val.wrapped
Expand Down Expand Up @@ -408,7 +408,7 @@ def __getitem__(self, item: DefId) -> ParsedDef: ...
def __getitem__(self, item: str) -> "ParsedDef | PythonObject": ...

def __getitem__(self, item: DefId | str) -> "ParsedDef | PythonObject":
from guppylang.tracing.object import GuppyDefinition
from guppylang.defs import GuppyDefinition

match item:
case DefId() as def_id:
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ def _check_global(
)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
from guppylang.defs import GuppyDefinition
from guppylang.engine import ENGINE
from guppylang.tracing.object import GuppyDefinition

# A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
# not an AST node, so we have to compute its span by hand. This is fine since
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def check_nested_func_def(
if not captured:
# If there are no captured vars, we treat the function like a global name
from guppylang.definition.function import ParsedFunctionDef
from guppylang.tracing.object import GuppyDefinition
from guppylang.defs import GuppyDefinition

func = ParsedFunctionDef(def_id, func_def.name, func_def, func_ty, None)
DEF_STORE.register_def(func, None)
Expand Down
84 changes: 38 additions & 46 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from pathlib import Path
from types import FrameType, ModuleType
from typing import Any, TypeVar, cast
from typing import Any, ParamSpec, TypeVar, cast

from hugr import ops
from hugr import tys as ht
Expand Down Expand Up @@ -33,21 +33,24 @@
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import (
CompiledFunctionDef,
PyFunc,
RawFunctionDef,
)
from guppylang.definition.overloaded import OverloadedFunctionDef
from guppylang.definition.parameter import ConstVarDef, RawConstVarDef, TypeVarDef
from guppylang.definition.pytket_circuits import (
CompiledPytketDef,
RawLoadPytketDef,
RawPytketDef,
)
from guppylang.definition.struct import RawStructDef
from guppylang.definition.traced import RawTracedFunctionDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.definition.value import CompiledHugrNodeDef
from guppylang.definition.wasm import RawWasmFunctionDef
from guppylang.defs import (
GuppyDefinition,
GuppyFunctionDefinition,
GuppyTypeVarDefinition,
)
from guppylang.dummy_decorator import _DummyGuppy, sphinx_running
from guppylang.engine import DEF_STORE
from guppylang.span import Loc, SourceMap, Span
Expand All @@ -57,7 +60,6 @@
WasmModuleDiscardCompiler,
WasmModuleInitCompiler,
)
from guppylang.tracing.object import GuppyDefinition, TypeVarGuppyDefinition
from guppylang.tys.arg import Argument
from guppylang.tys.builtin import (
WasmModuleTypeDef,
Expand All @@ -75,6 +77,7 @@
S = TypeVar("S")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
Decorator = Callable[[S], T]

AnyRawFunctionDef = (
Expand Down Expand Up @@ -108,19 +111,15 @@ class ModuleIdentifier:
class _Guppy:
"""Class for the `@guppy` decorator."""

def __call__(self, f: F) -> F:
def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return the function unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
return GuppyFunctionDefinition(defn)

def comptime(self, f: F) -> F:
def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawTracedFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return the function unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
return GuppyFunctionDefinition(defn)

def extend_type(self, defn: TypeDef) -> Callable[[type], type]:
"""Decorator to add new instance functions to a type."""
Expand Down Expand Up @@ -205,15 +204,15 @@ def type_var(
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return TypeVarGuppyDefinition(defn, TypeVar(name)) # type: ignore[return-value]
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]

def nat_var(self, name: str) -> TypeVar:
"""Creates a new const nat variable in a module."""
defn = ConstVarDef(DefId.fresh(), name, None, NumericType(NumericType.Kind.Nat))
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return TypeVarGuppyDefinition(defn, TypeVar(name)) # type: ignore[return-value]
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]

def const_var(self, name: str, ty: str) -> TypeVar:
"""Creates a new const type variable."""
Expand All @@ -224,7 +223,7 @@ def const_var(self, name: str, ty: str) -> TypeVar:
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return TypeVarGuppyDefinition(defn, TypeVar(name)) # type: ignore[return-value]
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]

def custom(
self,
Expand All @@ -233,15 +232,15 @@ def custom(
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[F], F]:
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Decorator to add custom typing or compilation behaviour to function decls.

Optionally, usage of the function as a higher-order value can be disabled. In
that case, the function signature can be omitted if a custom call compiler is
provided.
"""

def dec(f: F) -> F:
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
call_checker = checker or DefaultCallChecker()
func = RawCustomFunctionDef(
DefId.fresh(),
Expand All @@ -254,9 +253,7 @@ def dec(f: F) -> F:
signature,
)
DEF_STORE.register_def(func, get_calling_frame())
# We're pretending to return the function unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(func) # type: ignore[return-value]
return GuppyFunctionDefinition(func)

return dec

Expand All @@ -267,7 +264,7 @@ def hugr_op(
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[F], F]:
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Decorator to annotate function declarations as HUGR ops.

Args:
Expand All @@ -280,14 +277,14 @@ def hugr_op(
"""
return self.custom(OpCompiler(op), checker, higher_order_value, name, signature)

def declare(self, f: F) -> F:
def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return the function unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
return GuppyFunctionDefinition(defn)

def overload(self, *funcs: Any) -> Callable[[F], F]:
def overload(
self, *funcs: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Collects multiple function definitions into one overloaded function.

Consider the following example:
Expand Down Expand Up @@ -341,15 +338,13 @@ def combined_new(): ...
)
func_ids.append(func.id)

def dec(f: F) -> F:
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
dummy_sig = FunctionType([], NoneType())
defn = OverloadedFunctionDef(
DefId.fresh(), f.__name__, None, dummy_sig, func_ids
)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return the class unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
return GuppyFunctionDefinition(defn)

return dec

Expand Down Expand Up @@ -399,10 +394,7 @@ def compile(self, obj: Any) -> ModulePointer:
raise TypeError(f"Object is not a Guppy definition: {obj}")
return ENGINE.compile(obj.id)

def compile_function(
self,
obj: Any,
) -> FuncDefnPointer:
def compile_function(self, obj: GuppyFunctionDefinition[P, T]) -> FuncDefnPointer:
"""Compiles a single function definition."""
from guppylang.engine import ENGINE

Expand All @@ -418,13 +410,15 @@ def compile_function(
mono_args = tuple(None for _ in checked_def.params)

compiled_def = ENGINE.compiled[obj.id, mono_args]
assert isinstance(compiled_def, CompiledFunctionDef | CompiledPytketDef)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were missing a lot of definition kinds in this assert. The new CompiledHugrNodeDef ABC makes this check more uniform

node = compiled_def.func_def.parent_node
assert isinstance(compiled_def, CompiledHugrNodeDef)
node = compiled_def.hugr_node
return FuncDefnPointer(
compiled_module.package, compiled_module.module_index, node
)

def pytket(self, input_circuit: Any) -> Callable[[F], F]:
def pytket(
self, input_circuit: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Adds a pytket circuit function definition with explicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
Expand All @@ -436,12 +430,10 @@ def pytket(self, input_circuit: Any) -> Callable[[F], F]:
except ImportError:
raise TypeError(err_msg) from None

def func(f: F) -> F:
def func(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawPytketDef(DefId.fresh(), f.__name__, None, f, input_circuit)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return the function unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
return GuppyFunctionDefinition(defn)

return func

Expand All @@ -451,7 +443,7 @@ def load_pytket(
input_circuit: Any,
*,
use_arrays: bool = True,
) -> GuppyDefinition:
) -> GuppyFunctionDefinition[..., Any]:
"""Adds a pytket circuit function definition with implicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.load_pytket"
try:
Expand All @@ -468,7 +460,7 @@ def load_pytket(
DefId.fresh(), name, None, span, input_circuit, use_arrays
)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyDefinition(defn)
return GuppyFunctionDefinition(defn)

def wasm_module(
self, filename: str, filehash: int
Expand Down Expand Up @@ -528,7 +520,7 @@ def dec(cls: builtins.type[T]) -> GuppyDefinition:

return dec

def wasm(self, f: PyFunc) -> GuppyDefinition:
def wasm(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
func = RawWasmFunctionDef(
DefId.fresh(),
f.__name__,
Expand All @@ -540,7 +532,7 @@ def wasm(self, f: PyFunc) -> GuppyDefinition:
signature=None,
)
DEF_STORE.register_def(func, get_calling_frame())
return GuppyDefinition(func)
return GuppyFunctionDefinition(func)


def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr:
Expand Down
9 changes: 7 additions & 2 deletions guppylang/definition/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompilerContext, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.definition.value import CompiledHugrNodeDef, CompiledValueDef, ValueDef
from guppylang.span import SourceMap
from guppylang.tys.parsing import type_from_ast

Expand Down Expand Up @@ -55,11 +55,16 @@ def compile_outer(


@dataclass(frozen=True)
class CompiledConstDef(ConstDef, CompiledValueDef):
class CompiledConstDef(ConstDef, CompiledValueDef, CompiledHugrNodeDef):
"""A constant that has been compiled to a Hugr node."""

const_node: Node

@property
def hugr_node(self) -> Node:
"""The Hugr node this definition was compiled into."""
return self.const_node

def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
"""Loads the extern value into a local Hugr dataflow graph."""
return dfg.builder.load(self.const_node)
16 changes: 14 additions & 2 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
load_with_args,
parse_py_func,
)
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.definition.value import (
CallableDef,
CallReturnWires,
CompiledCallableDef,
CompiledHugrNodeDef,
)
from guppylang.diagnostic import Error
from guppylang.error import GuppyError
from guppylang.nodes import GlobalCall
Expand Down Expand Up @@ -130,11 +135,18 @@ def compile_outer(


@dataclass(frozen=True)
class CompiledFunctionDecl(CheckedFunctionDecl, CompiledCallableDef):
class CompiledFunctionDecl(
CheckedFunctionDecl, CompiledCallableDef, CompiledHugrNodeDef
):
"""A function declaration with a corresponding Hugr node."""

declaration: Node

@property
def hugr_node(self) -> Node:
"""The Hugr node this definition was compiled into."""
return self.declaration

def load_with_args(
self,
type_args: Inst,
Expand Down
9 changes: 7 additions & 2 deletions guppylang/definition/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompilerContext, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.definition.value import CompiledHugrNodeDef, CompiledValueDef, ValueDef
from guppylang.span import SourceMap
from guppylang.tys.parsing import type_from_ast

Expand Down Expand Up @@ -70,11 +70,16 @@ def compile_outer(


@dataclass(frozen=True)
class CompiledExternDef(ExternDef, CompiledValueDef):
class CompiledExternDef(ExternDef, CompiledValueDef, CompiledHugrNodeDef):
"""An extern symbol definition that has been compiled to a Hugr constant."""

const_node: Node

@property
def hugr_node(self) -> Node:
"""The Hugr node this definition was compiled into."""
return self.const_node

def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
"""Loads the extern value into a local Hugr dataflow graph."""
return dfg.builder.load(self.const_node)
Loading
Loading