diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 4d959d15..a5c278ac 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -16,6 +16,7 @@ from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.transformation.auto.auto_optimize import make_transients_persistent +from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass from gt4py import storage @@ -34,7 +35,11 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations import ( + AxisIterator, + CartesianAxisMerge, + CartesianRefineTransients, +) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -48,9 +53,6 @@ _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" -_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] -"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" - def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -126,7 +128,7 @@ def _simplify( # We disable ScalarToSymbolPromotion because it might push symbols onto edges # that DaCe itself can't parse anymore later, e.g. casts, inlined function # calls or (complicated) field accesses. - skip=["ScalarToSymbolPromotion"], + skip={"ScalarToSymbolPromotion"}, ).apply_pass(sdfg, {}) @@ -157,14 +159,37 @@ def _build_sdfg( _simplify(sdfg) if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + # Break all loops into uni-dimensional loops to simplify optimizations + sdfg.apply_transformations_repeated(MapExpansion, validate=True) stree = sdfg.as_schedule_tree() with DaCeProgress(config, "Schedule Tree: optimization"): if config.is_gpu_backend(): GPUPipeline().run(stree) else: - CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + passes = [] + + if config.get_backend() == "dace:cpu_kfirst": + passes.extend( + [ + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianAxisMerge(AxisIterator._K), + CartesianRefineTransients(config.get_backend()), + ] + ) + else: + passes.extend( + [ + CartesianAxisMerge(AxisIterator._K), + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianRefineTransients(config.get_backend()), + ] + ) + CPUPipeline(passes=passes).run(stree) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 47c764b3..8e371ee9 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,5 @@ from .axis_merge import AxisIterator, CartesianAxisMerge +from .refine_transients import CartesianRefineTransients -__all__ = ["AxisIterator", "CartesianAxisMerge"] +__all__ = ["AxisIterator", "CartesianAxisMerge", "CartesianRefineTransients"] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 262a6021..74fe9e02 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -221,7 +221,7 @@ def _push_tasklet_down( return 0 # Tasklet is a callback next_index = list_index(nodes, the_tasklet) - if next_index == len(nodes): + if next_index == len(nodes) - 1: return 0 # Last node - done next_node = nodes[next_index + 1] @@ -323,7 +323,10 @@ def _map_overcompute_merge( nodes: list[stree.ScheduleTreeNode], ) -> int: if _last_node(nodes, the_map): - return 0 + merged = 0 + for child in the_map.children: + merged += self._merge_node(child, the_map.children) + return merged next_node = _get_next_node(nodes, the_map) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py new file mode 100644 index 00000000..967d6923 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import warnings +from types import TracebackType + +import dace.data +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator + + +def _change_index_of_tuple( + old_tuple: tuple[int, ...], index: int, value: int = 1 +) -> tuple[int, ...]: + """Return a copy of the given tuple with `old_tuple[index]` being replaced by `value`. + + Args: + old_tuple: to be copied + index: at which index to replace a value + value: to replace `old_tuple[index]` + """ + new_list = list(old_tuple) + new_list[index] = value + return tuple(new_list) + + +def _reduce_cartesian_axes_size_to_1( + transient_map_access: set[stree.nodes.MapEntry], + transient_data: dace.data.Data, + ijk_order: tuple[int, int, int], +) -> bool: + """Reduce dimension size of transient to 1 if their are accessed only + in a single Map for the cartesian dimensions""" + refined = False + for axis in AxisIterator: + access_in_map_count = 0 + for map_entry in transient_map_access: + if axis.as_str() in map_entry.params[0]: + access_in_map_count += 1 + + if access_in_map_count != 1: + continue + + # This transient is used in exactly one single-Axis map + # therefore this dimension can be removed. BUT we are not truly + # removing it, we are reducing it to 1 to not have to deal + # with different slicing. + transient_data.shape = _change_index_of_tuple( + transient_data.shape, + axis.as_cartesian_index(), + value=1, + ) + transient_data.set_strides_from_layout(*ijk_order) + refined = True + + return refined + + +class _CartesianMapNesting: + def __init__( + self, + cartesian_current_map_nesting: list[stree.nodes.MapEntry | None], + node: stree.MapScope, + ) -> None: + self._cartesian_current_map_nesting = cartesian_current_map_nesting + self._node = node + + def __enter__(self) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = self._node.node + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = self._node.node + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = self._node.node + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = None + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = None + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = None + + +class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor): + """Collect all access of transient arrays per Maps.""" + + def __init__(self) -> None: + self.transient_map_access: dict[str, set[stree.nodes.MapEntry]] = {} + self._cartesian_current_map_nesting: list[stree.nodes.MapEntry | None] = [ + None, + None, + None, + ] + + def __str__(self) -> str: + return "CartesianCollectMaps" + + def visit_MapScope(self, node: stree.MapScope) -> None: + if len(node.node.params) > 1: + ndsl_log.debug( + "Can't apply CartesianRefineTransients, require unidimensional Maps" + ) + return + + with _CartesianMapNesting(self._cartesian_current_map_nesting, node): + for child in node.children: + self.visit(child) + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for memlet in [*node.input_memlets(), *node.output_memlets()]: + if self.containers[memlet.data].transient: + for map_entry in self._cartesian_current_map_nesting: + if map_entry is not None: + self.transient_map_access[memlet.data].add(map_entry) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for name, data in self.containers.items(): + if data.transient: + self.transient_map_access[name] = set() + + for child in node.children: + self.visit(child) + + +class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): + """Rebuild memlets from containers to ensure they are scope to the right size.""" + + def __str__(self) -> str: + return "RefineTransientAxis" + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for name, memlet in node.in_memlets.items(): + if self.containers[memlet.data].transient: + node.in_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + for name, memlet in node.out_memlets.items(): + if self.containers[memlet.data].transient: + node.out_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for child in node.children: + self.visit(child) + + +class CartesianRefineTransients(stree.ScheduleNodeTransformer): + """Refine (reduce dimensionality) of transients based on their true use in + the cartesian dimensions. + + + It can do: + - Looking at usage of a transient in a cartesian axis (e.g. loop over a + cartesian axis) it will reduce that axis to 1 if it exists in _only one_. + It should but cannot do/will bug if: + - Dataflow analysis on the axis to prevent reducing an axis to one where + the transient is used with offset, leading to faulty numerics + - Using the dataflow above, we can reduce the dimensions to the correct lowest + size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element + buffer) + - Current action when detecting a valide candidate is to reduce the size of the dimension + to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce + the cache access significantly. This also has been implemented to _not_ deal with offset/slicing + downstream impact of removing an axis. Nevertheless the xis should be removed if it's not + used. + + More tests: + - Test for dataflow with offset + - Test for I/J refine but not in K + - Test for J refine but not in I or K + - Test with dataflow: if/else, while, etc. + - Test with ForScope (FORWARD/BACKWARD) instead of Map + """ + + def __init__(self, backend: str) -> None: + warnings.warn( + "CartesianRefineTransients is a WIP. It's usage is *severaly* limited" + "and will most likely lead to bad numerics. Check the docs, check utest.", + UserWarning, + stacklevel=2, + ) + + if backend in ["dace:cpu_kfirst"]: + self.ijk_order = (2, 1, 0) + elif backend in ["dace:cpu", "dace:gpu"]: + self.ijk_order = (1, 0, 2) + else: + raise NotImplementedError( + "[Schedule Tree Opt] CartesianRefineTransient not implemented for " + f"backend {backend}" + ) + + def __str__(self) -> str: + return "CartesianRefineTransients" + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + collect_map = CollectTransientAccessInCartesianMaps() + collect_map.visit(node) + + # Remove Axis + refined_transient = 0 + for name, data in node.containers.items(): + if not data.transient: + continue + refined = _reduce_cartesian_axes_size_to_1( + collect_map.transient_map_access[name], + data, + self.ijk_order, + ) + refined_transient += 1 if refined else 0 + + RebuildMemletsFromContainers().visit(node) + + ndsl_log.debug(f"🚀 {refined_transient} Transient refined") diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 721ae38b..065cc298 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -75,7 +75,6 @@ def check_for_quantity(object_: object) -> None: obj=self, config=self._dace_config, ) - print(type(self)) def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index 6ed29479..fd08a1cb 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -1,11 +1,31 @@ -from ndsl import StencilFactory, orchestrate +from types import TracebackType + +import dace + +import ndsl.dsl.dace.orchestration as orch +from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField +class StreeOptimization: + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + + def stencil_A(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = in_field @@ -33,20 +53,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField) -> None: self.stencil_B(in_field, out_field) -def test_stree_roundtrip_no_opt() -> None: - """Dev Note: - - The below code successfully merges top level K loop (2 loops) - How do we test it?! Running doesn't test merging and the compilation - is a near-black box. We could reach in the `dace_config.compiled_sdfg` - cache but it's keyed on the dace.program and if we can reach the program - well we can reach the SDFG and turn it into an stree for verification - Should we run orchestration "by hand"? - Can we intercept the `stree` ? After all we just want to check that! - - Test is deactivated for now""" - - return +def test_stree_merge_maps() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend="dace:cpu" @@ -56,14 +63,63 @@ def test_stree_roundtrip_no_opt() -> None: in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - # Temporarily flip the internal switch - import ndsl.dsl.dace.orchestration as orch + with StreeOptimization(): + code(in_qty, out_qty) - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True - orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] + all_maps = [ + (me, state) + for me, state in sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + + assert len(all_maps) == 3 + assert (out_qty.field[:] == 4).all() + + +class LocalRefineableCode(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil_A(in_field, self.tmp) + self.stencil_A(self.tmp, out_field) + + +def test_stree_roundtrip_transient_is_refined() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = LocalRefineableCode(stencil_factory, quantity_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - code(in_qty, out_qty) + with StreeOptimization(): + code(in_qty, out_qty) - assert (out_qty.field[:] == 4).all() + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + for array in sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1)