Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 62 additions & 28 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __init__(self, axis_name: str, loc: Optional[nodes.Location] = None):
def slice_from_value(node: ast.Expr) -> ast.Slice:
"""Create an ast.Slice node from a general ast.Expr node."""
slice_node = ast.Slice(
lower=node, upper=ast.BinOp(left=node, op=ast.Add(), right=ast.Constant(value=1))
lower=node,
upper=ast.BinOp(left=node, op=ast.Add(), right=ast.Constant(value=1)),
)
slice_node = ast.copy_location(slice_node, node)
return slice_node
Expand Down Expand Up @@ -361,7 +362,11 @@ class CallInliner(ast.NodeTransformer):

@classmethod
def apply(
cls, func_node: ast.FunctionDef, context: dict, *, call_stack: Optional[Set[str]] = None
cls,
func_node: ast.FunctionDef,
context: dict,
*,
call_stack: Optional[Set[str]] = None,
):
inliner = cls(context, call_stack=call_stack or set())
inliner(func_node)
Expand Down Expand Up @@ -456,7 +461,9 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex
call_ast = copy.deepcopy(call_info["ast"])
self.current_name = call_name
CallInliner.apply(
call_ast, call_info["local_context"], call_stack={*self.call_stack, call_name}
call_ast,
call_info["local_context"],
call_stack={*self.call_stack, call_name},
)

# Extract call arguments
Expand Down Expand Up @@ -509,7 +516,10 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex
template_fmt = "{name}__" + call_id_suffix

gt_meta.map_symbol_names(
call_ast, name_mapping, template_fmt=template_fmt, skip_names=self.all_skip_names
call_ast,
name_mapping,
template_fmt=template_fmt,
skip_names=self.all_skip_names,
)

# Replace returns by assignments in subroutine
Expand Down Expand Up @@ -645,7 +655,9 @@ def _make_temp_decls(


def _make_init_computations(
temp_decls: Dict[str, nodes.FieldDecl], init_values: Dict[str, Any], func_node: ast.AST
temp_decls: Dict[str, nodes.FieldDecl],
init_values: Dict[str, Any],
func_node: ast.AST,
) -> List[nodes.ComputationBlock]:
if not temp_decls:
return []
Expand All @@ -672,7 +684,8 @@ def _make_init_computations(
else:
stmts.append(
nodes.Assign(
target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name]
target=nodes.FieldRef.at_center(name, axes=decl.axes),
value=init_values[name],
)
)

Expand Down Expand Up @@ -833,12 +846,14 @@ def __init__(
"int64": nodes.DataType.INT64,
"float32": nodes.DataType.FLOAT32,
"float64": nodes.DataType.FLOAT64,
"int": nodes.DataType.INT32
if self.literal_int_precision == 32
else nodes.DataType.INT64,
"float": nodes.DataType.FLOAT32
if self.literal_float_precision == 32
else nodes.DataType.FLOAT64,
"int": (
nodes.DataType.INT32 if self.literal_int_precision == 32 else nodes.DataType.INT64
),
"float": (
nodes.DataType.FLOAT32
if self.literal_float_precision == 32
else nodes.DataType.FLOAT64
),
} # Conversion table for types to DataTypes

def __call__(self, ast_root: ast.AST):
Expand Down Expand Up @@ -910,7 +925,8 @@ def _visit_with_horizontal(
self, node: ast.withitem, loc: nodes.Location
) -> List[Dict[str, nodes.AxisInterval]]:
syntax_error = GTScriptSyntaxError(
f"Invalid 'with' statement at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'with' statement at line {loc.line} (column {loc.column})",
loc=loc,
)

call_args = node.context_expr.args
Expand All @@ -930,7 +946,8 @@ def _are_intervals_nonoverlapping(self, compute_blocks: List[nodes.ComputationBl

def _visit_iteration_order_node(self, node: ast.withitem, loc: nodes.Location):
syntax_error = GTScriptSyntaxError(
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})",
loc=loc,
)
comp_node = node.context_expr
if len(comp_node.args) + len(comp_node.keywords) != 1 or any(
Expand Down Expand Up @@ -992,7 +1009,8 @@ def _visit_interval_node(self, node: ast.withitem, loc: nodes.Location):
def _visit_computation_node(self, node: ast.With) -> nodes.ComputationBlock:
loc = nodes.Location.from_ast_node(node, scope=self.stencil_name)
syntax_error = GTScriptSyntaxError(
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})",
loc=loc,
)

# Parse computation specification, i.e. `withItems` nodes
Expand Down Expand Up @@ -1143,7 +1161,9 @@ def visit_Index(self, node: ast.Index):
return index

def _eval_new_spatial_index(
self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]]
self,
index_nodes: Sequence[nodes.Expr],
field_axes: Optional[Set[Literal["I", "J", "K"]]],
) -> List[int]:
index_dict = {}
all_spatial_axes = ("I", "J", "K")
Expand Down Expand Up @@ -1188,7 +1208,9 @@ def _eval_new_spatial_index(
return [index_dict.get(axis, 0) for axis in ("I", "J", "K") if axis in field_axes]

def _eval_index(
self, node: ast.Subscript, field_axes: Optional[Set[Literal["I", "J", "K"]]] = None
self,
node: ast.Subscript,
field_axes: Optional[Set[Literal["I", "J", "K"]]] = None,
) -> list[int] | nodes.AbsoluteKIndex | None:
tuple_or_expr = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
index_nodes = gt_utils.listify(
Expand Down Expand Up @@ -2012,12 +2034,16 @@ def annotate_definition(
temp_init_values: Dict[str, numbers.Number] = {}

frontend_types_to_native_types = nodes.frontend_type_to_native_type(
options.literal_int_precision
if options is not None
else gt_definitions.LITERAL_INT_PRECISION,
options.literal_float_precision
if options is not None
else gt_definitions.LITERAL_FLOAT_PRECISION,
(
options.literal_int_precision
if options is not None
else gt_definitions.LITERAL_INT_PRECISION
),
(
options.literal_float_precision
if options is not None
else gt_definitions.LITERAL_FLOAT_PRECISION
),
)

ann_assign_context = {
Expand Down Expand Up @@ -2189,7 +2215,9 @@ def resolve_external_symbols(
(
attr_name,
GTScriptParser.eval_external(
attr_name, context, nodes.Location.from_ast_node(attr_nodes[0])
attr_name,
context,
nodes.Location.from_ast_node(attr_nodes[0]),
),
)
)
Expand Down Expand Up @@ -2239,7 +2267,8 @@ def extract_arg_descriptors(self):
for arg_info, arg_annotation in zip(api_signature, api_annotations):
try:
assert arg_annotation in gtscript._VALID_DATA_TYPES or isinstance(
arg_annotation, (gtscript._SequenceDescriptor, gtscript._FieldDescriptor)
arg_annotation,
(gtscript._SequenceDescriptor, gtscript._FieldDescriptor),
), "Invalid parameter annotation"

if arg_annotation in gtscript._VALID_DATA_TYPES:
Expand All @@ -2255,7 +2284,10 @@ def extract_arg_descriptors(self):
data_type = nodes.DataType.from_dtype(np.dtype(arg_annotation))
length = arg_annotation.length
parameter_decls[arg_info.name] = nodes.VarDecl(
name=arg_info.name, data_type=data_type, length=length, is_api=True
name=arg_info.name,
data_type=data_type,
length=length,
is_api=True,
)
else:
assert isinstance(arg_annotation, gtscript._FieldDescriptor)
Expand Down Expand Up @@ -2337,7 +2369,9 @@ def run(self, backend_name: str):
fields_decls.update(temp_decls)

init_computations = _make_init_computations(
temp_decls, self.definition._gtscript_["temp_init_values"], func_node=main_func_node
temp_decls,
self.definition._gtscript_["temp_init_values"],
func_node=main_func_node,
)

# Generate definition IR
Expand All @@ -2363,7 +2397,7 @@ def run(self, backend_name: str):
parameters=[
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
],
computations=init_computations + computations if init_computations else computations,
computations=(init_computations + computations if init_computations else computations),
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
externals=self.resolved_externals,
docstring=inspect.getdoc(self.definition) or "",
loc=nodes.Location.from_ast_node(self.ast_root.body[0]),
Expand Down
33 changes: 28 additions & 5 deletions src/gt4py/cartesian/gtc/common.py
Comment thread
katrinafandrich marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ def _precision_to_datatype(func: NativeFunction) -> DataType:
raise NotImplementedError(f"Found unknown precision specification {func}")

def _impl(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None:
if instance.func in (NativeFunction.ISFINITE, NativeFunction.ISINF, NativeFunction.ISNAN):
if instance.func in (
NativeFunction.ISFINITE,
NativeFunction.ISINF,
NativeFunction.ISNAN,
):
instance.dtype = DataType.BOOL # type: ignore[attr-defined]
elif instance.func in (
NativeFunction.INT32,
Expand Down Expand Up @@ -665,7 +669,12 @@ def visit_Node(
self.generic_visit(node, loop_order=loop_order, **kwargs)

def visit_AssignStmt(
self, node: AssignStmt, *, loop_order: LoopOrder, symtable: Dict[str, Any], **kwargs: Any
self,
node: AssignStmt,
*,
loop_order: LoopOrder,
symtable: Dict[str, Any],
**kwargs: Any,
) -> None:
decl = symtable.get(node.left.name, None)
if decl is None:
Expand Down Expand Up @@ -920,7 +929,10 @@ def data_type_to_typestr(dtype: DataType) -> str:
ComparisonOperator.EQ: "equal",
ComparisonOperator.NE: "not_equal",
},
LogicalOperator: {LogicalOperator.AND: "logical_and", LogicalOperator.OR: "logical_or"},
LogicalOperator: {
LogicalOperator.AND: "logical_and",
LogicalOperator.OR: "logical_or",
},
NativeFunction: {
NativeFunction.ABS: "abs",
NativeFunction.MIN: "minimum",
Expand Down Expand Up @@ -965,11 +977,22 @@ def data_type_to_typestr(dtype: DataType) -> str:

def op_to_ufunc(
op: Union[
UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction
UnaryOperator,
ArithmeticOperator,
ComparisonOperator,
LogicalOperator,
NativeFunction,
],
) -> np.ufunc:
if not isinstance(
op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction)
op,
(
UnaryOperator,
ArithmeticOperator,
ComparisonOperator,
LogicalOperator,
NativeFunction,
),
):
raise TypeError(
"Can only convert instances of GTC operators and supported native functions to typestr."
Expand Down
12 changes: 9 additions & 3 deletions src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def visit_FieldAccess(
offsets = node.offset.to_dict()
if node.name in tmps_name_map:
assert offsets["i"] == offsets["j"] == offsets["k"] == 0, (
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
"Non-zero offset in temporary that is replaced?!"
"No K-offset capabilities of 3D temporaries. Must define '"
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
+ str(node.name)
+ "' as a FloatField."
)
return oir.ScalarAccess(name=tmps_name_map[node.name], dtype=node.dtype)
return self.generic_visit(node, tmps_name_map=tmps_name_map, **kwargs)
Expand All @@ -52,7 +54,9 @@ def visit_HorizontalExecution(
declarations=node.declarations
+ [
oir.LocalScalar(
name=tmps_name_map[tmp], dtype=symtable[tmp].dtype, loc=symtable[tmp].loc
name=tmps_name_map[tmp],
dtype=symtable[tmp].dtype,
loc=symtable[tmp].loc,
)
for tmp in local_tmps_to_replace
],
Expand All @@ -76,7 +80,9 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
name=node.name,
params=node.params,
vertical_loops=self.visit(
node.vertical_loops, new_symbol_name=symbol_name_creator(all_names), **kwargs
node.vertical_loops,
new_symbol_name=symbol_name_creator(all_names),
**kwargs,
),
declarations=[d for d in node.declarations if d.name not in tmps_to_replace],
loc=node.loc,
Expand Down
12 changes: 9 additions & 3 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _validate_args( # Function is too complex
)
):
raise ValueError(
f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
f"Compute domain too large (provided: {domain}, maximum: {max_domain}. Adjust K interval as needed.)"
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
)

if domain[2] < self.domain_info.min_sequential_axis_size:
Expand Down Expand Up @@ -434,7 +434,10 @@ def _validate_args( # Function is too complex

if (
arg_info.dimensions is not None
and (*field_info.axes, *(str(d) for d in range(len(field_info.data_dims))))
and (
*field_info.axes,
*(str(d) for d in range(len(field_info.data_dims))),
)
!= arg_info.dimensions
):
raise ValueError(
Expand Down Expand Up @@ -589,7 +592,10 @@ def _call_run(
exec_info["call_run_end_time"] = time.perf_counter()

def freeze(
self: StencilObject, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]
self: StencilObject,
*,
origin: dict[str, tuple[int, ...]],
domain: tuple[int, ...],
) -> FrozenStencil:
"""Return a StencilObject wrapper with a fixed domain and origin for each argument.

Expand Down
2 changes: 1 addition & 1 deletion tests/eve_tests/unit_tests/test_type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_immutability(self):

fl = FrozenList([0, 1, 2, 3, 4, 5])

with pytest.raises(TypeError, match="object does not support item assignment"):
with pytest.raises(TypeError, match="object does not support item assignment."):
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
fl[2] = -2

def test_instance_check(self):
Expand Down