From 369febf61c37bf02b55d35bd09c5ead5e8fa9a0d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 10:47:39 +0200 Subject: [PATCH 01/11] Cleanups around shash() Cleans up types, adds docstrings and improves usage. --- src/gt4py/cartesian/caching.py | 2 +- src/gt4py/cartesian/definitions.py | 7 ++---- .../cartesian/frontend/gtscript_frontend.py | 2 +- src/gt4py/cartesian/utils/base.py | 25 ++++++++++++++++--- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index bb94fc1f82..aa61e369a0 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -325,7 +325,7 @@ def stencil_id(self) -> StencilID: # typeignore because attrclass StencilID has generated constructor return StencilID( # type: ignore self.builder.options.qualified_name, - gt_utils.shashed_id(gt_utils.shashed_id(fingerprint), self.options_id), + gt_utils.shashed_id(fingerprint, self.options_id), ) @property diff --git a/src/gt4py/cartesian/definitions.py b/src/gt4py/cartesian/definitions.py index b795d4280d..ad1ff57e07 100644 --- a/src/gt4py/cartesian/definitions.py +++ b/src/gt4py/cartesian/definitions.py @@ -105,17 +105,14 @@ class BuildOptions(AttributeClassLike): @property def qualified_name(self): - name = ".".join([self.module, self.name]) - return name + return ".".join([self.module, self.name]) @property def shashed_id(self): - result = gt_utils.shashed_id( + return gt_utils.shashed_id( self.name, self.module, self.format_source, *tuple(sorted(self.backend_opts.items())) ) - return result - @attribclass(frozen=True) class StencilID(AttributeClassLike): diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index e4a0665d47..41db4f51c4 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -480,7 +480,7 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex if isinstance(value, ast.Name) and name not in assigned_symbols } - call_id = gt_utils.shashed_id(call_name)[:3] + call_id = gt_utils.shashed_id(call_name, length=3) call_id_suffix = f"{call_id}_{node.lineno}_{node.col_offset}" template_fmt = "{name}__" + call_id_suffix diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index 961cb4dc4c..610cde2dd7 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -8,6 +8,8 @@ """Basic utilities for Python programming.""" +from __future__ import annotations + import collections.abc import functools import hashlib @@ -19,6 +21,7 @@ import sys import time import types +import warnings from typing import Generic, TypeVar from gt4py.cartesian import config as gt_config @@ -169,7 +172,13 @@ def normalize_mapping(mapping, key_types=(object,), *, filter_none=False): return result -def shash(*args, hash_algorithm=None): +def shash(*args, hash_algorithm: hashlib._Hash | None = None, length: int | None = None) -> str: + """Hash the given arguments. + + Args: + hash_algorithm: Specify the hashlib algorithm used. Defaults to sha 256. + length: Trim to the first `length` digits of the hash. Returns the full hash by default. + """ if hash_algorithm is None: hash_algorithm = hashlib.sha256() @@ -182,11 +191,19 @@ def shash(*args, hash_algorithm=None): item = str.encode(repr(item)) hash_algorithm.update(item) - return hash_algorithm.hexdigest() + digest = hash_algorithm.hexdigest() + if length is not None and length > len(digest): + warnings.warn( + f"Requested hash of length {length}, but the full hash's length is {len(digest)}. Returning the full hash.", + stacklevel=2, + ) + length = None + return digest[:length] if length is not None else digest -def shashed_id(*args, length=10, hash_algorithm=None): - return shash(*args, hash_algorithm=hash_algorithm)[:length] +def shashed_id(*args, length: int = 10) -> str: + """Hash the given arguments and trim to length.""" + return shash(*args, length=length) def classmethod_to_function(class_method, instance=None, owner=None, remove_cls_arg=False): From 24789af1453253f889c22758f33eba3de9698eba Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 10:52:40 +0200 Subject: [PATCH 02/11] Random cleanups in base utils --- src/gt4py/cartesian/utils/base.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index 610cde2dd7..b7d20b11d6 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -22,7 +22,7 @@ import time import types import warnings -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar from gt4py.cartesian import config as gt_config @@ -42,14 +42,14 @@ def jsonify(value, indent=2): return json.dumps(value, indent=indent, default=lambda obj: str(obj)) -def is_identifier_name(value, namespaced=True): +def is_identifier_name(value: Any, namespaced: bool = True) -> bool: if isinstance(value, str): if namespaced: return all(name.isidentifier() for name in value.split(".")) - else: - return value.isidentifier() - else: - return False + + return value.isidentifier() + + return False def listify(value): @@ -75,8 +75,8 @@ def get_member(instance, item_name): isinstance(instance, collections.abc.Sequence) and isinstance(item_name, int) ): return instance[item_name] - else: - return getattr(instance, item_name) + + return getattr(instance, item_name) except Exception: return NOTHING @@ -209,12 +209,10 @@ def shashed_id(*args, length: int = 10) -> str: def classmethod_to_function(class_method, instance=None, owner=None, remove_cls_arg=False): if remove_cls_arg: return functools.partial(class_method.__get__(instance, owner), None) - else: - return class_method.__get__(instance, owner) + return class_method.__get__(instance, owner) -def namespace_from_nested_dict(nested_dict): - assert isinstance(nested_dict, dict) +def namespace_from_nested_dict(nested_dict: dict) -> types.SimpleNamespace: return types.SimpleNamespace( **{ key: namespace_from_nested_dict(value) if isinstance(value, dict) else value @@ -318,7 +316,7 @@ def patch_module(module, member, new_value, *, recursive=True): if patched: originals[current] = patched - patch = dict( + return dict( module=module, original_value=member, patched_value=new_value, @@ -326,8 +324,6 @@ def patch_module(module, member, new_value, *, recursive=True): originals=originals, ) - return patch - def restore_module(patch, *, verify=True): """Restore a module patched with the `patch_module()` function.""" From 4256311d120ea05a4569a70b3612a810a0b96a74 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 10:58:48 +0200 Subject: [PATCH 03/11] Cleanup how we freeze DaCeStencilObjectes --- .../cartesian/backend/dace_stencil_object.py | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index a6cb2be565..a3de5ad59b 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -100,30 +100,25 @@ def freeze( self: DaCeStencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] ) -> DaCeFrozenStencil: key = DaCeStencilObject._get_domain_origin_key(domain, origin) + + # check if same sdfg already cached on disk if key in self._frozen_cache: return self._frozen_cache[key] - frozen_hash = shash(origin, domain) + # otherwise, wrap and save sdfg from scratch + frozen_sdfg = freeze_origin_domain_sdfg( + self.sdfg(), + arg_names=self.__sdfg_signature__()[0], + field_info=self.field_info, + origin=origin, + domain=domain, + ) + self._frozen_cache[key] = DaCeFrozenStencil(self, origin, domain, frozen_sdfg) - # check if same sdfg already cached on disk basename = os.path.splitext(self.SDFG_PATH)[0] - filename = basename + "_" + str(frozen_hash) + ".sdfg" - try: - frozen_sdfg = dace.SDFG.from_file(filename) - except FileNotFoundError: - # otherwise, wrap and save sdfg from scratch - inner_sdfg = self.sdfg() - - frozen_sdfg = freeze_origin_domain_sdfg( - inner_sdfg, - arg_names=self.__sdfg_signature__()[0], - field_info=self.field_info, - origin=origin, - domain=domain, - ) - frozen_sdfg.save(filename) + filename = f"{basename}_{shash(origin, domain)}.sdfg" + frozen_sdfg.save(filename) - self._frozen_cache[key] = DaCeFrozenStencil(self, origin, domain, frozen_sdfg) return self._frozen_cache[key] @classmethod From 4a8ba696ea98aed2c3316228bdf4d6190833c161 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:03:33 +0200 Subject: [PATCH 04/11] Fix typo in top-level README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 12d430821d..a5e1034577 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ The `pyproject.toml` file contains both the definition of the `gt4py` Python dis ### Development Tasks (`dev-tasks.py`) -Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration accross different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`). +Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration across different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`). ## 📖 Documentation From 301d7ee5a46028cc0fc73a8ebfbfae4259a84404 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:25:37 +0200 Subject: [PATCH 05/11] Random cleanups in dace_backend --- src/gt4py/cartesian/backend/dace_backend.py | 51 ++++++++++----------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 578a016059..164ae150a2 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -316,20 +316,20 @@ def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, doma class SDFGManager: # Cache loaded SDFGs across all instances - _loaded_sdfgs: ClassVar[dict[str, dace.SDFG]] = dict() + _loaded_sdfgs: ClassVar[dict[str | pathlib.Path, dace.SDFG]] = dict() - def __init__(self, builder): + def __init__(self, builder: StencilBuilder) -> None: self.builder = builder @staticmethod - def _strip_history(sdfg): + def _strip_history(sdfg: dace.SDFG) -> None: # strip history from SDFG for faster save/load for tmp_sdfg in sdfg.all_sdfgs_recursive(): tmp_sdfg.transformation_hist = [] tmp_sdfg.orig_sdfg = None @staticmethod - def _save_sdfg(sdfg, path): + def _save_sdfg(sdfg: dace.SDFG, path: str) -> None: SDFGManager._strip_history(sdfg) sdfg.save(path) @@ -375,27 +375,26 @@ def expanded_sdfg(self): return copy.deepcopy(self._expanded_sdfg()) def _frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]): - frozen_hash = shash(origin, domain) + basename = self.builder.module_path.stem + path = f"{basename}_{shash(origin, domain)}.sdfg" + # check if same sdfg already cached on disk - path = self.builder.module_path - basename = os.path.splitext(path)[0] - path = basename + "_" + str(frozen_hash) + ".sdfg" - if path not in SDFGManager._loaded_sdfgs: - try: - sdfg = dace.SDFG.from_file(path) - except FileNotFoundError: - # otherwise, wrap and save sdfg from scratch - inner_sdfg = self.unexpanded_sdfg() - - sdfg = freeze_origin_domain_sdfg( - inner_sdfg, - arg_names=[arg.name for arg in self.builder.gtir.api_signature], - field_info=make_args_data_from_gtir(self.builder.gtir_pipeline).field_info, - origin=origin, - domain=domain, - ) - self._save_sdfg(sdfg, path) - SDFGManager._loaded_sdfgs[path] = sdfg + if path in SDFGManager._loaded_sdfgs: + return SDFGManager._loaded_sdfgs[path] + + # otherwise, wrap and save sdfg from scratch + inner_sdfg = self.unexpanded_sdfg() + + sdfg = freeze_origin_domain_sdfg( + inner_sdfg, + arg_names=[arg.name for arg in self.builder.gtir.api_signature], + field_info=make_args_data_from_gtir(self.builder.gtir_pipeline).field_info, + origin=origin, + domain=domain, + ) + SDFGManager._loaded_sdfgs[path] = sdfg + self._save_sdfg(sdfg, path) + return SDFGManager._loaded_sdfgs[path] def frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]): @@ -631,6 +630,7 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> list[ symbols[name] = fmt.format( name=name, ndim=len(array.shape), origin=",".join(str(o) for o in origin) ) + # the remaining arguments are variables and can be passed by name for sym in sdfg.signature_arglist(with_types=False, for_call=True): if sym not in symbols: @@ -678,8 +678,7 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> li ) ) elif name in sdfg.symbols and not name.startswith("__"): - assert name in sdfg.symbols - res[name] = "{dtype} {name}".format(dtype=sdfg.symbols[name].ctype, name=name) + res[name] = f"{sdfg.symbols[name].ctype} {name}" return list(res[node.name] for node in stencil_ir.params if node.name in res) def generate_sid_params(self, sdfg: dace.SDFG) -> list[str]: From 28c76919e631a3b5acc99b4e97e0e2539576e923 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:28:37 +0200 Subject: [PATCH 06/11] type hints for symbol_utils --- src/gt4py/cartesian/gtc/dace/symbol_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py index 1f332d682b..0a3c00d7ec 100644 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ b/src/gt4py/cartesian/gtc/dace/symbol_utils.py @@ -11,8 +11,8 @@ from functools import lru_cache from typing import TYPE_CHECKING -import dace import numpy as np +from dace import dtypes, symbolic from gt4py import eve from gt4py.cartesian.gtc import common @@ -22,9 +22,9 @@ from gt4py.cartesian.gtc.dace import daceir as dcir -def data_type_to_dace_typeclass(data_type): +def data_type_to_dace_typeclass(data_type: common.DataType) -> dtypes.typeclass: dtype = np.dtype(common.data_type_to_typestr(data_type)) - return dace.dtypes.typeclass(dtype.type) + return dtypes.typeclass(dtype.type) def get_axis_bound_str(axis_bound, var_name): @@ -65,5 +65,7 @@ def get_axis_bound_diff_str(axis_bound1, axis_bound2, var_name: str): @lru_cache(maxsize=None) -def get_dace_symbol(name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32): - return dace.symbol(name, dtype=data_type_to_dace_typeclass(dtype)) +def get_dace_symbol( + name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32 +) -> symbolic.symbol: + return symbolic.symbol(name, dtype=data_type_to_dace_typeclass(dtype)) From 6a51f6c962b34c3c8ea39f145a7a03b6d0fff370 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:37:03 +0200 Subject: [PATCH 07/11] cleanup dace/utils --- src/gt4py/cartesian/backend/dace_backend.py | 6 ++++- src/gt4py/cartesian/gtc/dace/utils.py | 28 ++++++++++----------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 164ae150a2..099c4ff049 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -687,7 +687,11 @@ def generate_sid_params(self, sdfg: dace.SDFG) -> list[str]: for name, array in sdfg.arrays.items(): if array.transient: continue - domain_dim_flags = array_dimensions(array) + + domain_dim_flags = tuple(array_dimensions(array)) + if len(domain_dim_flags) != 3: + raise RuntimeError("Expected 3 cartesian array dimensions. Codegen error.") + data_ndim = len(array.shape) - sum(domain_dim_flags) sid_def = pybuffer_to_sid( name=name, diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 4ef48ebcd9..c552c9dfa1 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -12,9 +12,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import dace -import dace.data import numpy as np +from dace import data, dtypes, properties, subsets, symbolic from gt4py import eve from gt4py.cartesian.gtc import common, oir @@ -23,33 +22,32 @@ from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents -def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo: +def get_dace_debuginfo(node: common.LocNode) -> dtypes.DebugInfo: if node.loc is None: - return dace.dtypes.DebugInfo(0) + return dtypes.DebugInfo(0) - return dace.dtypes.DebugInfo( + return dtypes.DebugInfo( node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename ) -def array_dimensions(array: dace.data.Array): - dims = [ +def array_dimensions(array: data.Array) -> list[bool]: + return [ any( re.match(f"__.*_{k}_stride", str(sym)) for st in array.strides - for sym in dace.symbolic.pystr_to_symbolic(st).free_symbols + for sym in symbolic.pystr_to_symbolic(st).free_symbols ) or any( re.match(f"__{k}", str(sym)) for sh in array.shape - for sym in dace.symbolic.pystr_to_symbolic(sh).free_symbols + for sym in symbolic.pystr_to_symbolic(sh).free_symbols ) for k in "IJK" ] - return dims -def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]: +def replace_strides(arrays: list[data.Array], get_layout_map) -> dict[str, str]: symbol_mapping = {} for array in arrays: dims = array_dimensions(array) @@ -61,7 +59,7 @@ def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, for idx in reversed(np.argsort(layout)): symbol = array.strides[idx] if symbol.is_symbol: - symbol_mapping[str(symbol)] = dace.symbolic.pystr_to_symbolic(stride) + symbol_mapping[str(symbol)] = symbolic.pystr_to_symbolic(stride) stride *= array.shape[idx] return symbol_mapping @@ -318,7 +316,7 @@ def compute_dcir_access_infos( collect_write=True, include_full_domain=False, **kwargs, -) -> dace.properties.DictProperty: +) -> properties.DictProperty: if block_extents is None: assert isinstance(oir_node, oir.Stencil) block_extents = compute_horizontal_block_extents(oir_node) @@ -514,7 +512,7 @@ def make_dace_subset( context_info: dcir.FieldAccessInfo, access_info: dcir.FieldAccessInfo, data_dims: Tuple[int, ...], -) -> dace.subsets.Range: +) -> subsets.Range: clamped_access_info = access_info clamped_context_info = context_info for axis in access_info.axes(): @@ -531,7 +529,7 @@ def make_dace_subset( ].to_dace_symbolic() res_ranges.append((subset_start - context_start, subset_end - context_start - 1, 1)) res_ranges.extend((0, dim - 1, 1) for dim in data_dims) - return dace.subsets.Range(res_ranges) + return subsets.Range(res_ranges) def untile_memlets(memlets: Sequence[dcir.Memlet], axes: Sequence[dcir.Axis]) -> List[dcir.Memlet]: From 8b47b2c0a3f74fbd0bd434b4108f5139a2409734 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:01:12 +0200 Subject: [PATCH 08/11] Test improvements --- .../stencil_definitions.py | 38 +++++++++++++++++-- .../test_code_generation.py | 23 ++++++++--- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 217c0ee488..29057b1e5a 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -61,6 +61,7 @@ def _register_decorator(actual_func): return _register_decorator(func) if func else _register_decorator +Field2D = gtscript.Field[gtscript.IJ, np.float64] Field3D = gtscript.Field[np.float64] Field3DBool = gtscript.Field[np.bool_] @@ -72,7 +73,7 @@ def copy_stencil(field_a: Field3D, field_b: Field3D): @gtscript.function -def afunc(b): +def a_gtscript_function(b): return sqrt(abs(b[0, 1, 0])) @@ -82,6 +83,27 @@ def arithmetic_ops(field_a: Field3D, field_b: Field3D): field_a = (((((field_b + 42.0) - 42.0) * +42.0) / -42.0) % 42.0) ** 2 +@register +def scalar_inputs(field_a: Field3D, scalar_in: float): + with computation(PARALLEL), interval(...): + field_a = field_a * scalar_in + + +@register +def unary_operation(field_a: Field3D, scalar_in: float): + with computation(PARALLEL), interval(...): + field_a = -scalar_in + + +@register +def temporary_stencil(field_a: Field3D, field_b: Field2D, scalar_in: float): + with computation(PARALLEL), interval(...): + tmp = field_a * scalar_in + + with computation(FORWARD), interval(0, 1): + field_b += tmp + + @register def data_types( bool_field: gtscript.Field[bool], @@ -127,8 +149,8 @@ def native_functions(field_a: Field3D, field_b: Field3D): acosh_res = acosh(cosh_res) tanh_res = tanh(acosh_res) atanh_res = atanh(tanh_res) - sqrt_res = afunc(atanh_res) - pow10_res = 10 ** (sqrt_res) + sqrt_res = a_gtscript_function(atanh_res) + pow10_res = 10 ** (atanh_res) log10_res = log10(pow10_res) exp_res = exp(log10_res) log_res = log(exp_res) @@ -148,6 +170,14 @@ def native_functions(field_a: Field3D, field_b: Field3D): ) +@register +def while_stencil(field_a: Field3D, field_b: Field3D): + with computation(BACKWARD), interval(...): + while field_a > 2.0: + field_b = -1 + field_a = -field_b + + @register def copy_stencil_plus_one(field_a: Field3D, field_b: Field3D): with computation(PARALLEL), interval(...): @@ -295,7 +325,7 @@ def large_k_interval(in_field: Field3D, out_field: Field3D): with computation(PARALLEL): with interval(0, 6): out_field = in_field - # this stenicl is only legal to call with fields that have more than 16 elements + # this stencil is only legal to call with fields that have more than 16 elements with interval(6, -10): out_field = in_field + 1 with interval(-10, None): diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c2b82e4bac..e6768fa094 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -50,8 +50,8 @@ def test_generation(name, backend): dtype=(v.dtype, v.data_dims) if v.data_dims else v.dtype, dimensions=v.axes, backend=backend, - shape=(23, 23, 23), - aligned_index=(10, 10, 10), + shape=(23,) * len(v.axes), + aligned_index=(10,) * len(v.axes), ) else: args[k] = v(1.5) @@ -368,6 +368,17 @@ def stencil( with computation(PARALLEL), interval(...): out_field[0, 0, 0] = in_field * parameter + field_in = gt_storage.ones( + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + ) + field_out = gt_storage.zeros( + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + ) + + stencil(field_in, 3.1415, field_out) + + np.testing.assert_allclose(field_out[:, :, :], 3.1415) + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets(backend): @@ -706,8 +717,8 @@ def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scala # - lev = 2 # - A[2] == 42 && B[2] == -1 => False # End of iteration state - # - A[...] = A[40, 2.0, 2.0, -1] - # - B[...] = A[1, 1, -1, 42] + # - A[...] = A[40, 2.0, 42, -1] + # - B[...] = B[1, 1, -1, 42] # ITERATION k = 1 of [2:1] # if condition # - A[1] == 2.0 && B[1] == 1 => True @@ -719,10 +730,10 @@ def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scala # - A[2] = -1 # - B[1] = -1 # - lev = 2 - # - A[1] == 2.0 && B[2] == -1 => False + # - A[1] == 2.0 && B[1] == -1 => False # End of stencil state # - A[...] = A[2.0, 2.0, -1, -1] - # - B[...] = A[1, -1, 2.0, 42] + # - B[...] = B[1, -1, 2.0, 42] assert (A[0, 0, :] == arraylib.array([2, 2, -1, -1])).all() assert (B[0, 0, :] == arraylib.array([1, -1, 2, 42])).all() From b037fab80dd40d663752a3179fe5a00a2ad5bf3d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 25 Jun 2025 09:41:11 +0200 Subject: [PATCH 09/11] Fixup: Don't cache frozen_sdfg to cwd --- src/gt4py/cartesian/backend/dace_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 099c4ff049..b4a53c6916 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -375,7 +375,7 @@ def expanded_sdfg(self): return copy.deepcopy(self._expanded_sdfg()) def _frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]): - basename = self.builder.module_path.stem + basename = self.builder.module_path.with_suffix("") path = f"{basename}_{shash(origin, domain)}.sdfg" # check if same sdfg already cached on disk From 420834e2ab8f9f531a824a68f2305284b2ebfed8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 25 Jun 2025 14:10:19 +0200 Subject: [PATCH 10/11] Use type guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- src/gt4py/cartesian/utils/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index b7d20b11d6..a48308f87f 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -22,7 +22,7 @@ import time import types import warnings -from typing import Any, Generic, TypeVar +from typing import Generic, TypeGuard, TypeVar from gt4py.cartesian import config as gt_config @@ -42,7 +42,7 @@ def jsonify(value, indent=2): return json.dumps(value, indent=indent, default=lambda obj: str(obj)) -def is_identifier_name(value: Any, namespaced: bool = True) -> bool: +def is_identifier_name(value: object, namespaced: bool = True) -> TypeGuard[str]: if isinstance(value, str): if namespaced: return all(name.isidentifier() for name in value.split(".")) From aa902571cc22683add718205e83404f329994b8e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 25 Jun 2025 10:43:01 +0200 Subject: [PATCH 11/11] fix typo in ADR template --- docs/development/ADRs/Template.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/development/ADRs/Template.md b/docs/development/ADRs/Template.md index cad7b5d164..f038962f60 100644 --- a/docs/development/ADRs/Template.md +++ b/docs/development/ADRs/Template.md @@ -24,7 +24,7 @@ We chose option X because [justification. e.g., only option, which meets k.o. cr ## Consequences -What it now easier to do? What becomes more difficult with this change? +What is now easier to do? What becomes more difficult with this change? Describe the positive (e.g., improvement of quality attribute satisfaction, follow-up decisions required, ...) as well as the negative (e.g., compromising quality attribute, follow-up decisions required, ...) outcomes of this decision.