diff --git a/README.md b/README.md index 43dfd280..b256eda7 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ To run the GPU backends, you'll need: - Libraries: MPI compiled with CUDA support - CUDA 11.2+ - Python package: - - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) + - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) A simple way to install MPI is using pre-built wheels, e.g. diff --git a/ndsl/__init__.py b/ndsl/__init__.py index aa275711..c7730282 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,4 +1,5 @@ from . import dsl # isort:skip +from .logging import ndsl_log # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm @@ -22,7 +23,6 @@ from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization import GridSizer, QuantityFactory, SubtileGridSizer -from .logging import ndsl_log from .monitor.netcdf_monitor import NetCDFMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index a9d99902..43b6c4f3 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numbers import os from collections.abc import Callable, Sequence from typing import Any @@ -32,6 +33,8 @@ negative_qtracers_checker, sdfg_nan_checker, ) +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -41,6 +44,13 @@ from ndsl.optional_imports import cupy as cp +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False +"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" + +_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] +"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" + + def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" return func @@ -124,18 +134,47 @@ def _build_sdfg( ) -> None: """Build the .so out of the SDFG on the top tile ranks only.""" is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile + device_type = DaceDeviceType.GPU if config.is_gpu_backend() else DaceDeviceType.CPU if is_compiling: with DaCeProgress(config, "Validate original SDFG"): sdfg.validate() + # Fully specialize all known symbols and then propagate these changes in the simplify + # pass that follows. This is not only a smart idea in general, but also simplifies (haha) + # the schedule tree (optimization) roundtrip. + with DaCeProgress(config, "Fully specialize symbols"): + for my_sdfg in sdfg.all_sdfgs_recursive(): + if my_sdfg.parent_nsdfg_node is not None: + repl_dict = {} + for sym, val in my_sdfg.parent_nsdfg_node.symbol_mapping.items(): + if isinstance(val, numbers.Number): + repl_dict[sym] = val + my_sdfg.replace_dict(repl_dict) + + with DaCeProgress(config, "Simplify (1)"): + _simplify(sdfg) + + if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + stree = sdfg.as_schedule_tree() + + with DaCeProgress(config, "Schedule Tree: optimization"): + if config.is_gpu_backend(): + GPUPipeline().run(stree) + else: + CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + + with DaCeProgress(config, "Schedule Tree: go back to SDFG"): + sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) + # Make the transients array persistents if config.is_gpu_backend(): # TODO # The following should happen on the stree level _to_gpu(sdfg) - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.GPU) + make_transients_persistent(sdfg=sdfg, device=device_type) # Upload args to device _upload_to_device(list(args) + list(kwargs.values())) @@ -145,7 +184,7 @@ def _build_sdfg( for _sd, _aname, arr in sdfg.arrays_recursive(): if arr.shape == (1,): arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU) + make_transients_persistent(sdfg=sdfg, device=device_type) # Build non-constants & non-transients from the sdfg_kwargs sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) @@ -157,8 +196,8 @@ def _build_sdfg( if k in sdfg_kwargs and tup[1].transient: del sdfg_kwargs[k] - with DaCeProgress(config, "Simplify"): - _simplify(sdfg, validate=False, verbose=True) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) # Move all memory that can be into a pool to lower memory pressure. # Change Persistent memory (sub-SDFG) into Scope and flag it. @@ -182,6 +221,9 @@ def _build_sdfg( negative_delp_checker(sdfg) negative_qtracers_checker(sdfg) + with DaCeProgress(config, "Validate before compile"): + sdfg.validate() + # Compile with DaCeProgress(config, "Codegen & compile"): sdfg.compile() @@ -495,7 +537,7 @@ def orchestrate( raise RuntimeError( f"Could not orchestrate, " f"{type(obj).__name__}.{method_to_orchestrate} " - "does not exists" + "does not exist." ) if dace_compiletime_args is None: @@ -535,7 +577,9 @@ def __call__(self, *arg, **kwarg): # type: ignore[no-untyped-def] return wrapped(*arg, **kwarg) def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] - return wrapped.__sdfg__(*args, **kwargs) + sdfg = wrapped.__sdfg__(*args, **kwargs) + sdfg.validate() + return sdfg def __sdfg_closure__(self, reevaluate=None): # type: ignore[no-untyped-def] return wrapped.__sdfg_closure__(reevaluate) diff --git a/ndsl/dsl/dace/sdfg/loop_transform.py b/ndsl/dsl/dace/sdfg/loop_transform.py new file mode 100644 index 00000000..7e6cf1d4 --- /dev/null +++ b/ndsl/dsl/dace/sdfg/loop_transform.py @@ -0,0 +1,19 @@ +from dace import SDFG, ScheduleType, nodes + + +def make_SDFG_CPU_sequential(sdfg: SDFG) -> None: + """Utility to turn a CPU-based SDFG to pure serial by removing OpenMP""" + # Disable OpenMP sections + for sd in sdfg.all_sdfgs_recursive(): + sd.openmp_sections = False + + # Disable OpenMP maps + for node, _ in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.EntryNode): + schedule = getattr(node, "schedule", False) + if schedule in ( + ScheduleType.CPU_Multicore, + ScheduleType.CPU_Persistent, + ScheduleType.Default, + ): + node.schedule = ScheduleType.Sequential diff --git a/ndsl/dsl/dace/stree/__init__.py b/ndsl/dsl/dace/stree/__init__.py new file mode 100644 index 00000000..6435e662 --- /dev/null +++ b/ndsl/dsl/dace/stree/__init__.py @@ -0,0 +1,4 @@ +from .pipeline import CPUPipeline, GPUPipeline + + +__all__ = ["CPUPipeline", "GPUPipeline"] diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py new file mode 100644 index 00000000..47c764b3 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,4 @@ +from .axis_merge import AxisIterator, CartesianAxisMerge + + +__all__ = ["AxisIterator", "CartesianAxisMerge"] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py new file mode 100644 index 00000000..262a6021 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +import copy +import re +from typing import Any + +import dace +import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.properties import CodeBlock + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( + AxisIterator, + no_data_dependencies_on_cartesian_axis, +) +from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( + detect_cycle, + list_index, + swap_node_position_in_tree, +) + + +def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a map over the given axis.""" + map_parameter = node.node.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def _both_same_single_axis_maps( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + return ( + (len(first.node.params) == 1 and len(second.node.params) == 1) # Single axis + and first.node.params[0] == second.node.params[0] # Same axis + and _is_axis_map(first, axis) # Correct axis + ) + + +def _can_merge_axis_maps( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + return _both_same_single_axis_maps( + first, second, axis + ) and no_data_dependencies_on_cartesian_axis( + first, + second, + axis, + ) + + +class InsertOvercomputationGuard(stree.ScheduleNodeTransformer): + def __init__( + self, + axis_as_string: str, + *, + merged_range: dace.subsets.Range, + original_range: dace.subsets.Range, + ): + self._axis_as_string = axis_as_string + self._merged_range = merged_range + self._original_range = original_range + + def _execution_condition(self) -> CodeBlock: + # NOTE range.ranges are inclusive, e.g. + # Range(0:4) -> ranges = (start=1, stop=3, step=1) + range = self._original_range + start = range.ranges[0][0] + stop = range.ranges[0][1] + step = range.ranges[0][2] + return CodeBlock( + f"{self._axis_as_string} >= {start} " + f"and {self._axis_as_string} <= {stop} " + f"and ({self._axis_as_string} - {start}) % {step} == 0" + ) + + def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: + all_children_are_maps = all( + [isinstance(child, stree.MapScope) for child in node.children] + ) + if not all_children_are_maps: + if self._merged_range != self._original_range: + node.children = [ + stree.IfScope( + condition=self._execution_condition(), children=node.children + ) + ] + return node + + node.children = self.visit(node.children) + return node + + +def _get_next_node( + nodes: list[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, +) -> stree.ScheduleTreeNode: + return nodes[list_index(nodes, node) + 1] + + +def _last_node( + nodes: list[stree.ScheduleTreeNode], node: stree.ScheduleTreeNode +) -> bool: + return list_index(nodes, node) >= len(nodes) - 1 + + +def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str: + axis_clean = f"{axis.as_str()}" + pattern = f"{axis.as_str()}_[0-9]*" + + return re.sub(pattern, axis_clean, name_to_normalize) + + +class NormalizeAxisSymbol(stree.ScheduleNodeVisitor): + def __init__(self, axis: AxisIterator) -> None: + self.axis = axis + + def visit_MapScope( + self, + map_scope: stree.MapScope, + axis_replacements: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + if axis_replacements is None: + axis_replacements = {} + for index, param in enumerate(map_scope.node.params): + sanitized_param = _sanitize_axis(self.axis, param) + axis_replacements[param] = sanitized_param + map_scope.node.params[index] = sanitized_param + + # visit children + for child in map_scope.children: + self.visit(child, axis_rpl_dict=axis_replacements) + + def visit_TaskletNode( + self, + node: stree.TaskletNode, + axis_replacements: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + if axis_replacements is None: + axis_replacements = {} + for memlets in node.in_memlets.values(): + memlets.replace(axis_replacements) + for memlets in node.out_memlets.values(): + memlets.replace(axis_replacements) + + +class CartesianAxisMerge(stree.ScheduleNodeTransformer): + """Merge a cartesian axis if they are contiguous in code-flow. + + Can do: + - merge a given axis with the next maps at the same recursion level + - can overcompute (eager) to allow for more merging at the cost of an if + + Args: + axis: AxisIterator to be merged + eager: overcompute with a conditional guard + """ + + def __init__( + self, + axis: AxisIterator, + *, + eager: bool = True, + ) -> None: + self.axis = axis + self.eager = eager + + def __str__(self) -> str: + return f"CartesianAxisMerge({self.axis.name})" + + def _merge_node( + self, + node: stree.ScheduleTreeNode, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + """Direct code to the correct resolver for the node (e.g. visitor) + + Dev Note: Order matters! + Default behavior for base class must be _after_ bespoke leaf class + behavior (e.g. IfScope before ControlFlowScope) + """ + + if isinstance(node, stree.MapScope): + return self._map_overcompute_merge(node, nodes) + elif isinstance(node, stree.IfScope): + return self._push_ifelse_down(node, nodes) + elif isinstance(node, stree.TaskletNode): + return self._push_tasklet_down(node, nodes) + elif isinstance(node, stree.ControlFlowScope): + return self._default_control_flow(node, nodes) + else: + ndsl_log.debug( + f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends." + ) + return 0 + + def _default_control_flow( + self, + the_control_flow: stree.ControlFlowScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + if len(the_control_flow.children) != 0: + return self._merge(the_control_flow) + + return 0 + + def _push_tasklet_down( + self, + the_tasklet: stree.TaskletNode, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + """Push tasklet into a consecutive map.""" + in_memlets = the_tasklet.input_memlets() + if len(in_memlets) != 0: + if "__pystate" in [tasklet.data for tasklet in the_tasklet.input_memlets()]: + return 0 # Tasklet is a callback + + next_index = list_index(nodes, the_tasklet) + if next_index == len(nodes): + return 0 # Last node - done + + next_node = nodes[next_index + 1] + + # Before checking the possibility of merging - attempt to surface + # a map from the next nodes + merged = self._merge_node(next_node, nodes) + + # Attempt to push the tasklet in the next map + ndsl_log.debug(" Push tasklet down into next map") + next_node = nodes[next_index + 1] + if isinstance(next_node, stree.MapScope): + next_node.children.insert(0, the_tasklet) + the_tasklet.parent = next_node + nodes.remove(the_tasklet) + merged += self._merge_node(next_node, nodes) + + return merged + + def _push_ifelse_down( + self, + the_if: stree.IfScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + merged = 0 + + # Recurse down if/else/elif + if_index = list_index(nodes, the_if) + if len(the_if.children) != 0: + merged += self._merge_node(the_if.children[0], the_if.children) + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) + ): + merged += self._merge_node(else_node, else_node.children) + else: + break + + # Look at swapping if/else/elif first map w/ control flow + + # Gather all first maps - if they do not exists, get out + all_maps = [] + if isinstance(the_if.children[0], stree.MapScope): + all_maps.append(the_if.children[0]) + else: + return merged + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) + ): + if isinstance(else_node.children[0], stree.MapScope): + all_maps.append(else_node.children[0]) + else: + return merged + + else: + break + + # Check for mergeability + if len(all_maps) > 1: + the_map = all_maps[0] + for _map in all_maps[1:]: + if not _can_merge_axis_maps(the_map, _map, self.axis): + return merged + + # We are good to go - swap it all + ndsl_log.debug(f" Push IF {the_if.condition.as_string} down") + inner_if_map = the_if.children[0] + + # Swap IF & maps + if_index = list_index(nodes, the_if) + swap_node_position_in_tree(the_if, inner_if_map) + + # Swap ELIF/ELSE & maps + for else_index in range(if_index + 1, len(nodes)): + if else_index < len(nodes) and ( + isinstance(nodes[else_index], stree.ElseScope) + or isinstance(nodes[else_index], stree.ElifScope) + ): + swap_node_position_in_tree( + nodes[else_index], nodes[else_index].children[0] + ) + else: + break + + # Merge the Maps + assert isinstance(nodes[if_index], stree.MapScope) + merged += self._map_overcompute_merge(nodes[if_index], nodes) + + return merged + + def _map_overcompute_merge( + self, + the_map: stree.MapScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + if _last_node(nodes, the_map): + return 0 + + next_node = _get_next_node(nodes, the_map) + + # If the next node is not a MapScope - recurse + if not isinstance(next_node, stree.MapScope): + merged = self._merge_node(next_node, nodes) + new_next_node = _get_next_node(nodes, the_map) + if new_next_node == next_node: + return merged + return merged + self._merge_node(the_map, nodes) + + # Attempt to merge consecutive maps + if not _can_merge_axis_maps(the_map, next_node, self.axis): + return 0 + + # Over compute to merge: + # - force-merge by expanding the ranges + # - then, guard children to only run in their respective range + first_range = the_map.node.map.range + second_range = next_node.node.map.range + merged_range = dace.subsets.Range( + [ + ( + f"min({first_range.ranges[0][0]}, {second_range.ranges[0][0]})", + f"max({first_range.ranges[0][1]}, {second_range.ranges[0][1]})", + 1, # NOTE: we can optimize this to gcd later + ) + ] + ) + + ndsl_log.debug( + f" Merge {self.axis.name} map: {first_range} ⋃ {second_range} -> {merged_range}" + ) + + # push IfScope down if children are just maps + axis_as_str = the_map.node.params[0] + first_map = InsertOvercomputationGuard( + axis_as_str, merged_range=merged_range, original_range=first_range + ).visit(the_map) + second_map = InsertOvercomputationGuard( + axis_as_str, + merged_range=merged_range, + original_range=second_range, + ).visit(next_node) + merged_children: list[stree.MapScope] = [ + *first_map.children, + *second_map.children, + ] + first_map.children = merged_children + + # TODO also merge containers and symbols (if applicable) + first_map.node.map.range = merged_range + + # delete now-merged second_map + del nodes[list_index(nodes, next_node)] + + return 1 + + def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int: + merged = 0 + + if __debug__: + detect_cycle(node.children, set()) + + i_candidate = 0 + while i_candidate < len(node.children): + next_node = node.children[i_candidate] + merged += self._merge_node(next_node, node.children) + i_candidate += 1 + + if __debug__: + detect_cycle(node.children, set()) + + return merged + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + """Merge as many maps as possible. + + The algorithm works as follows: + - Start merging - move nodes to surface maps as much as possible + - Try to merge the surfaced maps + - When done, count the number of actual merges + - If NO merges - restore the previous children + (undo potential changes that didn't lead to map merge) + Then exit. + """ + + # TODO: many interval generate many iterator name right now + # e.g. _k_0, _k_1... + # This makes merging more difficult. We could write a pre-pass + # that cleans this up BUT we have an issue with the THIS_K feature + # in the tasklet... + # NormalizeAxisSymbol(self.axis).visit(node) + + overall_merged = 0 + i = 0 + while True: + i += 1 + ndsl_log.debug(f"🔥 Merge attempt #{i}") + previous_children = copy.deepcopy(node.children) + try: + merged = self._merge(node) + overall_merged += merged + if __debug__: + detect_cycle(node.children, set()) + except RecursionError as re: + raise re + + # If we didn't merge, we revert the children + # to the previous state + if merged == 0: + ndsl_log.debug("🥹 No merges, revert!") + node.children = previous_children + break + + ndsl_log.debug( + f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged" + ) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py new file mode 100644 index 00000000..0626133e --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -0,0 +1,149 @@ +from enum import Enum + +import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.memlet import Memlet + +from ndsl import ndsl_log + + +class AxisIterator(Enum): + _I = ("__i", 0) + _J = ("__j", 1) + _K = ("__k", 2) + + def as_str(self) -> str: + return self.value[0] + + def as_cartesian_index(self) -> int: + return self.value[1] + + +def no_data_dependencies_on_cartesian_axis( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + """Check for read after write. Allow when indexation on the axis + is not offset.""" + + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + + if write.subset.dims() <= axis.as_cartesian_index(): + # Dimension does not exist + continue + + previous_axis_index = write.subset[axis.as_cartesian_index()][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + ndsl_log.debug( + f"[{axis.name} Merge] Found read after write conflict " + f"for {write.data} " + f"w/ different offset to {axis.name} (" + f"write at {write.subset[axis.as_cartesian_index()][0]}, " + f"read at {read.subset[axis.as_cartesian_index()][0]})" + ) + return False + return True + + +def no_data_dependencies( + first: stree.MapScope, + second: stree.MapScope, + restrict_check_to_k: bool = False, +) -> bool: + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # Make sure we don't have read after write conditions. + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + if restrict_check_to_k: + if write.subset.dims() < 3: + # Case of 2D write - no K dependency + continue + + previous_k_index = write.subset[2][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_k_index != read.subset[2][0]: + print( + "[K Merge] Found read after write conflict " + f"for {write.data} " + "w/ different offset to K (" + f"write at {write.subset[2][0]}, " + f"read at {read.subset[2][0]})" + ) + return False + + else: + if write.data in [read.data for read in read_collector.in_memlets]: + print( + f"[All dims merge] Found potential read after write conflict for {write.data}" + ) + return False + return True + + +class MemletCollector(stree.ScheduleNodeVisitor): + """Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls.""" + + in_memlets: list[Memlet] + out_memlets: list[Memlet] + + def __init__( + self, *, collect_reads: bool = True, collect_writes: bool = True + ) -> None: + self._collect_reads = collect_reads + self._collect_writes = collect_writes + + self.in_memlets = [] + self.out_memlets = [] + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + if self._collect_reads: + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + if self._collect_writes: + self.out_memlets.extend([memlet for memlet in node.out_memlets.values()]) + + def visit_LibraryCall(self, node: stree.LibraryCall) -> None: + if self._collect_reads: + if isinstance(node.in_memlets, set): + self.in_memlets.extend(node.in_memlets) + else: + assert isinstance(node.in_memlets, dict) + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + + if self._collect_writes: + if isinstance(node.out_memlets, set): + self.out_memlets.extend(node.out_memlets) + else: + assert isinstance(node.out_memlets, dict) + self.out_memlets.extend( + [memlet for memlet in node.out_memlets.values()] + ) + + +def has_dynamic_memlets(first: stree.MapScope, second: stree.MapScope) -> bool: + first_collector = MemletCollector() + second_collector = MemletCollector() + first_collector.visit(first) + second_collector.visit(second) + has_dynamic_memlets = any( + [ + memlet.dynamic + for memlet in [ + *first_collector.in_memlets, + *first_collector.out_memlets, + *second_collector.in_memlets, + *second_collector.out_memlets, + ] + ] + ) + return has_dynamic_memlets diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py new file mode 100644 index 00000000..2583ec2d --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -0,0 +1,21 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as stree +import dace.subsets as sbs + + +class SpecializeCartesianMaps(stree.ScheduleNodeVisitor): + def __init__(self, mappings: dict[str, int]) -> None: + super().__init__() + self._mappings = mappings + + def visit_MapScope(self, node: stree.MapScope) -> None: + dims = [] + for p in node.node.map.params: + if p == "__i": + dims.append((0, self._mappings["__I"], 1)) + if p == "__j": + dims.append((0, self._mappings["__J"], 1)) + if p.startswith("__k"): + dims.append((0, self._mappings["__K"], 1)) + node.node.map.range = sbs.Range(dims) + + self.visit(node.children) diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py new file mode 100644 index 00000000..b965fc4a --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -0,0 +1,44 @@ +from typing import Collection + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + + +def swap_node_position_in_tree( + top_node: stree.ScheduleTreeScope, child_node: stree.ScheduleTreeScope +) -> None: + """Top node becomes child, child becomes top node""" + # Take refs before swap + top_children = top_node.parent.children + top_level_parent = top_node.parent + + # Swap childrens + top_node.children = child_node.children + child_node.children = [top_node] + top_children.insert(list_index(top_children, top_node), child_node) + + # Re-parent + top_node.parent = child_node + child_node.parent = top_level_parent + + # Remove now-pushed original node + top_children.remove(top_node) + + +def detect_cycle(nodes: list[stree.ScheduleTreeNode], visited: set) -> None: + """Detect the cycles in the tree.""" + # Dev note: isn't there a DaCe tool for this?! + for n in nodes: + if id(n) in visited: + breakpoint() + visited.add(id(n)) + if hasattr(n, "children"): + detect_cycle(n.children, visited) + + +def list_index( + collection: Collection[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, +) -> int: + """Check if node is in list with "is" operator.""" + # compare with "is" to get memory comparison. ".index()" uses value comparison + return next(index for index, element in enumerate(collection) if element is node) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py new file mode 100644 index 00000000..10fb77cd --- /dev/null +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge + + +class StreePipeline(ABC): + @abstractmethod + def __hash__(self) -> int: + raise NotImplementedError("Missing implementation of __hash__") + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError("Missing implementation of __repr__") + + @abstractmethod + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + raise NotImplementedError("Missing implementation of run") + + +class CPUPipeline(StreePipeline): + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: + self.passes = ( + passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] + ) + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree + + +class GPUPipeline(StreePipeline): + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: + self.passes = passes if passes else [] + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py index 22d466e1..f33ae8a6 100644 --- a/ndsl/quantity/field_bundle.py +++ b/ndsl/quantity/field_bundle.py @@ -25,6 +25,7 @@ class FieldBundle: """ _quantity: Quantity + _per_name_view: dict[str, Quantity] = {} _indexer: _FieldBundleIndexer = {} def __init__( @@ -83,13 +84,19 @@ def __getattr__(self, name: str) -> Quantity: return None # type: ignore # ToDo: extend the dims below to work with more than 4 dims assert len(self._quantity.data.shape) == 4 - return Quantity( - data=self._quantity.data[:, :, :, self.index(name)], - dims=self._quantity.dims[:-1], - units=self._quantity.units, - origin=self._quantity.origin[:-1], - extent=self._quantity.extent[:-1], - ) + + if name not in self._per_name_view: + # Memoize the Quantities returned here to ensue that we only ever + # have one `field.a_name`-Quantity floating around. If not, DaCe + # orchestration gets (rightly so) confused. + self._per_name_view[name] = Quantity( + data=self._quantity.data[:, :, :, self.index(name)], + dims=self._quantity.dims[:-1], + units=self._quantity.units, + origin=self._quantity.origin[:-1], + extent=self._quantity.extent[:-1], + ) + return self._per_name_view[name] def index(self, name: str) -> int: """Get index from name.""" diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py new file mode 100644 index 00000000..fb32a35d --- /dev/null +++ b/tests/stree_optimizer/test_optimization.py @@ -0,0 +1,69 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def stencil_A(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + +def stencil_B(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_B = stencil_factory.from_dims_halo( + func=stencil_B, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil_A(in_field, out_field) + self.stencil_B(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + """Dev Note: + + The below code sucessfully merges top level K loop (2 loops) + How do we test it?! Running doesn't test merging and the compilation + is a near-black box. We could reach in the `dace_config.compiled_sdfg` + cache but it's keyed on the dace.program and if we can reach the program + well we can reach the SDFG and turn it into an stree for verification + Should we run orchestration "by hand"? + Can we intercept the `stree` ? After all we just want to check that! + + Test is deactivated for now""" + + return True + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/stree_optimizer/test_pipeline.py new file mode 100644 index 00000000..6d4c6f74 --- /dev/null +++ b/tests/stree_optimizer/test_pipeline.py @@ -0,0 +1,48 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def double_map(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo( + func=double_map, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False