Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
- Vogt, Hannes. ETH Zurich - CSCS
- Weber, Benjamin. MeteoSwiss
- Wicky, Tobias. Allen Institute for AI
- Fandrich, Katrina. SSAI/NASA-GSFC
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
90 changes: 62 additions & 28 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __init__(self, axis_name: str, loc: Optional[nodes.Location] = None):
def slice_from_value(node: ast.Expr) -> ast.Slice:
"""Create an ast.Slice node from a general ast.Expr node."""
slice_node = ast.Slice(
lower=node, upper=ast.BinOp(left=node, op=ast.Add(), right=ast.Constant(value=1))
lower=node,
upper=ast.BinOp(left=node, op=ast.Add(), right=ast.Constant(value=1)),
)
slice_node = ast.copy_location(slice_node, node)
return slice_node
Expand Down Expand Up @@ -361,7 +362,11 @@ class CallInliner(ast.NodeTransformer):

@classmethod
def apply(
cls, func_node: ast.FunctionDef, context: dict, *, call_stack: Optional[Set[str]] = None
cls,
func_node: ast.FunctionDef,
context: dict,
*,
call_stack: Optional[Set[str]] = None,
):
inliner = cls(context, call_stack=call_stack or set())
inliner(func_node)
Expand Down Expand Up @@ -457,7 +462,9 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex
call_ast = copy.deepcopy(call_info["ast"])
self.current_name = call_name
CallInliner.apply(
call_ast, call_info["local_context"], call_stack={*self.call_stack, call_name}
call_ast,
call_info["local_context"],
call_stack={*self.call_stack, call_name},
)

# Extract call arguments
Expand Down Expand Up @@ -510,7 +517,10 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex
template_fmt = "{name}__" + call_id_suffix

gt_meta.map_symbol_names(
call_ast, name_mapping, template_fmt=template_fmt, skip_names=self.all_skip_names
call_ast,
name_mapping,
template_fmt=template_fmt,
skip_names=self.all_skip_names,
)

# Replace returns by assignments in subroutine
Expand Down Expand Up @@ -646,7 +656,9 @@ def _make_temp_decls(


def _make_init_computations(
temp_decls: Dict[str, nodes.FieldDecl], init_values: Dict[str, Any], func_node: ast.AST
temp_decls: Dict[str, nodes.FieldDecl],
init_values: Dict[str, Any],
func_node: ast.AST,
) -> List[nodes.ComputationBlock]:
if not temp_decls:
return []
Expand All @@ -673,7 +685,8 @@ def _make_init_computations(
else:
stmts.append(
nodes.Assign(
target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name]
target=nodes.FieldRef.at_center(name, axes=decl.axes),
value=init_values[name],
)
)

Expand Down Expand Up @@ -834,12 +847,14 @@ def __init__(
"int64": nodes.DataType.INT64,
"float32": nodes.DataType.FLOAT32,
"float64": nodes.DataType.FLOAT64,
"int": nodes.DataType.INT32
if self.literal_int_precision == 32
else nodes.DataType.INT64,
"float": nodes.DataType.FLOAT32
if self.literal_float_precision == 32
else nodes.DataType.FLOAT64,
"int": (
nodes.DataType.INT32 if self.literal_int_precision == 32 else nodes.DataType.INT64
),
"float": (
nodes.DataType.FLOAT32
if self.literal_float_precision == 32
else nodes.DataType.FLOAT64
),
} # Conversion table for types to DataTypes

def __call__(self, ast_root: ast.AST):
Expand Down Expand Up @@ -911,7 +926,8 @@ def _visit_with_horizontal(
self, node: ast.withitem, loc: nodes.Location
) -> List[Dict[str, nodes.AxisInterval]]:
syntax_error = GTScriptSyntaxError(
f"Invalid 'with' statement at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'with' statement at line {loc.line} (column {loc.column})",
loc=loc,
)

call_args = node.context_expr.args
Expand All @@ -931,7 +947,8 @@ def _are_intervals_nonoverlapping(self, compute_blocks: List[nodes.ComputationBl

def _visit_iteration_order_node(self, node: ast.withitem, loc: nodes.Location):
syntax_error = GTScriptSyntaxError(
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})",
loc=loc,
)
comp_node = node.context_expr
if len(comp_node.args) + len(comp_node.keywords) != 1 or any(
Expand Down Expand Up @@ -993,7 +1010,8 @@ def _visit_interval_node(self, node: ast.withitem, loc: nodes.Location):
def _visit_computation_node(self, node: ast.With) -> nodes.ComputationBlock:
loc = nodes.Location.from_ast_node(node, scope=self.stencil_name)
syntax_error = GTScriptSyntaxError(
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})", loc=loc
f"Invalid 'computation' specification at line {loc.line} (column {loc.column})",
loc=loc,
)

# Parse computation specification, i.e. `withItems` nodes
Expand Down Expand Up @@ -1143,7 +1161,9 @@ def visit_Index(self, node: ast.Index):
return self.visit(node.value)

def _eval_new_spatial_index(
self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]]
self,
index_nodes: Sequence[nodes.Expr],
field_axes: Optional[Set[Literal["I", "J", "K"]]],
) -> List[int]:
index_dict = {}
all_spatial_axes = ("I", "J", "K")
Expand Down Expand Up @@ -1188,7 +1208,9 @@ def _eval_new_spatial_index(
return [index_dict.get(axis, 0) for axis in ("I", "J", "K") if axis in field_axes]

def _eval_index(
self, node: ast.Subscript, field_axes: Optional[Set[Literal["I", "J", "K"]]] = None
self,
node: ast.Subscript,
field_axes: Optional[Set[Literal["I", "J", "K"]]] = None,
) -> list[int] | nodes.AbsoluteKIndex | None:
tuple_or_expr = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
index_nodes = gt_utils.listify(
Expand Down Expand Up @@ -2012,12 +2034,16 @@ def annotate_definition(
temp_init_values: Dict[str, numbers.Number] = {}

frontend_types_to_native_types = nodes.frontend_type_to_native_type(
options.literal_int_precision
if options is not None
else gt_definitions.LITERAL_INT_PRECISION,
options.literal_float_precision
if options is not None
else gt_definitions.LITERAL_FLOAT_PRECISION,
(
options.literal_int_precision
if options is not None
else gt_definitions.LITERAL_INT_PRECISION
),
(
options.literal_float_precision
if options is not None
else gt_definitions.LITERAL_FLOAT_PRECISION
),
)

ann_assign_context = {
Expand Down Expand Up @@ -2189,7 +2215,9 @@ def resolve_external_symbols(
(
attr_name,
GTScriptParser.eval_external(
attr_name, context, nodes.Location.from_ast_node(attr_nodes[0])
attr_name,
context,
nodes.Location.from_ast_node(attr_nodes[0]),
),
)
)
Expand Down Expand Up @@ -2239,7 +2267,8 @@ def extract_arg_descriptors(self):
for arg_info, arg_annotation in zip(api_signature, api_annotations):
try:
assert arg_annotation in gtscript._VALID_DATA_TYPES or isinstance(
arg_annotation, (gtscript._SequenceDescriptor, gtscript._FieldDescriptor)
arg_annotation,
(gtscript._SequenceDescriptor, gtscript._FieldDescriptor),
), "Invalid parameter annotation"

if arg_annotation in gtscript._VALID_DATA_TYPES:
Expand All @@ -2255,7 +2284,10 @@ def extract_arg_descriptors(self):
data_type = nodes.DataType.from_dtype(np.dtype(arg_annotation))
length = arg_annotation.length
parameter_decls[arg_info.name] = nodes.VarDecl(
name=arg_info.name, data_type=data_type, length=length, is_api=True
name=arg_info.name,
data_type=data_type,
length=length,
is_api=True,
)
else:
assert isinstance(arg_annotation, gtscript._FieldDescriptor)
Expand Down Expand Up @@ -2337,7 +2369,9 @@ def run(self, backend_name: str):
fields_decls.update(temp_decls)

init_computations = _make_init_computations(
temp_decls, self.definition._gtscript_["temp_init_values"], func_node=main_func_node
temp_decls,
self.definition._gtscript_["temp_init_values"],
func_node=main_func_node,
)

# Generate definition IR
Expand All @@ -2363,7 +2397,7 @@ def run(self, backend_name: str):
parameters=[
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
],
computations=init_computations + computations,
computations=(init_computations + computations if init_computations else computations),
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
externals=self.resolved_externals,
docstring=inspect.getdoc(self.definition) or "",
loc=nodes.Location.from_ast_node(self.ast_root.body[0]),
Expand Down
33 changes: 28 additions & 5 deletions src/gt4py/cartesian/gtc/common.py
Comment thread
katrinafandrich marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ def _precision_to_datatype(func: NativeFunction) -> DataType:
raise NotImplementedError(f"Found unknown precision specification {func}")

def _impl(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None:
if instance.func in (NativeFunction.ISFINITE, NativeFunction.ISINF, NativeFunction.ISNAN):
if instance.func in (
NativeFunction.ISFINITE,
NativeFunction.ISINF,
NativeFunction.ISNAN,
):
instance.dtype = DataType.BOOL # type: ignore[attr-defined]
elif instance.func in (
NativeFunction.INT32,
Expand Down Expand Up @@ -665,7 +669,12 @@ def visit_Node(
self.generic_visit(node, loop_order=loop_order, **kwargs)

def visit_AssignStmt(
self, node: AssignStmt, *, loop_order: LoopOrder, symtable: Dict[str, Any], **kwargs: Any
self,
node: AssignStmt,
*,
loop_order: LoopOrder,
symtable: Dict[str, Any],
**kwargs: Any,
) -> None:
decl = symtable.get(node.left.name, None)
if decl is None:
Expand Down Expand Up @@ -920,7 +929,10 @@ def data_type_to_typestr(dtype: DataType) -> str:
ComparisonOperator.EQ: "equal",
ComparisonOperator.NE: "not_equal",
},
LogicalOperator: {LogicalOperator.AND: "logical_and", LogicalOperator.OR: "logical_or"},
LogicalOperator: {
LogicalOperator.AND: "logical_and",
LogicalOperator.OR: "logical_or",
},
NativeFunction: {
NativeFunction.ABS: "abs",
NativeFunction.MIN: "minimum",
Expand Down Expand Up @@ -965,11 +977,22 @@ def data_type_to_typestr(dtype: DataType) -> str:

def op_to_ufunc(
op: Union[
UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction
UnaryOperator,
ArithmeticOperator,
ComparisonOperator,
LogicalOperator,
NativeFunction,
],
) -> np.ufunc:
if not isinstance(
op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction)
op,
(
UnaryOperator,
ArithmeticOperator,
ComparisonOperator,
LogicalOperator,
NativeFunction,
),
):
raise TypeError(
"Can only convert instances of GTC operators and supported native functions to typestr."
Expand Down
16 changes: 11 additions & 5 deletions src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def visit_FieldAccess(
) -> Union[oir.FieldAccess, oir.ScalarAccess]:
offsets = node.offset.to_dict()
if node.name in tmps_name_map:
assert offsets["i"] == offsets["j"] == offsets["k"] == 0, (
"Non-zero offset in temporary that is replaced?!"
)
if offsets["i"] != 0 or offsets["j"] != 0 or offsets["k"] != 0:
raise ValueError(
"No K-offset capabilities of 3D temporaries. Must "
Comment thread
katrinafandrich marked this conversation as resolved.
Outdated
f"define '{node.name}' as a FloatField."
)
return oir.ScalarAccess(name=tmps_name_map[node.name], dtype=node.dtype)
return self.generic_visit(node, tmps_name_map=tmps_name_map, **kwargs)

Expand All @@ -52,7 +54,9 @@ def visit_HorizontalExecution(
declarations=node.declarations
+ [
oir.LocalScalar(
name=tmps_name_map[tmp], dtype=symtable[tmp].dtype, loc=symtable[tmp].loc
name=tmps_name_map[tmp],
dtype=symtable[tmp].dtype,
loc=symtable[tmp].loc,
)
for tmp in local_tmps_to_replace
],
Expand All @@ -76,7 +80,9 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
name=node.name,
params=node.params,
vertical_loops=self.visit(
node.vertical_loops, new_symbol_name=symbol_name_creator(all_names), **kwargs
node.vertical_loops,
new_symbol_name=symbol_name_creator(all_names),
**kwargs,
),
declarations=[d for d in node.declarations if d.name not in tmps_to_replace],
loc=node.loc,
Expand Down
12 changes: 9 additions & 3 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _validate_args( # Function is too complex
)
):
raise ValueError(
f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
f"Compute domain too large (provided: {domain}, maximum: {max_domain}. Check stencil domain provided or adjust K interval as needed.)"
)

if domain[2] < self.domain_info.min_sequential_axis_size:
Expand Down Expand Up @@ -434,7 +434,10 @@ def _validate_args( # Function is too complex

if (
arg_info.dimensions is not None
and (*field_info.axes, *(str(d) for d in range(len(field_info.data_dims))))
and (
*field_info.axes,
*(str(d) for d in range(len(field_info.data_dims))),
)
!= arg_info.dimensions
):
raise ValueError(
Expand Down Expand Up @@ -589,7 +592,10 @@ def _call_run(
exec_info["call_run_end_time"] = time.perf_counter()

def freeze(
self: StencilObject, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]
self: StencilObject,
*,
origin: dict[str, tuple[int, ...]],
domain: tuple[int, ...],
) -> FrozenStencil:
"""Return a StencilObject wrapper with a fixed domain and origin for each argument.

Expand Down