From 53c8e686994a9d5e63a954647de9cf83cebdf369 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 4 Aug 2025 16:38:26 +0200 Subject: [PATCH 01/35] NASA Team: Mileston 2 "release" branch This branch is what we use for the NASA team as we start to prepare for Milestone 2. Currently uses the following versions of externals: - GT4Py: follows "milestone2" branch on Roman's fork - DaCe: whatever GT4Py's uv.lock file says about dace-cartesian --- README.md | 10 ++++++++++ external/gt4py | 2 +- ndsl/dsl/__init__.py | 3 ++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e55b012e..2f75b217 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,16 @@ # NOAA/NASA Domain Specific Language middleware +:warning: :dragon: This is the equivalent of a "release" branch for the NASA Team's Milestone 2 work. In particular, we include the following experimental features: + +- "Hybrid" indexing that allows absolute K-indices +- Access to the iteration variable in K, working title `THIS_K` +- `round()` function in gtscript + +Your standard readme continues below. + +--- + NDSL is a middleware for climate and weather modelling developed jointly by NOAA and NASA. The middleware brings together [GT4Py](https://github.com/GridTools/gt4py/) (the `cartesian` flavor), ETH CSCS's stencil DSL, and [DaCe](https://github.com/spcl/dace/), ETH SPCL's data flow framework, both developed for high-performance and portability. On top of those pillars, NDSL deploys a series of optimized APIs for common operations (Halo exchange, domain decomposition, MPI, ...), a set of bespoke optimizations for the models targeted by the middleware and tools to port existing models. ## Batteries-included for FV-based models diff --git a/external/gt4py b/external/gt4py index 68eea74b..a3da7126 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 68eea74b748747ac5415c93e479d7964f3ec6947 +Subproject commit a3da7126e677273751976bd2177c871163d009ca diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index e3fe0cc8..c41a113b 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -13,7 +13,8 @@ " before any `gt4py` imports." ) NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64")) -os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION) +os.environ["GT4PY_LITERAL_INT_PRECISION"] = str(NDSL_GLOBAL_PRECISION) +os.environ["GT4PY_LITERAL_FLOAT_PRECISION"] = str(NDSL_GLOBAL_PRECISION) # Set cache names for default gt backends workflow From c72c7740b9562cf437fa2b40d0d451f8d4b7ec9d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Tue, 5 Aug 2025 08:35:11 +0200 Subject: [PATCH 02/35] Expose erf, erfc, round, and new typecasts from ndsl.dsl.gt4py --- ndsl/dsl/gt4py/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ndsl/dsl/gt4py/__init__.py b/ndsl/dsl/gt4py/__init__.py index 3ae0dbd8..f92ccbba 100644 --- a/ndsl/dsl/gt4py/__init__.py +++ b/ndsl/dsl/gt4py/__init__.py @@ -26,12 +26,18 @@ computation, cos, cosh, + erf, + erfc, exp, externals, + float32, + float64, floor, function, gamma, horizontal, + int32, + int64, interval, isfinite, isinf, @@ -42,6 +48,7 @@ min, mod, region, + round, sin, sinh, sqrt, @@ -82,10 +89,16 @@ "cosh", "exp", "externals", + "erf", + "erfc", + "float32", + "float64", "floor", "function", "gamma", "horizontal", + "int32", + "int64", "interval", "isfinite", "isinf", @@ -96,6 +109,7 @@ "min", "mod", "region", + "round", "sin", "sinh", "sqrt", From 1e8ecfbde5f2632575708fb0aa11381d5ac22ab1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:26:53 +0200 Subject: [PATCH 03/35] gt4py update: abs k and current k in debug backend This commit updates GT4Py to add support for the experimental features "absolute k indexing" and "expose current k-level" in the debug backend. --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index a3da7126..297e6039 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit a3da7126e677273751976bd2177c871163d009ca +Subproject commit 297e603905497b81bf314c34866aad63ac20172d From 258836d3546e27a9ff5d1072fabc2c24a4234e73 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 7 Aug 2025 17:37:58 +0200 Subject: [PATCH 04/35] gt4py update: fix literal precision --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 297e6039..b1720abf 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 297e603905497b81bf314c34866aad63ac20172d +Subproject commit b1720abfe370c76d580fea59807ffa20d8d10751 From b1a1abde4a1a692726c8c590fa85e96be64f7276 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 Aug 2025 11:43:13 +0200 Subject: [PATCH 05/35] dace|orchestration: Schedule tree roundrip work (#206) * Roundtrip sdfg -> stree -> sdfg in orchestration with moaar validation * Remove debug prints and intermediate sdfg saving * Use default when calling simplify * Update gt4py/dace submodules (roundtrip work) This commit brings the changes needed for stree rountrips to validate with the AI2 data (in the PyFV3 translate tests). * Update README * Quick note: skip ScalarToSymbolPromotion for now The pass messes up previously valid & validating SDFGs. We can live (performance wise) without it for the current milestone. Let's re-evaluate once we get back to DaCe mainline (v2). * Update gt4py & dace submodules (stree/rountrip) --- README.md | 1 + external/dace | 2 +- external/gt4py | 2 +- ndsl/dsl/dace/orchestration.py | 31 +++++++++++++++++++++++++++++-- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 445ca2c7..0f594637 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ - "Hybrid" indexing that allows absolute K-indices - Access to the iteration variable in K, working title `THIS_K` - `round()` function in gtscript +- "schedule tree roundtrip work": fixes in the gt4py/dace bridge and in dace's stree/sdfg conversions to validate PyFV3 translate tests (with the AI2 data) Your standard readme continues below. diff --git a/external/dace b/external/dace index 82541a94..54c669ed 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 82541a9401dcadca43edc33cf1db61a0fe21d0e5 +Subproject commit 54c669ed0f0bc97eb077e84ace58f7596744c8b5 diff --git a/external/gt4py b/external/gt4py index b1720abf..a65f56f8 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit b1720abfe370c76d580fea59807ffa20d8d10751 +Subproject commit a65f56f898096b22013cbd5a9574de112688adfa diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 38e7be09..0475fc02 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numbers import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -111,6 +112,7 @@ def _simplify( validate=validate, validate_all=validate_all, verbose=verbose, + # ScalarToSymbolPromotion is messing with us, so we disable it. skip=["ScalarToSymbolPromotion"], ).apply_pass(sdfg, {}) @@ -125,6 +127,28 @@ def _build_sdfg( with DaCeProgress(config, "Validate original SDFG"): sdfg.validate() + # Fully specialize all known symbols and then propagate these changes in the simplify + # pass that follows. This is not only a smart idea in general, but also simplifies (haha) + # the schedule tree (optimization) roundtrip. + with DaCeProgress(config, "Fully specialize symbols"): + for my_sdfg in sdfg.all_sdfgs_recursive(): + if my_sdfg.parent_nsdfg_node is not None: + repl_dict = {} + for sym, val in my_sdfg.parent_nsdfg_node.symbol_mapping.items(): + if isinstance(val, numbers.Number): + repl_dict[sym] = val + my_sdfg.replace_dict(repl_dict) + + with DaCeProgress(config, "Simplify (1)"): + _simplify(sdfg) + + # TODO uncomment if you want to test schedule tree roundtrip change and/or remove once + # we have the schedule tree optimization pipeline. + # with DaCeProgress(config, "Schedule tree roundtrip"): + # stree = sdfg.as_schedule_tree() + # # ScalarToSymbolPromotion is messing with us, so we disable it. + # sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) + # Make the transients array persistents if config.is_gpu_backend(): # TODO @@ -153,8 +177,8 @@ def _build_sdfg( if k in sdfg_kwargs and tup[1].transient: del sdfg_kwargs[k] - with DaCeProgress(config, "Simplify"): - _simplify(sdfg, validate=False, verbose=True) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) # Move all memory that can be into a pool to lower memory pressure. # Change Persistent memory (sub-SDFG) into Scope and flag it. @@ -178,6 +202,9 @@ def _build_sdfg( negative_delp_checker(sdfg) negative_qtracers_checker(sdfg) + with DaCeProgress(config, "Validate before compile"): + sdfg.validate() + # Compile with DaCeProgress(config, "Codegen & compile"): sdfg.compile() From aac636067ce99f2c79c258ab3fb63f573618a11b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 22 Aug 2025 08:10:23 +0200 Subject: [PATCH 06/35] update gt4py to milestone2 --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index a65f56f8..9142c51a 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit a65f56f898096b22013cbd5a9574de112688adfa +Subproject commit 9142c51ad516b31fef3afafc132e8a4ed5b561be From 7521c0e0fd2b2e426be0496d6d8ec3e2931af570 Mon Sep 17 00:00:00 2001 From: "Christopher W. Kung" Date: Fri, 22 Aug 2025 10:22:09 -0400 Subject: [PATCH 07/35] Added device_synchronize call to fix GPU/MPI synchronization issue on MPI inplace all_reduce calls. Note that device_synchronize is Cupy/CUDA specific at the moment. --- ndsl/comm/communicator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index b4ea9a0c..d063cc1f 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -4,6 +4,7 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np +import cupy as cp import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer @@ -16,6 +17,7 @@ from ndsl.performance.timer import NullTimer, Timer from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule +from ndsl.utils import device_synchronize def to_numpy(array, dtype=None) -> np.ndarray: @@ -142,6 +144,9 @@ def all_reduce_per_element( def all_reduce_per_element_in_place( self, quantity: Quantity, op: ReductionOperator ): + # Note that device_synchronization is Cupy/Cuda specific + # at the moment. + device_synchronize() self.comm.Allreduce_inplace(quantity.data, op) def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): From 7555ec94e5958a6295990b499954d79238575ef0 Mon Sep 17 00:00:00 2001 From: "Christopher W. Kung" Date: Fri, 22 Aug 2025 10:34:59 -0400 Subject: [PATCH 08/35] Linting --- ndsl/comm/communicator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index d063cc1f..b07091ab 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -4,7 +4,6 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np -import cupy as cp import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer @@ -17,8 +16,6 @@ from ndsl.performance.timer import NullTimer, Timer from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule -from ndsl.utils import device_synchronize - def to_numpy(array, dtype=None) -> np.ndarray: """ From f7d962852832d580f7cf32520532dffa581b19be Mon Sep 17 00:00:00 2001 From: "Christopher W. Kung" Date: Fri, 22 Aug 2025 10:40:59 -0400 Subject: [PATCH 09/35] Linting again --- ndsl/comm/communicator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index b07091ab..c6ec58c5 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -17,6 +17,7 @@ from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule + def to_numpy(array, dtype=None) -> np.ndarray: """ Input array can be a numpy array or a cupy array. Returns numpy array. From ac5af0bac04f41fc78c031f1a5e2e01abc9095e4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:29:42 +0200 Subject: [PATCH 10/35] perf: set build type to release in dace config --- ndsl/dsl/dace/dace_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 484be830..c4e61b2b 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -203,6 +203,7 @@ def __init__( # Set the configuration of DaCe to a rigid & tested set of divergence # from the defaults when orchestrating if orchestration != DaCeOrchestration.Python: + dace.config.Config.set("compiler", "build_type", value="Release") # Required to True for gt4py storage/memory dace.config.Config.set( "compiler", From 573d44cf9a41e9fe84154160313b742d754332cb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 26 Aug 2025 16:02:12 +0200 Subject: [PATCH 11/35] perf: set -march=native flag for cpu --- ndsl/dsl/dace/dace_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index c4e61b2b..6e8073bc 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -215,7 +215,7 @@ def __init__( "compiler", "cpu", "args", - value=f"-std=c++14 -fPIC -Wall -Wextra -O{optimization_level}", + value=f"-march=native -std=c++17 -fPIC -Wall -Wextra -O{optimization_level}", ) # Potentially buggy - deactivate dace.config.Config.set( From caad0b06e0da41d15b3f3d58054fec92d51be1c7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:49:44 +0200 Subject: [PATCH 12/35] fix: stencil wrapper field origins with data_dims Add support for fields with data_dims (or data_dims only fields) in the stencil wrapper's function to computae field origins. --- ndsl/dsl/dace/orchestration.py | 1 + ndsl/dsl/stencil.py | 13 ++++++--- tests/dsl/test_stencil_wrapper.py | 44 ++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 0475fc02..d16b4a3a 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -318,6 +318,7 @@ def _parse_sdfg( **kwargs, save=False, simplify=False, + validate=False, # TODO: should we have a "debug flag" to turn this on? ) return sdfg diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 1eea7cc0..464837fa 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -449,8 +449,10 @@ def __call__(self, *args, **kwargs) -> None: @classmethod def _compute_field_origins( - cls, field_info_mapping, origin: Union[Index3D, Mapping[str, Tuple[int, ...]]] - ) -> Dict[str, Tuple[int, ...]]: + cls, + field_info_mapping: dict[str, gt_definitions.FieldInfo], + origin: Index3D | Mapping[str, tuple[int, ...]], + ) -> dict[str, tuple[int, ...]]: """ Computes the origin for each field in the stencil call. @@ -463,8 +465,8 @@ def _compute_field_origins( origin_mapping: a mapping from field names to origins """ if isinstance(origin, tuple): - field_origins: Dict[str, Tuple[int, ...]] = {"_all_": origin} - origin_tuple: Tuple[int, ...] = origin + field_origins: dict[str, tuple[int, ...]] = {"_all_": origin} + origin_tuple: tuple[int, ...] = origin else: field_origins = {**origin} origin_tuple = origin["_all_"] @@ -477,6 +479,9 @@ def _compute_field_origins( for ax in field_info.axes: origin_index = {"I": 0, "J": 1, "K": 2}[ax] field_origin_list.append(origin_tuple[origin_index]) + for i, _data_dim in enumerate(field_info.data_dims): + if field_info.mask[len(field_info.domain_mask) + i]: + field_origin_list.append(0) field_origin = tuple(field_origin_list) else: field_origin = origin_tuple diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index df91f5c0..a1b0d248 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -4,6 +4,7 @@ import gt4py.cartesian.gtscript import numpy as np import pytest +from gt4py.cartesian import definitions from ndsl import ( CompilationConfig, @@ -46,46 +47,57 @@ def mock_gtscript_stencil(mock): gt4py.cartesian.gtscript.stencil = original_stencil -class MockFieldInfo: - def __init__(self, axes): - self.axes = axes +class MockFieldInfo(definitions.FieldInfo): + def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ()): + # defaults + access = definitions.AccessKind.READ + boundary = None + dtype = np.float64 + + super().__init__( + axes=axes, + data_dims=data_dims, + access=access, + boundary=boundary, + dtype=dtype, + ) @pytest.mark.parametrize( "field_info, origin, field_origins", [ pytest.param( - {"a": MockFieldInfo(["I"])}, + {"a": MockFieldInfo(axes=("I"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1,)}, id="single_field_I", ), pytest.param( - {"a": MockFieldInfo(["J"])}, + {"a": MockFieldInfo(axes=("J"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (2,)}, id="single_field_J", ), pytest.param( - {"a": MockFieldInfo(["K"])}, + {"a": MockFieldInfo(axes=("K"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (3,)}, id="single_field_K", ), pytest.param( - {"a": MockFieldInfo(["I", "J"])}, + {"a": MockFieldInfo(axes=("I", "J"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1, 2)}, id="single_field_IJ", ), pytest.param( - {"a": MockFieldInfo(["I", "J", "K"])}, + {"a": MockFieldInfo(axes=("I", "J", "K"))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, id="single_field_origin_mapping", ), pytest.param( - {"a": MockFieldInfo(["I", "J", "K"]), "b": MockFieldInfo(["I"])}, + {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I"))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3), "b": (1,)}, id="two_fields_update_origin_mapping", @@ -97,11 +109,23 @@ def __init__(self, axes): id="single_field_None", ), pytest.param( - {"a": MockFieldInfo(["I", "J"]), "b": MockFieldInfo(["I", "J", "K"])}, + { + "a": MockFieldInfo(axes=("I", "J")), + "b": MockFieldInfo(axes=("I", "J", "K")), + }, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1, 2), "b": (1, 2, 3)}, id="two_fields", ), + pytest.param( + { + "field": MockFieldInfo(axes=("I", "J", "K")), + "table": MockFieldInfo(data_dims=(5,)), + }, + (1, 2, 3), + {"_all_": (1, 2, 3), "field": (1, 2, 3), "table": (0,)}, + id="field_and_table", + ), ], ) def test_compute_field_origins(field_info, origin, field_origins) -> None: From 0a66099e61e8f4305c55940a9a3c601576ef48c2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 2 Sep 2025 17:35:31 +0200 Subject: [PATCH 13/35] Unrelated: no unused arguments in stencil definition --- tests/dsl/test_caches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 2aecfbb3..92cc4537 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -41,7 +41,7 @@ def restore_cache_dir(): gt_config.cache_settings["dir_name"] = cache_dir -def _stencil(inp: Field[float], out: Field[float], scalar: float): +def _stencil(inp: Field[float], out: Field[float]): with computation(PARALLEL), interval(...): out = inp @@ -83,7 +83,7 @@ def __init__(self, backend, orchestration: DaCeOrchestration): self.out = utils.make_storage(empty, grid_indexing, stencil_config, dtype=float) def __call__(self): - self.stencil(self.inp, self.out, self.inp[0, 0, 0]) + self.stencil(self.inp, self.out) @pytest.mark.skipif( From ef3c3de6f3633373ddb5e5bcf039f6066ec0f495 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 2 Sep 2025 18:17:42 +0200 Subject: [PATCH 14/35] Update gt4py to lastest romanc/milestone2 --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 9142c51a..c58725f3 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 9142c51ad516b31fef3afafc132e8a4ed5b561be +Subproject commit c58725f3ebaf68f90028bbab80b38789df909e16 From 6257eb6578d4e5ea553878ee9a197acd9eee8466 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:42:56 +0200 Subject: [PATCH 15/35] tests: Add test case for orchestrated tables Add a non-trival test case for orchestrating tables. This is a mitigation for a gt4py-orchestration-issue that is easiest reproduced from NDSL (compared to a adding a test in gt4py directly). --- tests/dsl/test_stencil_tables.py | 79 ++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/dsl/test_stencil_tables.py diff --git a/tests/dsl/test_stencil_tables.py b/tests/dsl/test_stencil_tables.py new file mode 100644 index 00000000..0b31826a --- /dev/null +++ b/tests/dsl/test_stencil_tables.py @@ -0,0 +1,79 @@ +import numpy as np +from gt4py.storage import ones, zeros + +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + GridIndexing, + StencilConfig, + StencilFactory, + orchestrate, +) +from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, GlobalTable, computation, interval +from ndsl.dsl.stencil import CompareToNumpyStencil +from tests.dsl import utils + + +def _stencil(inp: GlobalTable[np.int32, (5,)], out: Field[np.float64]) -> None: + with computation(PARALLEL), interval(0, -1): + out[0, 0, 0] = inp.A[1] + with computation(FORWARD), interval(-1, None): + out[0, 0, 0] = inp.A[1] + inp.A[2] + + +def _build_stencil( + backend: str, orchestrated: DaCeOrchestration +) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: + # Make stencil and verify it ran + grid_indexing = GridIndexing( + domain=(5, 5, 5), + n_halo=2, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, + ) + + stencil_config = StencilConfig( + compilation_config=CompilationConfig(backend=backend, rebuild=True), + dace_config=DaceConfig(None, backend, 5, 5, orchestrated), + ) + + stencil_factory = StencilFactory(stencil_config, grid_indexing) + + built_stencil = stencil_factory.from_origin_domain( + _stencil, origin=(0, 0, 0), domain=grid_indexing.domain + ) + + return built_stencil, grid_indexing, stencil_config + + +class OrchestratedProgram: + def __init__(self, backend, orchestration: DaCeOrchestration): + self.stencil, grid_indexing, stencil_config = _build_stencil( + backend, orchestration + ) + orchestrate(obj=self, config=stencil_config.dace_config) + + self.inp = ones(shape=(5,), dtype=np.int32, backend=backend) + self.inp[1] = 42 + self.out = utils.make_storage(zeros, grid_indexing, stencil_config, dtype=float) + + def __call__(self): + self.stencil(self.inp, self.out) + + +def test_stecil_with_table_orchestrated() -> None: + program = OrchestratedProgram( + backend="dace:cpu", orchestration=DaCeOrchestration.BuildAndRun + ) + + # run the orchestrated stencil + program() + + # validate output + for k in range(4): + assert (program.out[:, :, k] == 42).all() + assert (program.out[:, :, 4] == 43).all() From 85011ef6adb1b5c4293907e5bec3dbb2537cc5c0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 25 Jun 2025 16:35:16 -0400 Subject: [PATCH 16/35] [orchestration] common cast operation replacments Cherry-picking (parts of) PR https://github.com/NOAA-GFDL/NDSL/pull/211 into the milestone2 branch. --- ndsl/dsl/dace/orchestration.py | 1 + ndsl/dsl/dace/replacements.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 ndsl/dsl/dace/replacements.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index d16b4a3a..5e0ef36a 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -19,6 +19,7 @@ from dace.transformation.passes.simplify import SimplifyPass from gt4py import storage +import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( diff --git a/ndsl/dsl/dace/replacements.py b/ndsl/dsl/dace/replacements.py new file mode 100644 index 00000000..26a3ab78 --- /dev/null +++ b/ndsl/dsl/dace/replacements.py @@ -0,0 +1,32 @@ +"""This module uses DaCe's op_repository feature to override symbols/AST object +during parsing and replace them with an SDFG compatible representation. This +allow custom NDSL system to be natively orchestratable.""" + +from dace import SDFG, SDFGState, dtypes +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.newast import ProgramVisitor +from dace.frontend.python.replacements import UfuncInput, _datatype_converter + +from ndsl.dsl.typing import Float, Int + + +@oprepo.replaces("Float") +def _convert_Float(_pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arg: UfuncInput): + """Replace `Float(x)` with a typecast of `x` to the proper floating precision type""" + return _datatype_converter( + sdfg, + state, + arg, + dtype=dtypes.dtype_to_typeclass(Float), + ) + + +@oprepo.replaces("Int") +def _convert_Int(_pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arg: UfuncInput): + """Replace `Int(x)` with a typecast of `x` to the proper integer precision type""" + return _datatype_converter( + sdfg, + state, + arg, + dtype=dtypes.dtype_to_typeclass(Int), + ) From 13c912d466c2cc4b13205214d8a197ca6797bad4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:23:14 +0200 Subject: [PATCH 17/35] FieldBundle memoization fix --- ndsl/dsl/dace/orchestration.py | 138 +++++++++++++++++---------------- ndsl/quantity/field_bundle.py | 25 +++--- 2 files changed, 86 insertions(+), 77 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5e0ef36a..c8cf16e0 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -237,7 +237,7 @@ def _build_sdfg( if mode == DaCeOrchestration.BuildAndRun: if not is_compiling: ndsl_log.info( - f"{DaCeProgress.default_prefix(config)} Rank is not compiling." + f"{DaCeProgress.default_prefix(config)} Rank is not compiling. " "Waiting for compilation to end on all other ranks..." ) MPI.COMM_WORLD.Barrier() @@ -464,9 +464,9 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable: def orchestrate( *, obj: object, - config: Optional[DaceConfig], + config: DaceConfig, method_to_orchestrate: str = "__call__", - dace_compiletime_args: Optional[Sequence[str]] = None, + dace_compiletime_args: Sequence[str] | None = None, ): """ Orchestrate a method of an object with DaCe. @@ -481,77 +481,79 @@ def orchestrate( dace_compiletime_args: list of names of arguments to be flagged has dace.compiletime for orchestration to behave """ + if not config.is_dace_orchestrated(): + return + if config is None: raise ValueError("DaCe config cannot be None") + if not hasattr(obj, method_to_orchestrate): + raise RuntimeError( + f"Could not orchestrate, " + f"{type(obj).__name__}.{method_to_orchestrate} " + "does not exist" + ) + if dace_compiletime_args is None: dace_compiletime_args = [] - if config.is_dace_orchestrated(): - if not hasattr(obj, method_to_orchestrate): - raise RuntimeError( - f"Could not orchestrate, " - f"{type(obj).__name__}.{method_to_orchestrate} " - "does not exists" - ) - - func = type.__getattribute__(type(obj), method_to_orchestrate) - - # Flag argument as dace.constant - for argument in dace_compiletime_args: - func.__annotations__[argument] = DaceCompiletime - - # Build DaCe orchestrated wrapper - # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config).__get__(obj) - - if method_to_orchestrate == "__call__": - # Grab the function from the type of the child class - # Dev note: we need to use type for dunder call because: - # a = A() - # a() - # resolved to: type(a).__call__(a) - # therefore patching the instance call (e.g a.__call__) is not enough. - # We could patch the type(self), ergo the class itself - # but that would patch _every_ instance of A. - # What we can do is patch the instance.__class__ with a local made class - # in order to keep each instance with it's own patch. - # - # Re: type:ignore - # Mypy is unhappy about dynamic class name and the devs (per github - # issues discussion) is to make a plugin. Too much work -> ignore mypy - - class _(type(obj)): # type: ignore - __qualname__ = f"{type(obj).__qualname__}_patched" - __name__ = f"{type(obj).__name__}_patched" - - def __call__(self, *arg, **kwarg): - return wrapped(*arg, **kwarg) - - def __sdfg__(self, *args, **kwargs): - return wrapped.__sdfg__(*args, **kwargs) - - def __sdfg_closure__(self, reevaluate=None): - return wrapped.__sdfg_closure__(reevaluate) - - def __sdfg_signature__(self): - return wrapped.__sdfg_signature__() - - def closure_resolver( - self, constant_args, given_args, parent_closure=None - ): - return wrapped.closure_resolver( - constant_args, given_args, parent_closure - ) - - # We keep the original class type name to not perturb - # the workflows that uses it to build relevant info (path, hash...) - previous_cls_name = type(obj).__name__ - obj.__class__ = _ - type(obj).__name__ = previous_cls_name - else: - # For regular attribute - we can just patch as usual - setattr(obj, method_to_orchestrate, wrapped) + func = type.__getattribute__(type(obj), method_to_orchestrate) + + # Flag argument as dace.constant + for argument in dace_compiletime_args: + func.__annotations__[argument] = DaceCompiletime + + # Build DaCe orchestrated wrapper + # This is a JIT object, e.g. DaCe compilation will happen on call + wrapped = _LazyComputepathMethod(func, config).__get__(obj) + + if method_to_orchestrate == "__call__": + # Grab the function from the type of the child class + # Dev note: we need to use type for dunder call because: + # a = A() + # a() + # resolved to: type(a).__call__(a) + # therefore patching the instance call (e.g a.__call__) is not enough. + # We could patch the type(self), ergo the class itself + # but that would patch _every_ instance of A. + # What we can do is patch the instance.__class__ with a local made class + # in order to keep each instance with it's own patch. + # + # Re: type:ignore + # Mypy is unhappy about dynamic class name and the devs (per github + # issues discussion) is to make a plugin. Too much work -> ignore mypy + + class _(type(obj)): # type: ignore + __qualname__ = f"{type(obj).__qualname__}_patched" + __name__ = f"{type(obj).__name__}_patched" + + def __call__(self, *arg, **kwarg): + return wrapped(*arg, **kwarg) + + def __sdfg__(self, *args, **kwargs): + sdfg = wrapped.__sdfg__(*args, **kwargs) + sdfg.validate() + return sdfg + + def __sdfg_closure__(self, reevaluate=None): + return wrapped.__sdfg_closure__(reevaluate) + + def __sdfg_signature__(self): + return wrapped.__sdfg_signature__() + + def closure_resolver(self, constant_args, given_args, parent_closure=None): + return wrapped.closure_resolver( + constant_args, given_args, parent_closure + ) + + # We keep the original class type name to not perturb + # the workflows that uses it to build relevant info (path, hash...) + previous_cls_name = type(obj).__name__ + obj.__class__ = _ + type(obj).__name__ = previous_cls_name + else: + # For regular attribute - we can just patch as usual + setattr(obj, method_to_orchestrate, wrapped) def orchestrate_function( diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py index df637fe2..6f57e430 100644 --- a/ndsl/quantity/field_bundle.py +++ b/ndsl/quantity/field_bundle.py @@ -25,6 +25,7 @@ class FieldBundle: """ _quantity: Quantity + _per_name_view: dict[str, Quantity] = {} _indexer: _FieldBundleIndexer = {} def __init__( @@ -80,13 +81,19 @@ def __getattr__(self, name: str) -> Quantity: return None # type: ignore # ToDo: extend the dims below to work with more than 4 dims assert len(self._quantity.data.shape) == 4 - return Quantity( - data=self._quantity.data[:, :, :, self.index(name)], - dims=self._quantity.dims[:-1], - units=self._quantity.units, - origin=self._quantity.origin[:-1], - extent=self._quantity.extent[:-1], - ) + + if name not in self._per_name_view: + # Memoize the Quantities returned here to ensue that we only ever + # have one `field.a_name`-Quantity floating around. If not, DaCe + # orchestration gets (rightly so) confused. + self._per_name_view[name] = Quantity( + data=self._quantity.data[:, :, :, self.index(name)], + dims=self._quantity.dims[:-1], + units=self._quantity.units, + origin=self._quantity.origin[:-1], + extent=self._quantity.extent[:-1], + ) + return self._per_name_view[name] def index(self, name: str) -> int: """Get index from name.""" @@ -141,7 +148,7 @@ class FieldBundleType: """Field Bundle Types to help with static sizing of Data Dimensions. Methods: - register: Register a type by sizing it's data dimensions + register: Register a type by sizing its data dimensions T: access any registered types for type hinting. """ @@ -151,7 +158,7 @@ class FieldBundleType: def register( cls, name: str, data_dims: tuple[int], dtype=Float ) -> gtscript._FieldDescriptor: - """Register a name type by name by giving the size of it's data dimensions. + """Register a name type by name by giving the size of its data dimensions. The same type cannot be registered twice and will error out. From 474517436df4a39fc95f699cb74c3275f26e6706 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 11 Sep 2025 18:14:37 +0200 Subject: [PATCH 18/35] Update gt4py: fix memlets into FrozenSDFG --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 8f6a45e3..cbb3eb9f 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 8f6a45e33772fef230eccdcce5b04e7145dfbdc8 +Subproject commit cbb3eb9ffb9c5250fb50bbfa406a60d6a1a8c653 From 83517d43d1321ba7c4bf313965fc20dc8ad3711e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 12 Sep 2025 12:22:39 +0200 Subject: [PATCH 19/35] update gt4py: tests memlet dimesion / fix domain symbols This gt4py update includes - tests for the memlet dimension fix - another fix to ensure that we always define all three cartesian symbols (even if we are only passing 2d fields and scalars into the stencil). --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index cbb3eb9f..644733df 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit cbb3eb9ffb9c5250fb50bbfa406a60d6a1a8c653 +Subproject commit 644733df3bfb4b806c684725a3af3e9501b269d1 From c897975acdb6eb311d26fac4562b1c79fe50b414 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 19 Sep 2025 14:17:50 +0200 Subject: [PATCH 20/35] cleanup: backends raise if not defined (#234) No need to assert - `from_backend()` raises a `ValueError` if a requested backend doesn't exist. --- ndsl/quantity/quantity.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index a4f7bd52..3a75c685 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -83,7 +83,6 @@ def __init__( if gt4py_backend is not None: gt4py_backend_cls = gt_backend.from_name(gt4py_backend) - assert gt4py_backend_cls is not None is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] dimensions: Tuple[Union[str, int], ...] = tuple( From 43030fa412e8c7cec1b888ec1ae542e5a3b0556e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:43:54 +0200 Subject: [PATCH 21/35] GT4Py update This GT4Py update includes - dace fixes: FrozenSDFG fixes, iterator symbols - feature: `dace:cpu_kfirst` backend - tests: remove unused test utils - tests: print cache location at start (not end) - dace fixes: merge schedule tree roundtrip work - dace fix: memlet size of data dimensions - dace fix: use cached SDFGs from disk - dace perf: align loop structure and data layout - dace: remove unsued tile symbol function - refacor: invalid backend/frontend raise ValueError --- README.md | 2 -- external/gt4py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 0f594637..b086c6ae 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,6 @@ - "Hybrid" indexing that allows absolute K-indices - Access to the iteration variable in K, working title `THIS_K` -- `round()` function in gtscript -- "schedule tree roundtrip work": fixes in the gt4py/dace bridge and in dace's stree/sdfg conversions to validate PyFV3 translate tests (with the AI2 data) Your standard readme continues below. diff --git a/external/gt4py b/external/gt4py index 644733df..61e785cc 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 644733df3bfb4b806c684725a3af3e9501b269d1 +Subproject commit 61e785cc04c63720da518b27a5c7ad229d347663 From cd599ffd277b2c99a4cd79f7f5364b1a0868c58a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:57:38 +0200 Subject: [PATCH 22/35] gt4py update: no major changes in cartesian this is just to be up to date with the `milestone2` in gt4py, which was updated as preparation for setting up a PR for absolute k indexing as experimental feature. --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 61e785cc..16ffb572 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 61e785cc04c63720da518b27a5c7ad229d347663 +Subproject commit 16ffb572a17527321dd1de026abd5c7eac8f228c From 4206b3b5521bedc88aa0e55ea94f5a78fc68fc1e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:16:03 +0200 Subject: [PATCH 23/35] gt4py update (abs K index fix in debug & dace) Bring in a fix for the issue that showed when IJ or K fields were used in combination with absolute K indexing. --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 16ffb572..3927d1f0 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 16ffb572a17527321dd1de026abd5c7eac8f228c +Subproject commit 3927d1f04fff1b6f5834f4b72ff7494abe2956f6 From 5e563ea5ea4f362637e67cf1c626d889f1d91001 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:37:33 +0200 Subject: [PATCH 24/35] gt4py update: absolute K indexing in mainline --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 3927d1f0..0ad8d959 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 3927d1f04fff1b6f5834f4b72ff7494abe2956f6 +Subproject commit 0ad8d95995a3e0291fb05d3de6b31af34b37fc98 From fc8b92fbdebcbbc13894e178f11255face26b81d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:38:54 +0200 Subject: [PATCH 25/35] absolute k indexing is now part of mainline gt4py (experimental) --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index b086c6ae..1454d54a 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ :warning: :dragon: This is the equivalent of a "release" branch for the NASA Team's Milestone 2 work. In particular, we include the following experimental features: -- "Hybrid" indexing that allows absolute K-indices - Access to the iteration variable in K, working title `THIS_K` Your standard readme continues below. From 47bc3e36d566d76fad3f89e09d2a74334cfe3177 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 7 Oct 2025 11:55:52 -0400 Subject: [PATCH 26/35] Schedule Tree Pipeline + Untested Axis Merge (#251) * Roundtrip sdfg -> stree -> sdfg in orchestration - with moaar validation * Move in code the merge passes + K offset check * Insert optimization in orchestration * Conserve correct code-flow and stop merging when hitting non-map node as second candidate * Debug: Save STREE post opt Remove still assert * Split AxisMerge, add scalar tasklet push * Move algorithms to a 3-step method * Move up `ndsl_log` in `__init__` stack because it's a standalone file (cut on potential circular imports) * Working PushIfElse operator (on FvTp2D) dies on D_SW * Fix `list.index` re-using `_list_index` written by hand Allow scope operation to look more broadly at `next_node` mergeability * Remove debug prints and intermediate sdfg saving * Use default when calling simplify * Update gt4py/dace submodules (roundtrip work) This commit brings the changes needed for stree rountrips to validate with the AI2 data (in the PyFV3 translate tests). * Update README * New algorithms - with revert when failure to merge and more aggresive depth-first merges * Quick note: skip ScalarToSymbolPromotion for now The pass messes up previously valid & validating SDFGs. We can live (performance wise) without it for the current milestone. Let's re-evaluate once we get back to DaCe mainline (v2). * Add default ControlFlow behavior (recurse) Add - deactivated - AxisIteartor name sanitizer Fix single axis merge test * Add helper to detect if log is in Debug Add Release & march=native into dace compiler flags Unused orchestration pass Unused stree pass * Fix the GT4Py dependancy * Bad merge fix * Bad merge fix * Internal in code flag for stree optimizer * Move helpful "Make Sequential" SDFG transformation * Lint * Remove original Roman code that has been harvested for good * Unit test for roundtrip, proper pipeline setup * Lint * Fix to default Pipeline * Move helper function for SDFG, delete unused code * Move out tree common operations * Add mock test for optimization * Lint * Lint --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- README.md | 1 + ndsl/__init__.py | 2 +- ndsl/dsl/dace/orchestration.py | 32 +- ndsl/dsl/dace/sdfg/loop_transform.py | 19 + ndsl/dsl/dace/stree/__init__.py | 4 + ndsl/dsl/dace/stree/optimizations/__init__.py | 4 + .../dace/stree/optimizations/axis_merge.py | 437 ++++++++++++++++++ .../stree/optimizations/memlet_helpers.py | 145 ++++++ .../stree/optimizations/specialize_maps.py | 21 + .../stree/optimizations/tree_common_op.py | 42 ++ ndsl/dsl/dace/stree/pipeline.py | 72 +++ tests/stree_optimizer/test_optimization.py | 69 +++ tests/stree_optimizer/test_pipeline.py | 48 ++ 13 files changed, 887 insertions(+), 9 deletions(-) create mode 100644 ndsl/dsl/dace/sdfg/loop_transform.py create mode 100644 ndsl/dsl/dace/stree/__init__.py create mode 100644 ndsl/dsl/dace/stree/optimizations/__init__.py create mode 100644 ndsl/dsl/dace/stree/optimizations/axis_merge.py create mode 100644 ndsl/dsl/dace/stree/optimizations/memlet_helpers.py create mode 100644 ndsl/dsl/dace/stree/optimizations/specialize_maps.py create mode 100644 ndsl/dsl/dace/stree/optimizations/tree_common_op.py create mode 100644 ndsl/dsl/dace/stree/pipeline.py create mode 100644 tests/stree_optimizer/test_optimization.py create mode 100644 tests/stree_optimizer/test_pipeline.py diff --git a/README.md b/README.md index 1454d54a..380b36f9 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ :warning: :dragon: This is the equivalent of a "release" branch for the NASA Team's Milestone 2 work. In particular, we include the following experimental features: - Access to the iteration variable in K, working title `THIS_K` +- "schedule tree roundtrip work": fixes in the gt4py/dace bridge and in dace's stree/sdfg conversions to validate PyFV3 translate tests (with the AI2 data) Your standard readme continues below. diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 675610e3..97d3d269 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,4 +1,5 @@ from . import dsl # isort:skip +from .logging import ndsl_log # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm @@ -22,7 +23,6 @@ from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer -from .logging import ndsl_log from .monitor.netcdf_monitor import NetCDFMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index ad4ae3d1..a3769e62 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -33,6 +33,8 @@ negative_qtracers_checker, sdfg_nan_checker, ) +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -42,6 +44,13 @@ from ndsl.optional_imports import cupy as cp +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False +"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" + +_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] +"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" + + def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" return func @@ -123,6 +132,7 @@ def _build_sdfg( ) -> None: """Build the .so out of the SDFG on the top tile ranks only.""" is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile + device_type = DaceDeviceType.GPU if config.is_gpu_backend() else DaceDeviceType.CPU if is_compiling: with DaCeProgress(config, "Validate original SDFG"): @@ -143,12 +153,18 @@ def _build_sdfg( with DaCeProgress(config, "Simplify (1)"): _simplify(sdfg) - # TODO uncomment if you want to test schedule tree roundtrip change and/or remove once - # we have the schedule tree optimization pipeline. - # with DaCeProgress(config, "Schedule tree roundtrip"): - # stree = sdfg.as_schedule_tree() - # # ScalarToSymbolPromotion is messing with us, so we disable it. - # sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) + if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + stree = sdfg.as_schedule_tree() + + with DaCeProgress(config, "Schedule Tree: optimization"): + if config.is_gpu_backend(): + GPUPipeline().run(stree) + else: + CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + + with DaCeProgress(config, "Schedule Tree: go back to SDFG"): + sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) # Make the transients array persistents if config.is_gpu_backend(): @@ -156,7 +172,7 @@ def _build_sdfg( # The following should happen on the stree level _to_gpu(sdfg) - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.GPU) + make_transients_persistent(sdfg=sdfg, device=device_type) # Upload args to device _upload_to_device(list(args) + list(kwargs.values())) @@ -166,7 +182,7 @@ def _build_sdfg( for _sd, _aname, arr in sdfg.arrays_recursive(): if arr.shape == (1,): arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU) + make_transients_persistent(sdfg=sdfg, device=device_type) # Build non-constants & non-transients from the sdfg_kwargs sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) diff --git a/ndsl/dsl/dace/sdfg/loop_transform.py b/ndsl/dsl/dace/sdfg/loop_transform.py new file mode 100644 index 00000000..7e6cf1d4 --- /dev/null +++ b/ndsl/dsl/dace/sdfg/loop_transform.py @@ -0,0 +1,19 @@ +from dace import SDFG, ScheduleType, nodes + + +def make_SDFG_CPU_sequential(sdfg: SDFG) -> None: + """Utility to turn a CPU-based SDFG to pure serial by removing OpenMP""" + # Disable OpenMP sections + for sd in sdfg.all_sdfgs_recursive(): + sd.openmp_sections = False + + # Disable OpenMP maps + for node, _ in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.EntryNode): + schedule = getattr(node, "schedule", False) + if schedule in ( + ScheduleType.CPU_Multicore, + ScheduleType.CPU_Persistent, + ScheduleType.Default, + ): + node.schedule = ScheduleType.Sequential diff --git a/ndsl/dsl/dace/stree/__init__.py b/ndsl/dsl/dace/stree/__init__.py new file mode 100644 index 00000000..6435e662 --- /dev/null +++ b/ndsl/dsl/dace/stree/__init__.py @@ -0,0 +1,4 @@ +from .pipeline import CPUPipeline, GPUPipeline + + +__all__ = ["CPUPipeline", "GPUPipeline"] diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py new file mode 100644 index 00000000..47c764b3 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,4 @@ +from .axis_merge import AxisIterator, CartesianAxisMerge + + +__all__ = ["AxisIterator", "CartesianAxisMerge"] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py new file mode 100644 index 00000000..27e3bc69 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import copy +import re +from typing import Any, List + +import dace +import dace.sdfg.analysis.schedule_tree.treenodes as dst +from dace.properties import CodeBlock + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( + AxisIterator, + no_data_dependencies_on_cartesian_axis, +) +from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( + detect_cycle, + list_index, + swap_node_position_in_tree, +) + + +def _is_axis_map(node: dst.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a map over K.""" + map_parameter = node.node.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def _both_same_single_axis_maps( + first: dst.MapScope, + second: dst.MapScope, + axis: AxisIterator, +) -> bool: + return ( + (len(first.node.params) == 1 and len(second.node.params) == 1) # Single axis + and first.node.params[0] == second.node.params[0] # Same axis + and _is_axis_map(first, axis) # Correct axis + ) + + +def _can_merge_axis_maps( + first: dst.MapScope, + second: dst.MapScope, + axis: AxisIterator, +) -> bool: + return _both_same_single_axis_maps( + first, second, axis + ) and no_data_dependencies_on_cartesian_axis( + first, + second, + axis, + ) + + +class InsertOvercomputationGuard(dst.ScheduleNodeTransformer): + def __init__( + self, + axis_as_string: str, + *, + merged_range: dace.subsets.Range, + original_range: dace.subsets.Range, + ): + self._axis_as_string = axis_as_string + self._merged_range = merged_range + self._original_range = original_range + + def _execution_condition(self) -> CodeBlock: + # NOTE range.ranges are inclusive, e.g. + # Range(0:4) -> ranges = (start=1, stop=3, step=1) + range = self._original_range + start = range.ranges[0][0] + stop = range.ranges[0][1] + step = range.ranges[0][2] + return CodeBlock( + f"{self._axis_as_string} >= {start} " + f"and {self._axis_as_string} <= {stop} " + f"and ({self._axis_as_string} - {start}) % {step} == 0" + ) + + def visit_MapScope(self, node: dst.MapScope) -> dst.MapScope: + all_children_are_maps = all( + [isinstance(child, dst.MapScope) for child in node.children] + ) + if not all_children_are_maps: + if self._merged_range != self._original_range: + node.children = [ + dst.IfScope( + condition=self._execution_condition(), children=node.children + ) + ] + return node + + node.children = self.visit(node.children) + return node + + +def _get_next_node( + nodes: list[dst.ScheduleTreeNode], + node: dst.ScheduleTreeNode, +) -> dst.ScheduleTreeNode: + return nodes[list_index(nodes, node) + 1] + + +def _last_node(nodes: list[dst.ScheduleTreeNode], node: dst.ScheduleTreeNode) -> bool: + return list_index(nodes, node) >= len(nodes) - 1 + + +def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str: + axis_clean = f"{axis.as_str()}" + pattern = f"{axis.as_str()}_[0-9]*" + + return re.sub(pattern, axis_clean, name_to_normalize) + + +class NormalizeAxisSymbol(dst.ScheduleNodeVisitor): + def __init__(self, axis: AxisIterator) -> None: + self.axis = axis + + def visit_MapScope( + self, + map_scope: dst.MapScope, + axis_rpl_dict: dict[str, str] = {}, + **kwargs: Any, + ) -> None: + for i_param, param in enumerate(map_scope.node.params): + sanitized_param = _sanitize_axis(self.axis, param) + axis_rpl_dict[param] = sanitized_param + map_scope.node.params[i_param] = sanitized_param + + # visit children + for child in map_scope.children: + self.visit(child, axis_rpl_dict=axis_rpl_dict) + + def visit_TaskletNode( + self, + node: dst.TaskletNode, + axis_rpl_dict: dict[str, str] = {}, + **kwargs: Any, + ) -> None: + for memlets in node.in_memlets.values(): + memlets.replace(axis_rpl_dict) + for memlets in node.out_memlets.values(): + memlets.replace(axis_rpl_dict) + + +class CartesianAxisMerge(dst.ScheduleNodeTransformer): + """Merge a cartesian axis if they are contiguous in code-flow. + + Can do: + - merge a given axis with the next maps at the same recursion level + - can overcompute (eager) do allow for more merging at a cost of an if + + Args: + axis: AxisIterator to be merged + eager: overcompute with a conditional guard + """ + + def __init__( + self, + axis: AxisIterator, + *, + eager: bool = True, + ) -> None: + self.axis = axis + self.eager = eager + + def __str__(self) -> str: + return f"CartesianAxisMerge({self.axis.name})" + + def _merge_node( + self, + node: dst.ScheduleTreeNode, + nodes: list[dst.ScheduleTreeNode], + ) -> int: + """Direct code to the correct resolver for the node (e.g. visitor) + + Dev Note: Order matters! + Default behavior for base class must be _after_ bespoke leaf class + behavior (e.g. IfScope before ControlFlowScope) + """ + + if isinstance(node, dst.MapScope): + return self._map_overcompute_merge(node, nodes) + elif isinstance(node, dst.IfScope): + return self._push_ifelse_down(node, nodes) + elif isinstance(node, dst.TaskletNode): + return self._push_tasklet_down(node, nodes) + elif isinstance(node, dst.ControlFlowScope): + return self._default_control_flow(node, nodes) + else: + ndsl_log.debug( + f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends." + ) + return 0 + + def _default_control_flow( + self, + the_control_flow: dst.ControlFlowScope, + nodes: list[dst.ScheduleTreeNode], + ) -> int: + if len(the_control_flow.children) != 0: + return self._merge(the_control_flow) + + return 0 + + def _push_tasklet_down( + self, + the_tasklet: dst.TaskletNode, + nodes: list[dst.ScheduleTreeNode], + ) -> int: + """Push tasklet into a consecutive map.""" + in_memlets = the_tasklet.input_memlets() + if len(in_memlets) != 0: + if "__pystate" in [tasklet.data for tasklet in the_tasklet.input_memlets()]: + return 0 # Tasklet is a callback + + next_index = list_index(nodes, the_tasklet) + if next_index == len(nodes): + return 0 # Last node - done + + next_node = nodes[next_index + 1] + + # Before checking the possibility of merging - attempt to surface + # a map from the next nodes + merged = self._merge_node(next_node, nodes) + + # Attempt to push the tasklet in the next map + ndsl_log.debug(" Push tasklet down into next map") + next_node = nodes[next_index + 1] + if isinstance(next_node, dst.MapScope): + next_node.children.insert(0, the_tasklet) + the_tasklet.parent = next_node + nodes.remove(the_tasklet) + merged += self._merge_node(next_node, nodes) + + return merged + + def _push_ifelse_down( + self, + the_if: dst.IfScope, + nodes: list[dst.ScheduleTreeNode], + ) -> int: + merged = 0 + + # Recurse down if/else/elif + if_index = list_index(nodes, the_if) + if len(the_if.children) != 0: + merged += self._merge_node(the_if.children[0], the_if.children) + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, dst.ElseScope) + or isinstance(else_node, dst.ElifScope) + ): + merged += self._merge_node(else_node, else_node.children) + else: + break + + # Look at swapping if/else/elif first map w/ control flow + + # Gather all first maps - if they do not exists, get out + all_maps = [] + if isinstance(the_if.children[0], dst.MapScope): + all_maps.append(the_if.children[0]) + else: + return merged + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, dst.ElseScope) + or isinstance(else_node, dst.ElifScope) + ): + if isinstance(else_node.children[0], dst.MapScope): + all_maps.append(else_node.children[0]) + else: + return merged + else: + break + + # Check for mergeability + if len(all_maps) > 1: + the_map = all_maps[0] + for _map in all_maps[1:]: + if not _can_merge_axis_maps(the_map, _map, self.axis): + return merged + + # We are good to go - swap it all + ndsl_log.debug(f" Push IF {the_if.condition.as_string} down") + inner_if_map = the_if.children[0] + + # Swap IF & maps + if_index = list_index(nodes, the_if) + swap_node_position_in_tree(the_if, inner_if_map) + + # Swap ELIF/ELSE & maps + for else_index in range(if_index + 1, len(nodes)): + if else_index < len(nodes) and ( + isinstance(nodes[else_index], dst.ElseScope) + or isinstance(nodes[else_index], dst.ElifScope) + ): + swap_node_position_in_tree( + nodes[else_index], nodes[else_index].children[0] + ) + else: + break + + # Merge the Maps + assert isinstance(nodes[if_index], dst.MapScope) + merged += self._map_overcompute_merge(nodes[if_index], nodes) + + return merged + + def _map_overcompute_merge( + self, + the_map: dst.MapScope, + nodes: list[dst.ScheduleTreeNode], + ) -> int: + if _last_node(nodes, the_map): + return 0 + + next_node = _get_next_node(nodes, the_map) + + # If we the next node is not a MapScope - recurse + if not isinstance(next_node, dst.MapScope): + merged = self._merge_node(next_node, nodes) + new_next_node = _get_next_node(nodes, the_map) + if new_next_node == next_node: + return merged + return merged + self._merge_node(the_map, nodes) + + # Attempt to merge consecutive maps + if not _can_merge_axis_maps(the_map, next_node, self.axis): + return 0 + + # Only for maps in K: + # - force-merge by expanding the ranges + # - then, guard children to only run in their respective range + first_range = the_map.node.map.range + second_range = next_node.node.map.range + merged_range = dace.subsets.Range( + [ + ( + f"min({first_range.ranges[0][0]}, {second_range.ranges[0][0]})", + f"max({first_range.ranges[0][1]}, {second_range.ranges[0][1]})", + 1, # NOTE: we can optimize this to gcd later + ) + ] + ) + + ndsl_log.debug( + f" Merge K map: {first_range} ⋃ {second_range} -> {merged_range}" + ) + + # push IfScope down if children are just maps + axis_as_str = the_map.node.params[0] + first_map = InsertOvercomputationGuard( + axis_as_str, merged_range=merged_range, original_range=first_range + ).visit(the_map) + second_map = InsertOvercomputationGuard( + axis_as_str, + merged_range=merged_range, + original_range=second_range, + ).visit(next_node) + merged_children: List[dst.MapScope] = [ + *first_map.children, + *second_map.children, + ] + first_map.children = merged_children + + # TODO also merge containers and symbols (if applicable) + + first_map.node.map.range = merged_range + + # delete now-merged second_map + del nodes[list_index(nodes, next_node)] + + return 1 + + def _merge(self, node: dst.ScheduleTreeRoot | dst.ScheduleTreeScope) -> int: + merged = 0 + + if __debug__: + detect_cycle(node.children, set()) + + i_candidate = 0 + while i_candidate < len(node.children): + next_node = node.children[i_candidate] + merged += self._merge_node(next_node, node.children) + i_candidate += 1 + + if __debug__: + detect_cycle(node.children, set()) + + return merged + + def visit_ScheduleTreeRoot(self, node: dst.ScheduleTreeRoot): + """Merge as many maps as possible. + + The algorithm works as follows: + - Start merging - move nodes to surface maps as much as possible + - Try to merge the surfaced maps + - When done, count the number of actual merges + - If NO merges - restore the previous children + (undo potential changes that didn't lead to map merge) + Then exit. + """ + + # ndsl_log.debug(f"🧹 Normalizing {self.axis} loop symbol") + # NormalizeAxisSymbol(self.axis).visit(node) + # with open("stree-IN-sanitized.txt", "w") as f: + # f.write(node.as_string(-1)) + + overall_merged = 0 + i = 0 + while True: + i += 1 + ndsl_log.debug(f"🔥 Merge attempt #{i}") + previous_children = copy.deepcopy(node.children) + try: + merged = self._merge(node) + overall_merged += merged + if __debug__: + detect_cycle(node.children, set()) + except RecursionError as re: + breakpoint() + raise re + + # If we didn't merge, we revert the children + # to the previous state + if merged == 0: + ndsl_log.debug("🥹 No merges, revert!") + node.children = previous_children + break + + ndsl_log.debug( + f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged" + ) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py new file mode 100644 index 00000000..ecd138b4 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -0,0 +1,145 @@ +from enum import Enum + +import dace.sdfg.analysis.schedule_tree.treenodes as dst +from dace.memlet import Memlet + + +class AxisIterator(Enum): + _I = ("__i", 0) + _J = ("__j", 1) + _K = ("__k", 2) + + def as_str(self) -> str: + return self.value[0] + + def as_cartesian_index(self) -> int: + return self.value[1] + + +def no_data_dependencies_on_cartesian_axis( + first: dst.MapScope, + second: dst.MapScope, + axis: AxisIterator, +) -> bool: + """Check for read after write. Allow when indexation on the axis + is not offseted.""" + + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + + if write.subset.dims() <= axis.as_cartesian_index(): + # Dimension does not exist + continue + + previous_axis_index = write.subset[axis.as_cartesian_index()][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + print( + f"[{axis.name} Merge] Found read after write conflict " + f"for {write.data} " + f"w/ different offset to {axis.name} (" + f"write at {write.subset[axis.as_cartesian_index()][0]}, " + f"read at {read.subset[axis.as_cartesian_index()][0]})" + ) + return False + return True + + +def no_data_dependencies( + first: dst.MapScope, + second: dst.MapScope, + restrict_check_to_k=False, +) -> bool: + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # Make sure we don't have read after write conditions. + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + if restrict_check_to_k: + if write.subset.dims() < 3: + # Case of 2D write - no K dependency + continue + + previous_k_index = write.subset[2][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_k_index != read.subset[2][0]: + print( + "[K Merge] Found read after write conflict " + f"for {write.data} " + "w/ different offset to K (" + f"write at {write.subset[2][0]}, " + f"read at {read.subset[2][0]})" + ) + return False + + else: + if write.data in [read.data for read in read_collector.in_memlets]: + print( + f"[All dims merge] Found potential read after write conflict for {write.data}" + ) + return False + return True + + +class MemletCollector(dst.ScheduleNodeVisitor): + """Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls.""" + + in_memlets: list[Memlet] + out_memlets: list[Memlet] + + def __init__(self, *, collect_reads=True, collect_writes=True): + self._collect_reads = collect_reads + self._collect_writes = collect_writes + + self.in_memlets = [] + self.out_memlets = [] + + def visit_TaskletNode(self, node: dst.TaskletNode) -> None: + if self._collect_reads: + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + if self._collect_writes: + self.out_memlets.extend([memlet for memlet in node.out_memlets.values()]) + + def visit_LibraryCall(self, node: dst.LibraryCall) -> None: + if self._collect_reads: + if isinstance(node.in_memlets, set): + self.in_memlets.extend(node.in_memlets) + else: + assert isinstance(node.in_memlets, dict) + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + + if self._collect_writes: + if isinstance(node.out_memlets, set): + self.out_memlets.extend(node.out_memlets) + else: + assert isinstance(node.out_memlets, dict) + self.out_memlets.extend( + [memlet for memlet in node.out_memlets.values()] + ) + + +def has_dynamic_memlets(first: dst.MapScope, second: dst.MapScope) -> bool: + first_collector = MemletCollector() + second_collector = MemletCollector() + first_collector.visit(first) + second_collector.visit(second) + has_dynamic_memlets = any( + [ + memlet.dynamic + for memlet in [ + *first_collector.in_memlets, + *first_collector.out_memlets, + *second_collector.in_memlets, + *second_collector.out_memlets, + ] + ] + ) + return has_dynamic_memlets diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py new file mode 100644 index 00000000..90cb553b --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -0,0 +1,21 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.subsets as sbs + + +class SpecializeCartesianMaps(dst.ScheduleNodeVisitor): + def __init__(self, mappings: dict[str, int]) -> None: + super().__init__() + self._mappings = mappings + + def visit_MapScope(self, node: dst.MapScope): + dims = [] + for p in node.node.map.params: + if p == "__i": + dims.append((0, self._mappings["__I"], 1)) + if p == "__j": + dims.append((0, self._mappings["__J"], 1)) + if p.startswith("__k"): + dims.append((0, self._mappings["__K"], 1)) + node.node.map.range = sbs.Range(dims) + + self.visit(node.children) diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py new file mode 100644 index 00000000..c3da106e --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -0,0 +1,42 @@ +from typing import Collection + +import dace.sdfg.analysis.schedule_tree.treenodes as dst + + +def swap_node_position_in_tree(top_node, child_node): + """Top node becomes child, child becomes top node""" + # Take refs before swap + top_children = top_node.parent.children + top_level_parent = top_node.parent + + # Swap childrens + top_node.children = child_node.children + child_node.children = [top_node] + top_children.insert(list_index(top_children, top_node), child_node) + + # Re-parent + top_node.parent = child_node + child_node.parent = top_level_parent + + # Remove now-pushed original node + top_children.remove(top_node) + + +def detect_cycle(nodes, visited: set): + """Detect the cycles in the tree.""" + # Dev note: isn't there a DaCe tool for this?! + for n in nodes: + if id(n) in visited: + breakpoint() + visited.add(id(n)) + if hasattr(n, "children"): + detect_cycle(n.children, visited) + + +def list_index( + collection: Collection[dst.ScheduleTreeNode], + node: dst.ScheduleTreeNode, +) -> int: + """Check if node is in list with "is" operator.""" + # compare with "is" to get memory comparison. ".index()" uses value comparison + return next(index for index, element in enumerate(collection) if element is node) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py new file mode 100644 index 00000000..898107a8 --- /dev/null +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -0,0 +1,72 @@ +from abc import abstractmethod +from typing import Protocol + +import dace.sdfg.analysis.schedule_tree.treenodes as dst + +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge + + +class StreePipeline(Protocol): + @abstractmethod + def __hash__(self) -> int: + raise NotImplementedError("Missing implementation of __hash__") + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError("Missing implementation of __repr__") + + @abstractmethod + def run( + self, + stree: dst.ScheduleTreeRoot, + verbose: bool = False, + ) -> dst.ScheduleTreeRoot: + raise NotImplementedError("Missing implementation of run") + + +class CPUPipeline(StreePipeline): + def __init__(self, passes: list[dst.ScheduleNodeTransformer] | None = None) -> None: + self.passes = ( + passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] + ) + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: dst.ScheduleTreeRoot, + verbose: bool = False, + ) -> dst.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree + + +class GPUPipeline(StreePipeline): + def __init__(self, passes: list[dst.ScheduleNodeTransformer] | None = None) -> None: + self.passes = passes if passes else [] + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: dst.ScheduleTreeRoot, + verbose: bool = False, + ) -> dst.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py new file mode 100644 index 00000000..fb32a35d --- /dev/null +++ b/tests/stree_optimizer/test_optimization.py @@ -0,0 +1,69 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def stencil_A(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + +def stencil_B(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_B = stencil_factory.from_dims_halo( + func=stencil_B, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil_A(in_field, out_field) + self.stencil_B(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + """Dev Note: + + The below code sucessfully merges top level K loop (2 loops) + How do we test it?! Running doesn't test merging and the compilation + is a near-black box. We could reach in the `dace_config.compiled_sdfg` + cache but it's keyed on the dace.program and if we can reach the program + well we can reach the SDFG and turn it into an stree for verification + Should we run orchestration "by hand"? + Can we intercept the `stree` ? After all we just want to check that! + + Test is deactivated for now""" + + return True + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/stree_optimizer/test_pipeline.py new file mode 100644 index 00000000..6d4c6f74 --- /dev/null +++ b/tests/stree_optimizer/test_pipeline.py @@ -0,0 +1,48 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def double_map(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo( + func=double_map, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False From 00f832605b21d9cd094402044bb04ebe9ad98161 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 10 Oct 2025 12:10:14 -0400 Subject: [PATCH 27/35] [Clean up] Schedule Tree optimizer (WIP) (#255) * Use `ndsl_log` * Remove missed `breakpoint` and turn dead code into coding comment --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 9 +++++---- ndsl/dsl/dace/stree/optimizations/memlet_helpers.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 27e3bc69..18142998 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -405,10 +405,12 @@ def visit_ScheduleTreeRoot(self, node: dst.ScheduleTreeRoot): Then exit. """ - # ndsl_log.debug(f"🧹 Normalizing {self.axis} loop symbol") + # TODO: many interval generate many iterator name right now + # e.g. _k_0, _k_1... + # This makes merging more difficult. We could write a pre-pass + # that cleans this up BUT we have an issue with the THIS_K feature + # in the tasklet... # NormalizeAxisSymbol(self.axis).visit(node) - # with open("stree-IN-sanitized.txt", "w") as f: - # f.write(node.as_string(-1)) overall_merged = 0 i = 0 @@ -422,7 +424,6 @@ def visit_ScheduleTreeRoot(self, node: dst.ScheduleTreeRoot): if __debug__: detect_cycle(node.children, set()) except RecursionError as re: - breakpoint() raise re # If we didn't merge, we revert the children diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py index ecd138b4..1613e014 100644 --- a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -3,6 +3,8 @@ import dace.sdfg.analysis.schedule_tree.treenodes as dst from dace.memlet import Memlet +from ndsl import ndsl_log + class AxisIterator(Enum): _I = ("__i", 0) @@ -22,7 +24,7 @@ def no_data_dependencies_on_cartesian_axis( axis: AxisIterator, ) -> bool: """Check for read after write. Allow when indexation on the axis - is not offseted.""" + is not offset.""" write_collector = MemletCollector(collect_reads=False) write_collector.visit(first) @@ -39,7 +41,7 @@ def no_data_dependencies_on_cartesian_axis( for read in read_collector.in_memlets: if write.data == read.data: if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: - print( + ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" From 10f14dc6b6dc0fd7d6b59a83554450df2318fc0c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:36:29 +0200 Subject: [PATCH 28/35] gt4py update: push forscope down, shiny error messages --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 0ad8d959..d10dddef 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 0ad8d95995a3e0291fb05d3de6b31af34b37fc98 +Subproject commit d10dddefbeba97ca479709a3eb9c722647179be5 From b3136d5467861ec525384b5d612c44299aff5281 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 14 Oct 2025 10:59:17 +0200 Subject: [PATCH 29/35] update dace (& gt4py): fixes from v1/maintenance --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index 54c669ed..1033dfcf 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 54c669ed0f0bc97eb077e84ace58f7596744c8b5 +Subproject commit 1033dfcf9d118856d82c6ee8d6f6cfacec662335 diff --git a/external/gt4py b/external/gt4py index d10dddef..cdc655dc 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit d10dddefbeba97ca479709a3eb9c722647179be5 +Subproject commit cdc655dcb473e026e7bb6ca5309d5dad4631f546 From 379ef1b5537fcde065939f933d0ef5ac9edd855a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 20 Oct 2025 10:09:57 +0200 Subject: [PATCH 30/35] fixup: add missing type after merge --- ndsl/dsl/dace/stree/optimizations/specialize_maps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index 90cb553b..d1ce5b26 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -7,7 +7,7 @@ def __init__(self, mappings: dict[str, int]) -> None: super().__init__() self._mappings = mappings - def visit_MapScope(self, node: dst.MapScope): + def visit_MapScope(self, node: dst.MapScope) -> None: dims = [] for p in node.node.map.params: if p == "__i": From d5b0e6a67e01eee883569a50a36a0b070f9405c1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:43:11 +0200 Subject: [PATCH 31/35] update gt4py: K iteration index --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index cdc655dc..c6b40add 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit cdc655dcb473e026e7bb6ca5309d5dad4631f546 +Subproject commit c6b40add1385d2be5021e2fd15d38a1be13171f5 From a19c94d0d71fa2abf7f89ae28eefff8f81294eb7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 23 Oct 2025 10:49:23 -0400 Subject: [PATCH 32/35] De-dragon the README --- README.md | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/README.md b/README.md index 380b36f9..b256eda7 100644 --- a/README.md +++ b/README.md @@ -3,15 +3,6 @@ # NOAA/NASA Domain Specific Language middleware -:warning: :dragon: This is the equivalent of a "release" branch for the NASA Team's Milestone 2 work. In particular, we include the following experimental features: - -- Access to the iteration variable in K, working title `THIS_K` -- "schedule tree roundtrip work": fixes in the gt4py/dace bridge and in dace's stree/sdfg conversions to validate PyFV3 translate tests (with the AI2 data) - -Your standard readme continues below. - ---- - NDSL is a middleware for climate and weather modelling developed jointly by NOAA and NASA. The middleware brings together [GT4Py](https://github.com/GridTools/gt4py/) (the `cartesian` flavor), ETH CSCS's stencil DSL, and [DaCe](https://github.com/spcl/dace/), ETH SPCL's data flow framework, both developed for high-performance and portability. On top of those pillars, NDSL deploys a series of optimized APIs for common operations (Halo exchange, domain decomposition, MPI, ...), a set of bespoke optimizations for the models targeted by the middleware and tools to port existing models. ## Batteries-included for FV-based models @@ -57,7 +48,7 @@ To run the GPU backends, you'll need: - Libraries: MPI compiled with CUDA support - CUDA 11.2+ - Python package: - - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) + - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) A simple way to install MPI is using pre-built wheels, e.g. From bd6ddc346f669f7f432140c1e1dd89df5076d7bd Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 23 Oct 2025 15:14:37 -0400 Subject: [PATCH 33/35] Rename `dst` to `stree` for moniker of `dace.sdfg.analysis.schedule_tree.treenodes` Better docs --- .../dace/stree/optimizations/axis_merge.py | 118 +++++++++--------- .../stree/optimizations/memlet_helpers.py | 18 +-- .../stree/optimizations/specialize_maps.py | 6 +- .../stree/optimizations/tree_common_op.py | 10 +- ndsl/dsl/dace/stree/pipeline.py | 22 ++-- 5 files changed, 90 insertions(+), 84 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index cd9ba94d..62a8c126 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -2,10 +2,10 @@ import copy import re -from typing import Any, List +from typing import Any import dace -import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.properties import CodeBlock from ndsl import ndsl_log @@ -20,15 +20,15 @@ ) -def _is_axis_map(node: dst.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a map over K.""" +def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a map over the given axis.""" map_parameter = node.node.params return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) def _both_same_single_axis_maps( - first: dst.MapScope, - second: dst.MapScope, + first: stree.MapScope, + second: stree.MapScope, axis: AxisIterator, ) -> bool: return ( @@ -39,8 +39,8 @@ def _both_same_single_axis_maps( def _can_merge_axis_maps( - first: dst.MapScope, - second: dst.MapScope, + first: stree.MapScope, + second: stree.MapScope, axis: AxisIterator, ) -> bool: return _both_same_single_axis_maps( @@ -52,7 +52,7 @@ def _can_merge_axis_maps( ) -class InsertOvercomputationGuard(dst.ScheduleNodeTransformer): +class InsertOvercomputationGuard(stree.ScheduleNodeTransformer): def __init__( self, axis_as_string: str, @@ -77,14 +77,14 @@ def _execution_condition(self) -> CodeBlock: f"and ({self._axis_as_string} - {start}) % {step} == 0" ) - def visit_MapScope(self, node: dst.MapScope) -> dst.MapScope: + def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: all_children_are_maps = all( - [isinstance(child, dst.MapScope) for child in node.children] + [isinstance(child, stree.MapScope) for child in node.children] ) if not all_children_are_maps: if self._merged_range != self._original_range: node.children = [ - dst.IfScope( + stree.IfScope( condition=self._execution_condition(), children=node.children ) ] @@ -95,13 +95,15 @@ def visit_MapScope(self, node: dst.MapScope) -> dst.MapScope: def _get_next_node( - nodes: list[dst.ScheduleTreeNode], - node: dst.ScheduleTreeNode, -) -> dst.ScheduleTreeNode: + nodes: list[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, +) -> stree.ScheduleTreeNode: return nodes[list_index(nodes, node) + 1] -def _last_node(nodes: list[dst.ScheduleTreeNode], node: dst.ScheduleTreeNode) -> bool: +def _last_node( + nodes: list[stree.ScheduleTreeNode], node: stree.ScheduleTreeNode +) -> bool: return list_index(nodes, node) >= len(nodes) - 1 @@ -112,28 +114,28 @@ def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str: return re.sub(pattern, axis_clean, name_to_normalize) -class NormalizeAxisSymbol(dst.ScheduleNodeVisitor): +class NormalizeAxisSymbol(stree.ScheduleNodeVisitor): def __init__(self, axis: AxisIterator) -> None: self.axis = axis def visit_MapScope( self, - map_scope: dst.MapScope, - axis_rpl_dict: dict[str, str] = {}, + map_scope: stree.MapScope, + axis_replacements: dict[str, str] = {}, **kwargs: Any, ) -> None: - for i_param, param in enumerate(map_scope.node.params): + for index, param in enumerate(map_scope.node.params): sanitized_param = _sanitize_axis(self.axis, param) - axis_rpl_dict[param] = sanitized_param - map_scope.node.params[i_param] = sanitized_param + axis_replacements[param] = sanitized_param + map_scope.node.params[index] = sanitized_param # visit children for child in map_scope.children: - self.visit(child, axis_rpl_dict=axis_rpl_dict) + self.visit(child, axis_rpl_dict=axis_replacements) def visit_TaskletNode( self, - node: dst.TaskletNode, + node: stree.TaskletNode, axis_rpl_dict: dict[str, str] = {}, **kwargs: Any, ) -> None: @@ -143,12 +145,12 @@ def visit_TaskletNode( memlets.replace(axis_rpl_dict) -class CartesianAxisMerge(dst.ScheduleNodeTransformer): +class CartesianAxisMerge(stree.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. Can do: - merge a given axis with the next maps at the same recursion level - - can overcompute (eager) do allow for more merging at a cost of an if + - can overcompute (eager) to allow for more merging at the cost of an if Args: axis: AxisIterator to be merged @@ -169,8 +171,8 @@ def __str__(self) -> str: def _merge_node( self, - node: dst.ScheduleTreeNode, - nodes: list[dst.ScheduleTreeNode], + node: stree.ScheduleTreeNode, + nodes: list[stree.ScheduleTreeNode], ) -> int: """Direct code to the correct resolver for the node (e.g. visitor) @@ -179,13 +181,13 @@ def _merge_node( behavior (e.g. IfScope before ControlFlowScope) """ - if isinstance(node, dst.MapScope): + if isinstance(node, stree.MapScope): return self._map_overcompute_merge(node, nodes) - elif isinstance(node, dst.IfScope): + elif isinstance(node, stree.IfScope): return self._push_ifelse_down(node, nodes) - elif isinstance(node, dst.TaskletNode): + elif isinstance(node, stree.TaskletNode): return self._push_tasklet_down(node, nodes) - elif isinstance(node, dst.ControlFlowScope): + elif isinstance(node, stree.ControlFlowScope): return self._default_control_flow(node, nodes) else: ndsl_log.debug( @@ -195,8 +197,8 @@ def _merge_node( def _default_control_flow( self, - the_control_flow: dst.ControlFlowScope, - nodes: list[dst.ScheduleTreeNode], + the_control_flow: stree.ControlFlowScope, + nodes: list[stree.ScheduleTreeNode], ) -> int: if len(the_control_flow.children) != 0: return self._merge(the_control_flow) @@ -205,8 +207,8 @@ def _default_control_flow( def _push_tasklet_down( self, - the_tasklet: dst.TaskletNode, - nodes: list[dst.ScheduleTreeNode], + the_tasklet: stree.TaskletNode, + nodes: list[stree.ScheduleTreeNode], ) -> int: """Push tasklet into a consecutive map.""" in_memlets = the_tasklet.input_memlets() @@ -227,7 +229,7 @@ def _push_tasklet_down( # Attempt to push the tasklet in the next map ndsl_log.debug(" Push tasklet down into next map") next_node = nodes[next_index + 1] - if isinstance(next_node, dst.MapScope): + if isinstance(next_node, stree.MapScope): next_node.children.insert(0, the_tasklet) the_tasklet.parent = next_node nodes.remove(the_tasklet) @@ -237,8 +239,8 @@ def _push_tasklet_down( def _push_ifelse_down( self, - the_if: dst.IfScope, - nodes: list[dst.ScheduleTreeNode], + the_if: stree.IfScope, + nodes: list[stree.ScheduleTreeNode], ) -> int: merged = 0 @@ -249,8 +251,8 @@ def _push_ifelse_down( for else_index in range(if_index + 1, len(nodes)): else_node = nodes[else_index] if else_index < len(nodes) and ( - isinstance(else_node, dst.ElseScope) - or isinstance(else_node, dst.ElifScope) + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) ): merged += self._merge_node(else_node, else_node.children) else: @@ -260,20 +262,21 @@ def _push_ifelse_down( # Gather all first maps - if they do not exists, get out all_maps = [] - if isinstance(the_if.children[0], dst.MapScope): + if isinstance(the_if.children[0], stree.MapScope): all_maps.append(the_if.children[0]) else: return merged for else_index in range(if_index + 1, len(nodes)): else_node = nodes[else_index] if else_index < len(nodes) and ( - isinstance(else_node, dst.ElseScope) - or isinstance(else_node, dst.ElifScope) + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) ): - if isinstance(else_node.children[0], dst.MapScope): + if isinstance(else_node.children[0], stree.MapScope): all_maps.append(else_node.children[0]) else: return merged + else: break @@ -295,8 +298,8 @@ def _push_ifelse_down( # Swap ELIF/ELSE & maps for else_index in range(if_index + 1, len(nodes)): if else_index < len(nodes) and ( - isinstance(nodes[else_index], dst.ElseScope) - or isinstance(nodes[else_index], dst.ElifScope) + isinstance(nodes[else_index], stree.ElseScope) + or isinstance(nodes[else_index], stree.ElifScope) ): swap_node_position_in_tree( nodes[else_index], nodes[else_index].children[0] @@ -305,23 +308,23 @@ def _push_ifelse_down( break # Merge the Maps - assert isinstance(nodes[if_index], dst.MapScope) + assert isinstance(nodes[if_index], stree.MapScope) merged += self._map_overcompute_merge(nodes[if_index], nodes) return merged def _map_overcompute_merge( self, - the_map: dst.MapScope, - nodes: list[dst.ScheduleTreeNode], + the_map: stree.MapScope, + nodes: list[stree.ScheduleTreeNode], ) -> int: if _last_node(nodes, the_map): return 0 next_node = _get_next_node(nodes, the_map) - # If we the next node is not a MapScope - recurse - if not isinstance(next_node, dst.MapScope): + # If the next node is not a MapScope - recurse + if not isinstance(next_node, stree.MapScope): merged = self._merge_node(next_node, nodes) new_next_node = _get_next_node(nodes, the_map) if new_next_node == next_node: @@ -332,7 +335,7 @@ def _map_overcompute_merge( if not _can_merge_axis_maps(the_map, next_node, self.axis): return 0 - # Only for maps in K: + # Over compute to merge: # - force-merge by expanding the ranges # - then, guard children to only run in their respective range first_range = the_map.node.map.range @@ -348,7 +351,7 @@ def _map_overcompute_merge( ) ndsl_log.debug( - f" Merge K map: {first_range} ⋃ {second_range} -> {merged_range}" + f" Merge {self.axis.name} map: {first_range} ⋃ {second_range} -> {merged_range}" ) # push IfScope down if children are just maps @@ -361,14 +364,13 @@ def _map_overcompute_merge( merged_range=merged_range, original_range=second_range, ).visit(next_node) - merged_children: List[dst.MapScope] = [ + merged_children: list[stree.MapScope] = [ *first_map.children, *second_map.children, ] first_map.children = merged_children # TODO also merge containers and symbols (if applicable) - first_map.node.map.range = merged_range # delete now-merged second_map @@ -376,7 +378,7 @@ def _map_overcompute_merge( return 1 - def _merge(self, node: dst.ScheduleTreeRoot | dst.ScheduleTreeScope) -> int: + def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int: merged = 0 if __debug__: @@ -393,7 +395,7 @@ def _merge(self, node: dst.ScheduleTreeRoot | dst.ScheduleTreeScope) -> int: return merged - def visit_ScheduleTreeRoot(self, node: dst.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: """Merge as many maps as possible. The algorithm works as follows: diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py index 2e9642c0..0626133e 100644 --- a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -1,6 +1,6 @@ from enum import Enum -import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet from ndsl import ndsl_log @@ -19,8 +19,8 @@ def as_cartesian_index(self) -> int: def no_data_dependencies_on_cartesian_axis( - first: dst.MapScope, - second: dst.MapScope, + first: stree.MapScope, + second: stree.MapScope, axis: AxisIterator, ) -> bool: """Check for read after write. Allow when indexation on the axis @@ -53,8 +53,8 @@ def no_data_dependencies_on_cartesian_axis( def no_data_dependencies( - first: dst.MapScope, - second: dst.MapScope, + first: stree.MapScope, + second: stree.MapScope, restrict_check_to_k: bool = False, ) -> bool: write_collector = MemletCollector(collect_reads=False) @@ -91,7 +91,7 @@ def no_data_dependencies( return True -class MemletCollector(dst.ScheduleNodeVisitor): +class MemletCollector(stree.ScheduleNodeVisitor): """Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls.""" in_memlets: list[Memlet] @@ -106,13 +106,13 @@ def __init__( self.in_memlets = [] self.out_memlets = [] - def visit_TaskletNode(self, node: dst.TaskletNode) -> None: + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: if self._collect_reads: self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) if self._collect_writes: self.out_memlets.extend([memlet for memlet in node.out_memlets.values()]) - def visit_LibraryCall(self, node: dst.LibraryCall) -> None: + def visit_LibraryCall(self, node: stree.LibraryCall) -> None: if self._collect_reads: if isinstance(node.in_memlets, set): self.in_memlets.extend(node.in_memlets) @@ -130,7 +130,7 @@ def visit_LibraryCall(self, node: dst.LibraryCall) -> None: ) -def has_dynamic_memlets(first: dst.MapScope, second: dst.MapScope) -> bool: +def has_dynamic_memlets(first: stree.MapScope, second: stree.MapScope) -> bool: first_collector = MemletCollector() second_collector = MemletCollector() first_collector.visit(first) diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index d1ce5b26..2583ec2d 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -1,13 +1,13 @@ -import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.sdfg.analysis.schedule_tree.treenodes as stree import dace.subsets as sbs -class SpecializeCartesianMaps(dst.ScheduleNodeVisitor): +class SpecializeCartesianMaps(stree.ScheduleNodeVisitor): def __init__(self, mappings: dict[str, int]) -> None: super().__init__() self._mappings = mappings - def visit_MapScope(self, node: dst.MapScope) -> None: + def visit_MapScope(self, node: stree.MapScope) -> None: dims = [] for p in node.node.map.params: if p == "__i": diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index e853a9e9..b965fc4a 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -1,10 +1,10 @@ from typing import Collection -import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.sdfg.analysis.schedule_tree.treenodes as stree def swap_node_position_in_tree( - top_node: dst.ScheduleTreeScope, child_node: dst.ScheduleTreeScope + top_node: stree.ScheduleTreeScope, child_node: stree.ScheduleTreeScope ) -> None: """Top node becomes child, child becomes top node""" # Take refs before swap @@ -24,7 +24,7 @@ def swap_node_position_in_tree( top_children.remove(top_node) -def detect_cycle(nodes: list[dst.ScheduleTreeNode], visited: set) -> None: +def detect_cycle(nodes: list[stree.ScheduleTreeNode], visited: set) -> None: """Detect the cycles in the tree.""" # Dev note: isn't there a DaCe tool for this?! for n in nodes: @@ -36,8 +36,8 @@ def detect_cycle(nodes: list[dst.ScheduleTreeNode], visited: set) -> None: def list_index( - collection: Collection[dst.ScheduleTreeNode], - node: dst.ScheduleTreeNode, + collection: Collection[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, ) -> int: """Check if node is in list with "is" operator.""" # compare with "is" to get memory comparison. ".index()" uses value comparison diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 898107a8..76bdae49 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Protocol -import dace.sdfg.analysis.schedule_tree.treenodes as dst +import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge @@ -18,14 +18,16 @@ def __repr__(self) -> str: @abstractmethod def run( self, - stree: dst.ScheduleTreeRoot, + stree: stree.ScheduleTreeRoot, verbose: bool = False, - ) -> dst.ScheduleTreeRoot: + ) -> stree.ScheduleTreeRoot: raise NotImplementedError("Missing implementation of run") class CPUPipeline(StreePipeline): - def __init__(self, passes: list[dst.ScheduleNodeTransformer] | None = None) -> None: + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: self.passes = ( passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] ) @@ -38,9 +40,9 @@ def __hash__(self) -> int: def run( self, - stree: dst.ScheduleTreeRoot, + stree: stree.ScheduleTreeRoot, verbose: bool = False, - ) -> dst.ScheduleTreeRoot: + ) -> stree.ScheduleTreeRoot: for p in self.passes: if verbose: print(f"[Stree OPT] {p}") @@ -50,7 +52,9 @@ def run( class GPUPipeline(StreePipeline): - def __init__(self, passes: list[dst.ScheduleNodeTransformer] | None = None) -> None: + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: self.passes = passes if passes else [] def __repr__(self) -> str: @@ -61,9 +65,9 @@ def __hash__(self) -> int: def run( self, - stree: dst.ScheduleTreeRoot, + stree: stree.ScheduleTreeRoot, verbose: bool = False, - ) -> dst.ScheduleTreeRoot: + ) -> stree.ScheduleTreeRoot: for p in self.passes: if verbose: print(f"[Stree OPT] {p}") From 2d58b2bfcd720b92d62b2c1f476e124823223a5d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 24 Oct 2025 09:21:05 -0400 Subject: [PATCH 34/35] Flip `Protocol` base class to the broader and cleaner ABC --- ndsl/dsl/dace/stree/pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 76bdae49..10fb77cd 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,12 +1,11 @@ -from abc import abstractmethod -from typing import Protocol +from abc import ABC, abstractmethod import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge -class StreePipeline(Protocol): +class StreePipeline(ABC): @abstractmethod def __hash__(self) -> int: raise NotImplementedError("Missing implementation of __hash__") From 9483da716171bd8b8035baa7c0c29040bd49dd83 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 24 Oct 2025 10:46:43 -0400 Subject: [PATCH 35/35] Lint --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 62a8c126..262a6021 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -121,9 +121,11 @@ def __init__(self, axis: AxisIterator) -> None: def visit_MapScope( self, map_scope: stree.MapScope, - axis_replacements: dict[str, str] = {}, + axis_replacements: dict[str, str] | None = None, **kwargs: Any, ) -> None: + if axis_replacements is None: + axis_replacements = {} for index, param in enumerate(map_scope.node.params): sanitized_param = _sanitize_axis(self.axis, param) axis_replacements[param] = sanitized_param @@ -136,13 +138,15 @@ def visit_MapScope( def visit_TaskletNode( self, node: stree.TaskletNode, - axis_rpl_dict: dict[str, str] = {}, + axis_replacements: dict[str, str] | None = None, **kwargs: Any, ) -> None: + if axis_replacements is None: + axis_replacements = {} for memlets in node.in_memlets.values(): - memlets.replace(axis_rpl_dict) + memlets.replace(axis_replacements) for memlets in node.out_memlets.values(): - memlets.replace(axis_rpl_dict) + memlets.replace(axis_replacements) class CartesianAxisMerge(stree.ScheduleNodeTransformer):