From 833b419b0345e25aaca618fc04326ad824b68f5b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 3 Nov 2025 13:52:44 -0500 Subject: [PATCH 01/12] Fix axis merge --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 262a6021..74fe9e02 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -221,7 +221,7 @@ def _push_tasklet_down( return 0 # Tasklet is a callback next_index = list_index(nodes, the_tasklet) - if next_index == len(nodes): + if next_index == len(nodes) - 1: return 0 # Last node - done next_node = nodes[next_index + 1] @@ -323,7 +323,10 @@ def _map_overcompute_merge( nodes: list[stree.ScheduleTreeNode], ) -> int: if _last_node(nodes, the_map): - return 0 + merged = 0 + for child in the_map.children: + merged += self._merge_node(child, the_map.children) + return merged next_node = _get_next_node(nodes, the_map) From 93f6b6d5812f476d5016707f7bf0aeb593a8a651 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 3 Nov 2025 13:54:17 -0500 Subject: [PATCH 02/12] Remove debug print --- ndsl/dsl/ndsl_runtime.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 721ae38b..065cc298 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -75,7 +75,6 @@ def check_for_quantity(object_: object) -> None: obj=self, config=self._dace_config, ) - print(type(self)) def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) From 00e717921ce505a3136bc21e65156bc41f7f1176 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 3 Nov 2025 13:54:41 -0500 Subject: [PATCH 03/12] Refine transients + utests --- ndsl/dsl/dace/orchestration.py | 45 +++++- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../stree/optimizations/refine_transients.py | 151 ++++++++++++++++++ tests/stree_optimizer/test_optimization.py | 71 ++++++-- 4 files changed, 246 insertions(+), 24 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/refine_transients.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 4d959d15..ff06d9ec 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -11,11 +11,13 @@ from dace import method as dace_method from dace import nodes from dace import program as dace_program +from dace.dtypes import AllocationLifetime from dace.dtypes import DeviceType as DaceDeviceType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.transformation.auto.auto_optimize import make_transients_persistent +from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass from gt4py import storage @@ -34,7 +36,11 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations import ( + AxisIterator, + CartesianAxisMerge, + CartesianRefineTransients, +) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -45,12 +51,9 @@ from ndsl.quantity import Quantity, State -_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = True """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" -_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] -"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" - def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -126,7 +129,7 @@ def _simplify( # We disable ScalarToSymbolPromotion because it might push symbols onto edges # that DaCe itself can't parse anymore later, e.g. casts, inlined function # calls or (complicated) field accesses. - skip=["ScalarToSymbolPromotion"], + skip={"ScalarToSymbolPromotion"}, ).apply_pass(sdfg, {}) @@ -157,16 +160,44 @@ def _build_sdfg( _simplify(sdfg) if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + # Break all loops into uni-dimensional loops to simplify optimizations + sdfg.apply_transformations_repeated(MapExpansion, validate=True) stree = sdfg.as_schedule_tree() with DaCeProgress(config, "Schedule Tree: optimization"): if config.is_gpu_backend(): GPUPipeline().run(stree) else: - CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + # passes = [MapExpand()] + passes = [] + + if config.get_backend() == "dace:cpu_kfirst": + passes.extend( + [ + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianAxisMerge(AxisIterator._K), + CartesianRefineTransients((2, 1, 0)), + ] + ) + else: + passes.extend( + [ + CartesianAxisMerge(AxisIterator._K), + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianRefineTransients((1, 0, 2)), + ] + ) + CPUPipeline(passes=passes).run(stree) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): + for _, data in stree.get_root().containers.items(): + if data.transient: + data.lifetime = AllocationLifetime.State + sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) # Make the transients array persistents diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 47c764b3..8e371ee9 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,5 @@ from .axis_merge import AxisIterator, CartesianAxisMerge +from .refine_transients import CartesianRefineTransients -__all__ = ["AxisIterator", "CartesianAxisMerge"] +__all__ = ["AxisIterator", "CartesianAxisMerge", "CartesianRefineTransients"] diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py new file mode 100644 index 00000000..11172b83 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import dace.sdfg.analysis.schedule_tree.treenodes as stree +import dace.data +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( + AxisIterator, +) +from ndsl import ndsl_log + + +def _zero_index_of_tuple(tuple_: tuple[int, ...], index: int) -> tuple[int, ...]: + new_list = list(tuple_) + new_list[index] = 1 + return tuple(new_list) + + +def _reduce_axis_size_to_1( + axis_iterator: AxisIterator, + transient_map_access: set[stree.nodes.MapEntry], + data: dace.data.Data, + ijk_order: tuple[int, int, int], +) -> bool: + access_in_map_count = 0 + for map_entry in transient_map_access: + if axis_iterator.value[0] in map_entry.params[0]: + access_in_map_count += 1 + + # If this transient is used in exactly one single-Axis map + # therefore this dimension can be removed + if access_in_map_count != 1: + return False + + data.shape = _zero_index_of_tuple(data.shape, axis_iterator.value[1]) + data.set_strides_from_layout(*ijk_order) + return True + + +class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor): + """Collect all access of transient arrays per Maps.""" + + def __init__(self) -> None: + self.transient_map_access: dict[str, set[stree.nodes.MapEntry]] = {} + self._cartesian_current_map_nesting: list[stree.nodes.MapEntry | None] = [ + None, + None, + None, + ] + + def __str__(self) -> str: + return "CartesianCollectMaps" + + def visit_MapScope(self, node: stree.MapScope) -> None: + if len(node.node.params) > 1: + ndsl_log.debug( + "Can't apply CartesianRefineTransients, require unidimensional Maps" + ) + + if AxisIterator._I.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[0] = node.node + elif AxisIterator._J.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[1] = node.node + elif AxisIterator._K.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[2] = node.node + + for child in node.children: + return self.visit(child) + + if AxisIterator._I.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[0] = None + elif AxisIterator._J.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[1] = None + elif AxisIterator._K.value[0] in node.node.params[0]: + self._cartesian_current_map_nesting[2] = None + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for memlet in node.input_memlets(): + if self.containers[memlet.data].transient: + for map_entry in self._cartesian_current_map_nesting: + if map_entry is not None: + self.transient_map_access[memlet.data].add(map_entry) + for memlet in node.output_memlets(): + if self.containers[memlet.data].transient: + for map_entry in self._cartesian_current_map_nesting: + if map_entry is not None: + self.transient_map_access[memlet.data].add(map_entry) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for name, data in self.containers.items(): + if data.transient: + self.transient_map_access[name] = set() + + for child in node.children: + self.visit(child) + + +class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): + """Rebuild memlets from containers to ensure they are scope to the right size.""" + + def __str__(self) -> str: + return "RefineTransientAxis" + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for name, memlet in node.in_memlets.items(): + if self.containers[memlet.data].transient: + node.in_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + for name, memlet in node.out_memlets.items(): + if self.containers[memlet.data].transient: + node.out_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for child in node.children: + self.visit(child) + + +class CartesianRefineTransients(stree.ScheduleNodeTransformer): + """ """ + + def __init__(self, ijk_order: tuple[int, int, int]) -> None: + self.ijk_order = ijk_order + + def __str__(self) -> str: + return "CartesianRefineTransients" + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + collect_map = CollectTransientAccessInCartesianMaps() + collect_map.visit(node) + + # Remove Axis + refined_transient = 0 + for name, data in node.containers.items(): + if data.transient: + refined = False + for axis in AxisIterator: + refined |= _reduce_axis_size_to_1( + axis, + collect_map.transient_map_access[name], + data, + self.ijk_order, + ) + refined_transient += 1 if refined else 0 + + RebuildMemletsFromContainers().visit(node) + + ndsl_log.debug(f"🚀 {refined_transient} Transient refined") diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index 6ed29479..0d6bee0e 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -1,7 +1,8 @@ -from ndsl import StencilFactory, orchestrate +import dace + +from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -33,26 +34,60 @@ def __call__(self, in_field: FloatField, out_field: FloatField) -> None: self.stencil_B(in_field, out_field) -def test_stree_roundtrip_no_opt() -> None: - """Dev Note: +def test_stree_merge_maps() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + code(in_qty, out_qty) + + assert len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) == 1 + sdfg = list(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values())[0] + all_maps = [ + (me, state) + for me, state in sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + + assert len(all_maps) == 3 + assert (out_qty.field[:] == 4).all() - The below code successfully merges top level K loop (2 loops) - How do we test it?! Running doesn't test merging and the compilation - is a near-black box. We could reach in the `dace_config.compiled_sdfg` - cache but it's keyed on the dace.program and if we can reach the program - well we can reach the SDFG and turn it into an stree for verification - Should we run orchestration "by hand"? - Can we intercept the `stree` ? After all we just want to check that! + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False - Test is deactivated for now""" - return +class LocalRefineableCode(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil_A(in_field, self.tmp) + self.stencil_A(self.tmp, out_field) + + +def test_stree_roundtrip_transient_is_refined() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend="dace:cpu" ) - code = TriviallyMergeableCode(stencil_factory) + code = LocalRefineableCode(stencil_factory, quantity_factory) in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") @@ -60,10 +95,14 @@ def test_stree_roundtrip_no_opt() -> None: import ndsl.dsl.dace.orchestration as orch orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True - orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] code(in_qty, out_qty) - assert (out_qty.field[:] == 4).all() + assert len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) == 1 + sdfg = list(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values())[0] + + for array in sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False From 4da7eeb5dff250c52feb236c269166bcd5d6dbf5 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 3 Nov 2025 14:03:07 -0500 Subject: [PATCH 04/12] Lint --- ndsl/dsl/dace/stree/optimizations/refine_transients.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 11172b83..7b241cf3 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,11 +1,10 @@ from __future__ import annotations -import dace.sdfg.analysis.schedule_tree.treenodes as stree import dace.data -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( - AxisIterator, -) +import dace.sdfg.analysis.schedule_tree.treenodes as stree + from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator def _zero_index_of_tuple(tuple_: tuple[int, ...], index: int) -> tuple[int, ...]: From 61d7ee1af1381f24f615567c909b1ad1bd631ff0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 3 Nov 2025 15:16:39 -0500 Subject: [PATCH 05/12] Revert to deactivating the experimental stree work --- ndsl/dsl/dace/orchestration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index ff06d9ec..5f3998d9 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -51,7 +51,7 @@ from ndsl.quantity import Quantity, State -_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = True +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" From 22bf96d98dbf89a7317ab98b79354d5a4ca36116 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 4 Nov 2025 15:27:50 -0500 Subject: [PATCH 06/12] Use context manager for `_INTERNAL__SCHEDULE_TREE_OPTIMIZATION` --- tests/stree_optimizer/test_optimization.py | 77 +++++++++++++--------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index 0d6bee0e..fd08a1cb 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -1,5 +1,8 @@ +from types import TracebackType + import dace +import ndsl.dsl.dace.orchestration as orch from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM @@ -7,6 +10,22 @@ from ndsl.dsl.typing import FloatField +class StreeOptimization: + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + + def stencil_A(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = in_field @@ -44,25 +63,24 @@ def test_stree_merge_maps() -> None: in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - # Temporarily flip the internal switch - import ndsl.dsl.dace.orchestration as orch - - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + with StreeOptimization(): + code(in_qty, out_qty) - code(in_qty, out_qty) - - assert len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) == 1 - sdfg = list(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values())[0] - all_maps = [ - (me, state) - for me, state in sdfg.sdfg.all_nodes_recursive() - if isinstance(me, dace.nodes.MapEntry) - ] - - assert len(all_maps) == 3 - assert (out_qty.field[:] == 4).all() + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] + all_maps = [ + (me, state) + for me, state in sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + assert len(all_maps) == 3 + assert (out_qty.field[:] == 4).all() class LocalRefineableCode(NDSLRuntime): @@ -91,18 +109,17 @@ def test_stree_roundtrip_transient_is_refined() -> None: in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - # Temporarily flip the internal switch - import ndsl.dsl.dace.orchestration as orch - - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + with StreeOptimization(): + code(in_qty, out_qty) - code(in_qty, out_qty) - - assert len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) == 1 - sdfg = list(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values())[0] - - for array in sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1) + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + for array in sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) From c795c831b71bed31443f2faab11b4fb99fc62d4a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 4 Nov 2025 15:28:05 -0500 Subject: [PATCH 07/12] Typo --- ndsl/dsl/dace/orchestration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5f3998d9..547121b5 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -170,7 +170,6 @@ def _build_sdfg( if config.is_gpu_backend(): GPUPipeline().run(stree) else: - # passes = [MapExpand()] passes = [] if config.get_backend() == "dace:cpu_kfirst": From bd967e99f90f89cd479dcb18e270325a4c03a862 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 4 Nov 2025 15:28:32 -0500 Subject: [PATCH 08/12] Clean refine transients code --- .../stree/optimizations/refine_transients.py | 90 +++++++++++-------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 7b241cf3..8d40bf1b 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,5 +1,7 @@ from __future__ import annotations +from types import TracebackType + import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree @@ -24,16 +26,49 @@ def _reduce_axis_size_to_1( if axis_iterator.value[0] in map_entry.params[0]: access_in_map_count += 1 - # If this transient is used in exactly one single-Axis map - # therefore this dimension can be removed if access_in_map_count != 1: return False + # If this transient is used in exactly one single-Axis map + # therefore this dimension can be removed. BUT we are not truly + # removing out, we are reducing it to 1 to not have to deal + # with different slicing data.shape = _zero_index_of_tuple(data.shape, axis_iterator.value[1]) data.set_strides_from_layout(*ijk_order) return True +class _CartesianMapNesting: + def __init__( + self, + cartesian_current_map_nesting: list[stree.nodes.MapEntry | None], + node: stree.MapScope, + ) -> None: + self._cartesian_current_map_nesting = cartesian_current_map_nesting + self._node = node + + def __enter__(self) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = self._node.node + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = self._node.node + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = self._node.node + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = None + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = None + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = None + + class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor): """Collect all access of transient arrays per Maps.""" @@ -53,31 +88,14 @@ def visit_MapScope(self, node: stree.MapScope) -> None: ndsl_log.debug( "Can't apply CartesianRefineTransients, require unidimensional Maps" ) + return - if AxisIterator._I.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[0] = node.node - elif AxisIterator._J.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[1] = node.node - elif AxisIterator._K.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[2] = node.node - - for child in node.children: - return self.visit(child) - - if AxisIterator._I.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[0] = None - elif AxisIterator._J.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[1] = None - elif AxisIterator._K.value[0] in node.node.params[0]: - self._cartesian_current_map_nesting[2] = None + with _CartesianMapNesting(self._cartesian_current_map_nesting, node): + for child in node.children: + return self.visit(child) def visit_TaskletNode(self, node: stree.TaskletNode) -> None: - for memlet in node.input_memlets(): - if self.containers[memlet.data].transient: - for map_entry in self._cartesian_current_map_nesting: - if map_entry is not None: - self.transient_map_access[memlet.data].add(map_entry) - for memlet in node.output_memlets(): + for memlet in [*node.input_memlets(), *node.output_memlets()]: if self.containers[memlet.data].transient: for map_entry in self._cartesian_current_map_nesting: if map_entry is not None: @@ -119,7 +137,8 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: class CartesianRefineTransients(stree.ScheduleNodeTransformer): - """ """ + """Refine (reduce dimensionality) of transients based on their true use in + the cartesian dimensions.""" def __init__(self, ijk_order: tuple[int, int, int]) -> None: self.ijk_order = ijk_order @@ -134,16 +153,17 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # Remove Axis refined_transient = 0 for name, data in node.containers.items(): - if data.transient: - refined = False - for axis in AxisIterator: - refined |= _reduce_axis_size_to_1( - axis, - collect_map.transient_map_access[name], - data, - self.ijk_order, - ) - refined_transient += 1 if refined else 0 + if not data.transient: + continue + refined = False + for axis in AxisIterator: + refined |= _reduce_axis_size_to_1( + axis, + collect_map.transient_map_access[name], + data, + self.ijk_order, + ) + refined_transient += 1 if refined else 0 RebuildMemletsFromContainers().visit(node) From c2ca2332da2a8a2884037d15eeb9f4cd131aa400 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Nov 2025 10:04:48 -0500 Subject: [PATCH 09/12] Derive common strides layout from backend Refactor code to make re-sizing more compact in main algorithm Fix bad recursion Add todo list and verbose state of optimization --- ndsl/dsl/dace/orchestration.py | 4 +- .../stree/optimizations/refine_transients.py | 123 +++++++++++++----- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 547121b5..7d18caaf 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -178,7 +178,7 @@ def _build_sdfg( CartesianAxisMerge(AxisIterator._I), CartesianAxisMerge(AxisIterator._J), CartesianAxisMerge(AxisIterator._K), - CartesianRefineTransients((2, 1, 0)), + CartesianRefineTransients(config.get_backend()), ] ) else: @@ -187,7 +187,7 @@ def _build_sdfg( CartesianAxisMerge(AxisIterator._K), CartesianAxisMerge(AxisIterator._I), CartesianAxisMerge(AxisIterator._J), - CartesianRefineTransients((1, 0, 2)), + CartesianRefineTransients(config.get_backend()), ] ) CPUPipeline(passes=passes).run(stree) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 8d40bf1b..e1e69b2a 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -7,35 +7,54 @@ from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +import warnings -def _zero_index_of_tuple(tuple_: tuple[int, ...], index: int) -> tuple[int, ...]: - new_list = list(tuple_) - new_list[index] = 1 +def _change_index_of_tuple( + old_tuple: tuple[int, ...], index: int, value: int = 1 +) -> tuple[int, ...]: + """Return a copy of the given tuple with `old_tuple[index]` being replaced by `value`. + + Args: + old_tuple: to be copied + index: at which index to replace a value + value: to replace `old_tuple[index]` + """ + new_list = list(old_tuple) + new_list[index] = value return tuple(new_list) -def _reduce_axis_size_to_1( - axis_iterator: AxisIterator, +def _reduce_cartesian_axes_size_to_1( transient_map_access: set[stree.nodes.MapEntry], - data: dace.data.Data, + transient_data: dace.data.Data, ijk_order: tuple[int, int, int], ) -> bool: - access_in_map_count = 0 - for map_entry in transient_map_access: - if axis_iterator.value[0] in map_entry.params[0]: - access_in_map_count += 1 - - if access_in_map_count != 1: - return False - - # If this transient is used in exactly one single-Axis map - # therefore this dimension can be removed. BUT we are not truly - # removing out, we are reducing it to 1 to not have to deal - # with different slicing - data.shape = _zero_index_of_tuple(data.shape, axis_iterator.value[1]) - data.set_strides_from_layout(*ijk_order) - return True + """Reduce dimension size of transient to 1 if their are accessed only + in a single Map for the cartesian dimensions""" + refined = False + for axis in AxisIterator: + access_in_map_count = 0 + for map_entry in transient_map_access: + if axis.as_str() in map_entry.params[0]: + access_in_map_count += 1 + + if access_in_map_count != 1: + continue + + # This transient is used in exactly one single-Axis map + # therefore this dimension can be removed. BUT we are not truly + # removing it, we are reducing it to 1 to not have to deal + # with different slicing. + transient_data.shape = _change_index_of_tuple( + transient_data.shape, + axis.as_cartesian_index(), + value=1, + ) + transient_data.set_strides_from_layout(*ijk_order) + refined = True + + return refined class _CartesianMapNesting: @@ -92,7 +111,7 @@ def visit_MapScope(self, node: stree.MapScope) -> None: with _CartesianMapNesting(self._cartesian_current_map_nesting, node): for child in node.children: - return self.visit(child) + self.visit(child) def visit_TaskletNode(self, node: stree.TaskletNode) -> None: for memlet in [*node.input_memlets(), *node.output_memlets()]: @@ -138,10 +157,49 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: class CartesianRefineTransients(stree.ScheduleNodeTransformer): """Refine (reduce dimensionality) of transients based on their true use in - the cartesian dimensions.""" - - def __init__(self, ijk_order: tuple[int, int, int]) -> None: - self.ijk_order = ijk_order + the cartesian dimensions. + + + It can do: + - Looking at usage of a transient in a cartesian axis (e.g. loop over a + cartesian axis) it will reduce that axis to 1 if it exists in _only one_. + It should but cannot do/will bug if: + - Dataflow analysis on the axis to prevent reducing an axis to one where + the transient is used with offset, leading to faulty numerics + - Using the dataflow above, we can reduce the dimensions to the correct lowest + size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element + buffer) + - Current action when detecting a valide candidate is to reduce the size of the dimension + to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce + the cache access significantly. This also has been implemented to _not_ deal with offset/slicing + downstream impact of removing an axis. Nevertheless the xis should be removed if it's not + used. + + More tests: + - Test for dataflow with offset + - Test for I/J refine but not in K + - Test for J refine but not in I or K + - Test with dataflow: if/else, while, etc. + - Test with ForScope (FORWARD/BACKWARD) instead of Map + """ + + def __init__(self, backend: str) -> None: + warnings.warn( + "CartesianRefineTransients is a WIP. It's usage is *severaly* limited" + "and will most likely lead to bad numerics. Check the docs, check utest.", + UserWarning, + stacklevel=2, + ) + + if backend in ["dace:cpu_kfirst"]: + self.ijk_order = (2, 1, 0) + elif backend in ["dace:cpu", "dace:gpu"]: + self.ijk_order = (1, 0, 2) + else: + raise NotImplementedError( + "[Schedule Tree Opt] CartesianRefineTransient not implemented for " + f"backend {backend}" + ) def __str__(self) -> str: return "CartesianRefineTransients" @@ -155,14 +213,11 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: for name, data in node.containers.items(): if not data.transient: continue - refined = False - for axis in AxisIterator: - refined |= _reduce_axis_size_to_1( - axis, - collect_map.transient_map_access[name], - data, - self.ijk_order, - ) + refined = _reduce_cartesian_axes_size_to_1( + collect_map.transient_map_access[name], + data, + self.ijk_order, + ) refined_transient += 1 if refined else 0 RebuildMemletsFromContainers().visit(node) From 762c6146441339e433af8971848e0d05a4b08313 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Nov 2025 10:08:32 -0500 Subject: [PATCH 10/12] Lint --- ndsl/dsl/dace/stree/optimizations/refine_transients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index e1e69b2a..967d6923 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from types import TracebackType import dace.data @@ -7,7 +8,6 @@ from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator -import warnings def _change_index_of_tuple( From bb5ac8b4e96b5a109463ac35c40923975ab88d1b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Nov 2025 11:55:55 -0500 Subject: [PATCH 11/12] Remove `transient` to `State` lifetime - keep PR on target --- ndsl/dsl/dace/orchestration.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 7d18caaf..879ca658 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -193,10 +193,6 @@ def _build_sdfg( CPUPipeline(passes=passes).run(stree) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): - for _, data in stree.get_root().containers.items(): - if data.transient: - data.lifetime = AllocationLifetime.State - sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) # Make the transients array persistents From 1567ff0afad49cd7070573731c61f679570a86f5 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Nov 2025 11:58:02 -0500 Subject: [PATCH 12/12] Lint --- ndsl/dsl/dace/orchestration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 879ca658..a5c278ac 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -11,7 +11,6 @@ from dace import method as dace_method from dace import nodes from dace import program as dace_program -from dace.dtypes import AllocationLifetime from dace.dtypes import DeviceType as DaceDeviceType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible