Skip to content
37 changes: 31 additions & 6 deletions ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dace.frontend.python.common import SDFGConvertible
from dace.frontend.python.parser import DaceProgram
from dace.transformation.auto.auto_optimize import make_transients_persistent
from dace.transformation.dataflow import MapExpansion
from dace.transformation.helpers import get_parent_map
from dace.transformation.passes.simplify import SimplifyPass
from gt4py import storage
Expand All @@ -34,7 +35,11 @@
sdfg_nan_checker,
)
from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline
from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge
from ndsl.dsl.dace.stree.optimizations import (
AxisIterator,
CartesianAxisMerge,
CartesianRefineTransients,
)
from ndsl.dsl.dace.utils import (
DaCeProgress,
memory_static_analysis,
Expand All @@ -48,9 +53,6 @@
_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False
"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer."""

_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)]
"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration."""


def dace_inhibitor(func: Callable) -> Callable:
"""Triggers callback generation wrapping `func` while doing DaCe parsing."""
Expand Down Expand Up @@ -126,7 +128,7 @@ def _simplify(
# We disable ScalarToSymbolPromotion because it might push symbols onto edges
# that DaCe itself can't parse anymore later, e.g. casts, inlined function
# calls or (complicated) field accesses.
skip=["ScalarToSymbolPromotion"],
skip={"ScalarToSymbolPromotion"},
).apply_pass(sdfg, {})


Expand Down Expand Up @@ -157,14 +159,37 @@ def _build_sdfg(
_simplify(sdfg)

if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION:
# Here be 🐉 - but tests exists in test_optimization.py
with DaCeProgress(config, "Schedule Tree: generate from SDFG"):
# Break all loops into uni-dimensional loops to simplify optimizations
sdfg.apply_transformations_repeated(MapExpansion, validate=True)
stree = sdfg.as_schedule_tree()

with DaCeProgress(config, "Schedule Tree: optimization"):
if config.is_gpu_backend():
GPUPipeline().run(stree)
else:
CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree)
passes = []

if config.get_backend() == "dace:cpu_kfirst":
passes.extend(
[
CartesianAxisMerge(AxisIterator._I),
CartesianAxisMerge(AxisIterator._J),
CartesianAxisMerge(AxisIterator._K),
CartesianRefineTransients(config.get_backend()),
]
)
else:
passes.extend(
[
CartesianAxisMerge(AxisIterator._K),
CartesianAxisMerge(AxisIterator._I),
CartesianAxisMerge(AxisIterator._J),
CartesianRefineTransients(config.get_backend()),
]
)
CPUPipeline(passes=passes).run(stree)

with DaCeProgress(config, "Schedule Tree: go back to SDFG"):
sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"})
Expand Down
3 changes: 2 additions & 1 deletion ndsl/dsl/dace/stree/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .axis_merge import AxisIterator, CartesianAxisMerge
from .refine_transients import CartesianRefineTransients


__all__ = ["AxisIterator", "CartesianAxisMerge"]
__all__ = ["AxisIterator", "CartesianAxisMerge", "CartesianRefineTransients"]
7 changes: 5 additions & 2 deletions ndsl/dsl/dace/stree/optimizations/axis_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _push_tasklet_down(
return 0 # Tasklet is a callback

next_index = list_index(nodes, the_tasklet)
if next_index == len(nodes):
if next_index == len(nodes) - 1:
return 0 # Last node - done

next_node = nodes[next_index + 1]
Expand Down Expand Up @@ -323,7 +323,10 @@ def _map_overcompute_merge(
nodes: list[stree.ScheduleTreeNode],
) -> int:
if _last_node(nodes, the_map):
return 0
merged = 0
for child in the_map.children:
merged += self._merge_node(child, the_map.children)
return merged

next_node = _get_next_node(nodes, the_map)

Expand Down
225 changes: 225 additions & 0 deletions ndsl/dsl/dace/stree/optimizations/refine_transients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from __future__ import annotations

import warnings
from types import TracebackType

import dace.data
import dace.sdfg.analysis.schedule_tree.treenodes as stree

from ndsl import ndsl_log
from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator


def _change_index_of_tuple(
old_tuple: tuple[int, ...], index: int, value: int = 1
) -> tuple[int, ...]:
"""Return a copy of the given tuple with `old_tuple[index]` being replaced by `value`.

Args:
old_tuple: to be copied
index: at which index to replace a value
value: to replace `old_tuple[index]`
"""
new_list = list(old_tuple)
new_list[index] = value
return tuple(new_list)


def _reduce_cartesian_axes_size_to_1(
transient_map_access: set[stree.nodes.MapEntry],
transient_data: dace.data.Data,
ijk_order: tuple[int, int, int],
) -> bool:
"""Reduce dimension size of transient to 1 if their are accessed only
in a single Map for the cartesian dimensions"""
refined = False
for axis in AxisIterator:
access_in_map_count = 0
for map_entry in transient_map_access:
if axis.as_str() in map_entry.params[0]:
access_in_map_count += 1

if access_in_map_count != 1:
continue

# This transient is used in exactly one single-Axis map
# therefore this dimension can be removed. BUT we are not truly
# removing it, we are reducing it to 1 to not have to deal
# with different slicing.
transient_data.shape = _change_index_of_tuple(
transient_data.shape,
axis.as_cartesian_index(),
value=1,
)
transient_data.set_strides_from_layout(*ijk_order)
refined = True

return refined


class _CartesianMapNesting:
def __init__(
self,
cartesian_current_map_nesting: list[stree.nodes.MapEntry | None],
node: stree.MapScope,
) -> None:
self._cartesian_current_map_nesting = cartesian_current_map_nesting
self._node = node

def __enter__(self) -> None:
if AxisIterator._I.value[0] in self._node.node.params[0]:
Comment thread
FlorianDeconinck marked this conversation as resolved.
self._cartesian_current_map_nesting[0] = self._node.node
elif AxisIterator._J.value[0] in self._node.node.params[0]:
self._cartesian_current_map_nesting[1] = self._node.node
elif AxisIterator._K.value[0] in self._node.node.params[0]:
self._cartesian_current_map_nesting[2] = self._node.node

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if AxisIterator._I.value[0] in self._node.node.params[0]:
self._cartesian_current_map_nesting[0] = None
elif AxisIterator._J.value[0] in self._node.node.params[0]:
self._cartesian_current_map_nesting[1] = None
elif AxisIterator._K.value[0] in self._node.node.params[0]:
self._cartesian_current_map_nesting[2] = None


class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor):
"""Collect all access of transient arrays per Maps."""

def __init__(self) -> None:
self.transient_map_access: dict[str, set[stree.nodes.MapEntry]] = {}
self._cartesian_current_map_nesting: list[stree.nodes.MapEntry | None] = [
None,
None,
None,
]

def __str__(self) -> str:
return "CartesianCollectMaps"

def visit_MapScope(self, node: stree.MapScope) -> None:
if len(node.node.params) > 1:
ndsl_log.debug(
"Can't apply CartesianRefineTransients, require unidimensional Maps"
)
Comment thread
FlorianDeconinck marked this conversation as resolved.
return

with _CartesianMapNesting(self._cartesian_current_map_nesting, node):
for child in node.children:
self.visit(child)

def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
for memlet in [*node.input_memlets(), *node.output_memlets()]:
if self.containers[memlet.data].transient:
for map_entry in self._cartesian_current_map_nesting:
if map_entry is not None:
self.transient_map_access[memlet.data].add(map_entry)

def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
self.containers = node.containers
for name, data in self.containers.items():
if data.transient:
self.transient_map_access[name] = set()

for child in node.children:
self.visit(child)


class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor):
"""Rebuild memlets from containers to ensure they are scope to the right size."""

def __str__(self) -> str:
return "RefineTransientAxis"

def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
for name, memlet in node.in_memlets.items():
if self.containers[memlet.data].transient:
node.in_memlets[name] = memlet.from_array(
memlet.data, self.containers[memlet.data]
)

for name, memlet in node.out_memlets.items():
if self.containers[memlet.data].transient:
node.out_memlets[name] = memlet.from_array(
memlet.data, self.containers[memlet.data]
)

def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
self.containers = node.containers
for child in node.children:
self.visit(child)


class CartesianRefineTransients(stree.ScheduleNodeTransformer):
"""Refine (reduce dimensionality) of transients based on their true use in
the cartesian dimensions.


It can do:
- Looking at usage of a transient in a cartesian axis (e.g. loop over a
cartesian axis) it will reduce that axis to 1 if it exists in _only one_.
It should but cannot do/will bug if:
- Dataflow analysis on the axis to prevent reducing an axis to one where
the transient is used with offset, leading to faulty numerics
- Using the dataflow above, we can reduce the dimensions to the correct lowest
size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element
buffer)
- Current action when detecting a valide candidate is to reduce the size of the dimension
to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce
the cache access significantly. This also has been implemented to _not_ deal with offset/slicing
downstream impact of removing an axis. Nevertheless the xis should be removed if it's not
used.

More tests:
- Test for dataflow with offset
- Test for I/J refine but not in K
- Test for J refine but not in I or K
- Test with dataflow: if/else, while, etc.
- Test with ForScope (FORWARD/BACKWARD) instead of Map
"""

def __init__(self, backend: str) -> None:
warnings.warn(
"CartesianRefineTransients is a WIP. It's usage is *severaly* limited"
"and will most likely lead to bad numerics. Check the docs, check utest.",
UserWarning,
stacklevel=2,
)

if backend in ["dace:cpu_kfirst"]:
self.ijk_order = (2, 1, 0)
elif backend in ["dace:cpu", "dace:gpu"]:
self.ijk_order = (1, 0, 2)
else:
raise NotImplementedError(
"[Schedule Tree Opt] CartesianRefineTransient not implemented for "
f"backend {backend}"
)

def __str__(self) -> str:
return "CartesianRefineTransients"

def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
collect_map = CollectTransientAccessInCartesianMaps()
collect_map.visit(node)

# Remove Axis
refined_transient = 0
for name, data in node.containers.items():
if not data.transient:
continue
refined = _reduce_cartesian_axes_size_to_1(
collect_map.transient_map_access[name],
data,
self.ijk_order,
)
refined_transient += 1 if refined else 0

RebuildMemletsFromContainers().visit(node)

ndsl_log.debug(f"🚀 {refined_transient} Transient refined")
1 change: 0 additions & 1 deletion ndsl/dsl/ndsl_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def check_for_quantity(object_: object) -> None:
obj=self,
config=self._dace_config,
)
print(type(self))

def __getattribute__(self, name: str) -> Any:
attr = super().__getattribute__(name)
Expand Down
Loading