diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 48d3e9758e..f17f56025d 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1339,9 +1339,17 @@ def visit_If(self, node: ast.If) -> list: self.decls_stack[-2].extend(self.decls_stack[-1]) self.decls_stack.pop() + try: + condition = self.visit(node.test) + except KeyError as e: + raise GTScriptSyntaxError( + message="Using function calls in the condition of an if is not allowed," + + " the function needs to be assigned to a variable outside the condition.", + loc=nodes.Location.from_ast_node(node), + ) from e result.append( nodes.If( - condition=self.visit(node.test), + condition=condition, loc=nodes.Location.from_ast_node(node), main_body=nodes.BlockStmt(stmts=main_stmts, loc=nodes.Location.from_ast_node(node)), else_body=( diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 1f7a779835..bfd3517475 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -17,6 +17,7 @@ import gt4py.cartesian.definitions as gt_definitions from gt4py.cartesian import gtscript from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend, nodes +from gt4py.cartesian.frontend.exceptions import GTScriptSyntaxError from gt4py.cartesian.gtscript import ( __INLINED, FORWARD, @@ -1719,6 +1720,26 @@ def func(field: gtscript.Field[np.float64]): ) +@gtscript.function +def boolean_return(a): + return a == 1 + + +class TestFunctionIfError: + def test_function_if_error(self): + def func(field: gtscript.Field[np.float64]): # type: ignore + with computation(PARALLEL), interval(...): + field = 0 + if boolean_return(field): + field = 1 + + with pytest.raises( + GTScriptSyntaxError, + match="Using function calls in the condition of an if is not allowed", + ): + parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + + class TestAnnotations: @staticmethod def sumdiff_defs(