diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index ab5175280..fe0bf4691 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -343,6 +343,15 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr: def visit_Call(self, node: ast.Call) -> ast.AST: return is_comptime_expression(node) or self.generic_visit(node) + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: + # Desugar negated numeric constants into constants + match node.op, node.operand: + case ast.USub(), ast.Constant(value=float(v) | int(v)) as const: + const.value = -v + return with_loc(node, const) + case _: + return self.generic_visit(node) + def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we # can turn them into regular expressions by assigning True/False to a temporary diff --git a/guppylang/checker/errors/type_errors.py b/guppylang/checker/errors/type_errors.py index 6aec747ac..9a14a292c 100644 --- a/guppylang/checker/errors/type_errors.py +++ b/guppylang/checker/errors/type_errors.py @@ -323,3 +323,22 @@ class TupleIndexOutOfBoundsError(Error): ) index: int size: int + + +@dataclass(frozen=True) +class IntOverflowError(Error): + title: ClassVar[str] = "Integer {over_under}flow" + span_label: ClassVar[str] = ( + "Value does not fit into a {bits}-bit {signed_unsigned} integer" + ) + signed: bool + bits: int + is_underflow: bool + + @property + def over_under(self) -> str: + return "under" if self.is_underflow else "over" + + @property + def signed_unsigned(self) -> str: + return "signed" if self.signed else "unsigned" diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 0a7d63196..e9e0c18c9 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -69,6 +69,7 @@ BinaryOperatorNotDefinedError, ConstMismatchError, IllegalConstant, + IntOverflowError, ModuleMemberNotFoundError, NonLinearInstantiateError, NotCallableError, @@ -1336,9 +1337,11 @@ def python_value_to_guppy_type( return string_type() # Only resolve `int` to `nat` if the user specifically asked for it case int(n) if type_hint == nat_type() and n >= 0: + _int_bounds_check(n, node, signed=False) return nat_type() # Otherwise, default to `int` for consistency with Python - case int(): + case int(n): + _int_bounds_check(n, node, signed=True) return int_type() case float(): return float_type() @@ -1361,6 +1364,19 @@ def python_value_to_guppy_type( return None +def _int_bounds_check(value: int, node: AstNode, signed: bool) -> None: + bit_width = 1 << NumericType.INT_WIDTH + if signed: + max_v = (1 << (bit_width - 1)) - 1 + min_v = -(1 << (bit_width - 1)) + else: + max_v = (1 << bit_width) - 1 + min_v = 0 + if value < min_v or value > max_v: + err = IntOverflowError(node, signed, bit_width, value < min_v) + raise GuppyTypeError(err) + + def _python_list_to_guppy_type( vs: list[Any], node: ast.AST, globals: Globals, type_hint: Type | None ) -> OpaqueType | None: diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 3d67bb103..a36b4587e 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -36,7 +36,7 @@ CompiledCallableDef, CompiledValueDef, ) -from guppylang.error import GuppyComptimeError, GuppyError, InternalGuppyError +from guppylang.error import GuppyError, InternalGuppyError from guppylang.nodes import ( BarrierExpr, DesugaredArrayComp, @@ -58,7 +58,10 @@ TupleAccessAndDrop, TypeApply, ) -from guppylang.std._internal.compiler.arithmetic import convert_ifromusize +from guppylang.std._internal.compiler.arithmetic import ( + UnsignedIntVal, + convert_ifromusize, +) from guppylang.std._internal.compiler.array import ( array_convert_from_std_array, array_convert_to_std_array, @@ -785,16 +788,14 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: case str(): return hugr.std.prelude.StringVal(v) case int(): - bit_width = 1 << NumericType.INT_WIDTH - max_v = (1 << (bit_width - 1)) - 1 - min_v = -(1 << (bit_width - 1)) - if v < min_v or v > max_v: - msg = ( - f"Integer value {v} is out of bounds for {bit_width}" - "-bit signed integer." - ) - raise GuppyComptimeError(msg) - return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH) + assert isinstance(exp_ty, NumericType) + match exp_ty.kind: + case NumericType.Kind.Nat: + return UnsignedIntVal(v, width=NumericType.INT_WIDTH) + case NumericType.Kind.Int: + return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH) + case _: + raise InternalGuppyError("Unexpected numeric type") case float(): return hugr.std.float.FloatVal(v) case tuple(elts): diff --git a/guppylang/std/_internal/compiler/arithmetic.py b/guppylang/std/_internal/compiler/arithmetic.py index 85e151b99..3adfdcde9 100644 --- a/guppylang/std/_internal/compiler/arithmetic.py +++ b/guppylang/std/_internal/compiler/arithmetic.py @@ -1,9 +1,10 @@ """Native arithmetic operations from the HUGR std, and compilers for non native ones.""" from collections.abc import Sequence +from dataclasses import dataclass import hugr.std.int -from hugr import ops +from hugr import model, ops, val from hugr import tys as ht from hugr.std.int import int_t @@ -12,6 +13,30 @@ INT_T = int_t(NumericType.INT_WIDTH) + +@dataclass +class UnsignedIntVal(val.ExtensionValue): # TODO: Upstream this to hugr-py? + """Custom value for an unsigned integer.""" + + v: int + width: int + + def __post_init__(self) -> None: + assert self.v >= 0 + + def to_value(self) -> val.Extension: + payload = {"log_width": self.width, "value": self.v} + return val.Extension("ConstInt", typ=int_t(self.width), val=payload) + + def __str__(self) -> str: + return f"{self.v}" + + def to_model(self) -> model.Term: + return model.Apply( + "arithmetic.int.const", [model.Literal(self.width), model.Literal(self.v)] + ) + + # ------------------------------------------------------ # --------- std.arithmetic.int operations -------------- # ------------------------------------------------------ diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index c3ce0b424..9e6d144be 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -96,10 +96,11 @@ def arg_from_ast( # Integer literals are turned into nat args since these are the only ones we support # right now. # TODO: Once we also have int args etc, we need proper inference logic here - if isinstance(node, ast.Constant) and isinstance(node.value, int): - # Fun fact: int ast.Constant values are never negative since e.g. `-5` is a - # `ast.UnaryOp` negation of a `ast.Constant(5)` - assert node.value >= 0 + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, int) + and node.value >= 0 + ): nat_ty = NumericType(NumericType.Kind.Nat) return ConstArg(ConstValue(nat_ty, node.value)) diff --git a/tests/error/tracing_errors/bad_int1.err b/tests/error/tracing_errors/bad_int1.err index 88dcdafef..047b0add8 100644 --- a/tests/error/tracing_errors/bad_int1.err +++ b/tests/error/tracing_errors/bad_int1.err @@ -3,4 +3,4 @@ Traceback (most recent call last): guppy.compile(test) File "$FILE", line 7, in test return x + (1 << 63) -guppylang.error.GuppyComptimeError: Integer value 9223372036854775808 is out of bounds for 64-bit signed integer. +guppylang.error.GuppyComptimeError: Integer overflow: Value does not fit into a 64-bit signed integer diff --git a/tests/error/tracing_errors/bad_int2.err b/tests/error/tracing_errors/bad_int2.err index e9dcabb8b..dcc7350b9 100644 --- a/tests/error/tracing_errors/bad_int2.err +++ b/tests/error/tracing_errors/bad_int2.err @@ -3,4 +3,4 @@ Traceback (most recent call last): guppy.compile(test) File "$FILE", line 7, in test return x + (-(1 << 63) - 1) -guppylang.error.GuppyComptimeError: Integer value -9223372036854775809 is out of bounds for 64-bit signed integer. +guppylang.error.GuppyComptimeError: Integer underflow: Value does not fit into a 64-bit signed integer diff --git a/tests/error/type_errors/int_overflow.err b/tests/error/type_errors/int_overflow.err new file mode 100644 index 000000000..9bb66a5a3 --- /dev/null +++ b/tests/error/type_errors/int_overflow.err @@ -0,0 +1,8 @@ +Error: Integer overflow (at $FILE:6:11) + | +4 | @guppy +5 | def foo() -> int: +6 | return 9_223_372_036_854_775_808 + | ^^^^^^^^^^^^^^^^^^^^^^^^^ Value does not fit into a 64-bit signed integer + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/int_overflow.py b/tests/error/type_errors/int_overflow.py new file mode 100644 index 000000000..70b9c1b79 --- /dev/null +++ b/tests/error/type_errors/int_overflow.py @@ -0,0 +1,9 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> int: + return 9_223_372_036_854_775_808 + + +guppy.compile(foo) diff --git a/tests/error/type_errors/int_underflow.err b/tests/error/type_errors/int_underflow.err new file mode 100644 index 000000000..2ff084d20 --- /dev/null +++ b/tests/error/type_errors/int_underflow.err @@ -0,0 +1,8 @@ +Error: Integer underflow (at $FILE:6:11) + | +4 | @guppy +5 | def foo() -> int: +6 | return -9_223_372_036_854_775_809 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ Value does not fit into a 64-bit signed integer + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/int_underflow.py b/tests/error/type_errors/int_underflow.py new file mode 100644 index 000000000..72176f74f --- /dev/null +++ b/tests/error/type_errors/int_underflow.py @@ -0,0 +1,9 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> int: + return -9_223_372_036_854_775_809 + + +guppy.compile(foo) diff --git a/tests/error/type_errors/nat_overflow.err b/tests/error/type_errors/nat_overflow.err new file mode 100644 index 000000000..5e5750f40 --- /dev/null +++ b/tests/error/type_errors/nat_overflow.err @@ -0,0 +1,8 @@ +Error: Integer overflow (at $FILE:7:11) + | +5 | @guppy +6 | def foo() -> nat: +7 | return 18_446_744_073_709_551_616 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ Value does not fit into a 64-bit unsigned integer + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/nat_overflow.py b/tests/error/type_errors/nat_overflow.py new file mode 100644 index 000000000..389f14c1d --- /dev/null +++ b/tests/error/type_errors/nat_overflow.py @@ -0,0 +1,10 @@ +from guppylang.decorator import guppy +from guppylang.std.num import nat + + +@guppy +def foo() -> nat: + return 18_446_744_073_709_551_616 + + +guppy.compile(foo) diff --git a/tests/error/type_errors/negative_nat_literal.err b/tests/error/type_errors/negative_nat_literal.err new file mode 100644 index 000000000..06dc1d894 --- /dev/null +++ b/tests/error/type_errors/negative_nat_literal.err @@ -0,0 +1,8 @@ +Error: Type mismatch (at $FILE:7:11) + | +5 | @guppy +6 | def foo() -> nat: +7 | return -1 + | ^^ Expected return value of type `nat`, got `int` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/negative_nat_literal.py b/tests/error/type_errors/negative_nat_literal.py new file mode 100644 index 000000000..9942a156a --- /dev/null +++ b/tests/error/type_errors/negative_nat_literal.py @@ -0,0 +1,10 @@ +from guppylang.decorator import guppy +from guppylang.std.num import nat + + +@guppy +def foo() -> nat: + return -1 + + +guppy.compile(foo) diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index bfe370579..0069c3587 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -40,6 +40,23 @@ def const() -> nat: validate(const) +def test_int_bounds(run_int_fn): + @guppy + def main() -> int: + return 9_223_372_036_854_775_807 + -9_223_372_036_854_775_808 + + run_int_fn(main, -1) + + +def test_nat_bounds(run_int_fn): + @guppy + def main() -> nat: + x: nat = 18_446_744_073_709_551_614 + return x - x + + run_int_fn(main, 0) + + def test_aug_assign(run_int_fn): @guppy def add(x: int) -> int: