From 1afb3c7f3226ccc44732c2bec1377979383b6c36 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 23 Jul 2025 16:39:24 +0100 Subject: [PATCH 1/4] feat: Extend comptime arguments to arbitrary non-linear types --- guppylang/tys/parsing.py | 11 ------- tests/integration/test_comptime_args.py | 38 ++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index f3182ddf7..683306e8c 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -365,17 +365,6 @@ def type_with_flags_from_ast( flags |= InputFlags.Comptime if not ty.copyable or not ty.droppable: raise GuppyError(LinearComptimeError(node.right, ty)) - - # TODO: For now we can only do `nat` comptime args since they lower to - # Hugr bounded nats. Extend to arbitrary types via monomorphization. - # See https://github.com/CQCL/guppylang/issues/1008 - if ( - not isinstance(ty, NumericType) - or not ty.kind == NumericType.Kind.Nat - ): - raise GuppyError( - UnsupportedError(node.right, f"`{ty}` comptime arguments") - ) case _: raise GuppyError(InvalidFlagError(node.right)) return ty, flags diff --git a/tests/integration/test_comptime_args.py b/tests/integration/test_comptime_args.py index 69835eaf7..c1df09e1f 100644 --- a/tests/integration/test_comptime_args.py +++ b/tests/integration/test_comptime_args.py @@ -2,7 +2,7 @@ from guppylang.std.builtins import nat, comptime, array -def test_basic(validate): +def test_basic_nat(validate): @guppy def foo(n: nat @ comptime) -> nat: return nat(n + 1) @@ -14,6 +14,42 @@ def main() -> nat: validate(guppy.compile(main)) +def test_basic_int(validate): + @guppy + def foo(n: int @ comptime) -> int: + return n + 1 + + @guppy + def main() -> int: + return foo(42) + + validate(guppy.compile(main)) + + +def test_basic_float(validate): + @guppy + def foo(f: float @ comptime) -> float: + return f + 1.5 + + @guppy + def main() -> float: + return foo(42.0) + + validate(guppy.compile(main)) + + +def test_basic_bool(validate): + @guppy + def foo(b: bool @ comptime) -> bool: + return not b + + @guppy + def main() -> bool: + return foo(True) + + validate(guppy.compile(main)) + + def test_multiple(validate): @guppy def foo( From 8974423a060bf38c30196d3f6d268b29c241cf38 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 23 Jul 2025 17:07:11 +0100 Subject: [PATCH 2/4] Prevent generic comptime args --- guppylang/tys/parsing.py | 11 +++++++++ guppylang/tys/subst.py | 27 +++++++++++++++++++-- tests/error/comptime_arg_errors/generic.err | 6 ++--- tests/error/comptime_arg_errors/generic.py | 11 ++++++--- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 683306e8c..d83af99b5 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -36,6 +36,7 @@ NonLinearOwnedError, ) from guppylang.tys.param import ConstParam, Parameter, TypeParam +from guppylang.tys.subst import BoundVarFinder from guppylang.tys.ty import ( FuncInput, FunctionType, @@ -365,6 +366,16 @@ def type_with_flags_from_ast( flags |= InputFlags.Comptime if not ty.copyable or not ty.droppable: raise GuppyError(LinearComptimeError(node.right, ty)) + # For now, we don't allow comptime annotations on generic inputs + # TODO: In the future we might want to allow stuff like + # `def foo[T: (Copy, Discard](x: T @comptime)`. + # Also see the todo in `parse_parameter`. + var_finder = BoundVarFinder() + ty.visit(var_finder) + if var_finder.bound_vars: + raise GuppyError( + UnsupportedError(node.left, "Generic comptime arguments") + ) case _: raise GuppyError(InvalidFlagError(node.right)) return ty, flags diff --git a/guppylang/tys/subst.py b/guppylang/tys/subst.py index 1da25e5ab..59cd61d77 100644 --- a/guppylang/tys/subst.py +++ b/guppylang/tys/subst.py @@ -4,7 +4,7 @@ from guppylang.error import InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg -from guppylang.tys.common import Transformer +from guppylang.tys.common import Transformer, Visitor from guppylang.tys.const import BoundConstVar, Const, ConstBase, ExistentialConstVar from guppylang.tys.ty import ( BoundTypeVar, @@ -13,7 +13,7 @@ Type, TypeBase, ) -from guppylang.tys.var import ExistentialVar +from guppylang.tys.var import ExistentialVar, BoundVar Subst = dict[ExistentialVar, Type | Const] Inst = Sequence[Argument] @@ -82,3 +82,26 @@ def _transform_FunctionType(self, ty: FunctionType) -> Type | None: if ty.parametrized: raise InternalGuppyError("Tried to instantiate under binder") return None + + +class BoundVarFinder(Visitor): + """Type visitor that looks for occurrences of bound variables.""" + + bound_vars: set[BoundVar] + + def __init__(self) -> None: + self.bound_vars = set() + + @functools.singledispatchmethod + def visit(self, ty: Any) -> bool: # type: ignore[override] + return False + + @visit.register + def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> bool: + self.bound_vars.add(ty) + return False + + @visit.register + def _transform_BoundConstVar(self, c: BoundConstVar) -> bool: + self.bound_vars.add(c) + return False diff --git a/tests/error/comptime_arg_errors/generic.err b/tests/error/comptime_arg_errors/generic.err index 888f5f8e7..8f4120ad0 100644 --- a/tests/error/comptime_arg_errors/generic.err +++ b/tests/error/comptime_arg_errors/generic.err @@ -1,8 +1,8 @@ -Error: Unsupported (at $FILE:8:15) +Error: Unsupported (at $FILE:8:11) | 6 | 7 | @guppy -8 | def main(q: T @comptime) -> None: - | ^^^^^^^^ `T` comptime arguments are not supported +8 | def foo(q: T @comptime) -> T: + | ^ Generic comptime arguments are not supported Guppy compilation failed due to 1 previous error diff --git a/tests/error/comptime_arg_errors/generic.py b/tests/error/comptime_arg_errors/generic.py index 2e5d223bb..bad0eac53 100644 --- a/tests/error/comptime_arg_errors/generic.py +++ b/tests/error/comptime_arg_errors/generic.py @@ -1,12 +1,17 @@ from guppylang import guppy from guppylang.std.builtins import comptime -T = guppy.type_var("T") +T = guppy.type_var("T", copyable=True, droppable=True) @guppy -def main(q: T @comptime) -> None: - pass +def foo(q: T @comptime) -> T: + return T + + +@guppy +def main() -> int: + return foo(42) guppy.compile(main) From 0780de0a1cd8127d8b0a36bb4f438557bb65bfa8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 23 Jul 2025 17:07:33 +0100 Subject: [PATCH 3/4] Delete test --- tests/error/comptime_arg_errors/unsupported_ty.err | 8 -------- tests/error/comptime_arg_errors/unsupported_ty.py | 10 ---------- 2 files changed, 18 deletions(-) delete mode 100644 tests/error/comptime_arg_errors/unsupported_ty.err delete mode 100644 tests/error/comptime_arg_errors/unsupported_ty.py diff --git a/tests/error/comptime_arg_errors/unsupported_ty.err b/tests/error/comptime_arg_errors/unsupported_ty.err deleted file mode 100644 index a7be9a281..000000000 --- a/tests/error/comptime_arg_errors/unsupported_ty.err +++ /dev/null @@ -1,8 +0,0 @@ -Error: Unsupported (at $FILE:6:17) - | -4 | -5 | @guppy -6 | def main(n: int @comptime) -> None: - | ^^^^^^^^ `int` comptime arguments are not supported - -Guppy compilation failed due to 1 previous error diff --git a/tests/error/comptime_arg_errors/unsupported_ty.py b/tests/error/comptime_arg_errors/unsupported_ty.py deleted file mode 100644 index 44e8c1e2b..000000000 --- a/tests/error/comptime_arg_errors/unsupported_ty.py +++ /dev/null @@ -1,10 +0,0 @@ -from guppylang import guppy -from guppylang.std.builtins import comptime - - -@guppy -def main(n: int @comptime) -> None: - pass - - -guppy.compile(main) From 45b48cbc637b6d1c778ed589e9160de5a5ea6fbe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 23 Jul 2025 17:08:04 +0100 Subject: [PATCH 4/4] Lint --- guppylang/tys/subst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/tys/subst.py b/guppylang/tys/subst.py index 59cd61d77..2b1aa9b9c 100644 --- a/guppylang/tys/subst.py +++ b/guppylang/tys/subst.py @@ -13,7 +13,7 @@ Type, TypeBase, ) -from guppylang.tys.var import ExistentialVar, BoundVar +from guppylang.tys.var import BoundVar, ExistentialVar Subst = dict[ExistentialVar, Type | Const] Inst = Sequence[Argument]