diff --git a/dace/transformation/passes/map_over_free_tasklet.py b/dace/transformation/passes/map_over_free_tasklet.py index 4756c51263..0c66a90a4e 100644 --- a/dace/transformation/passes/map_over_free_tasklet.py +++ b/dace/transformation/passes/map_over_free_tasklet.py @@ -52,10 +52,17 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]): if isinstance(node, dace.nodes.NestedSDFG): inner_sdfg = node.sdfg self.apply_pass(inner_sdfg, {}) - elif len(state.in_edges(node)) == 0 and sd[node] is None: - _, start_nodes, end_nodes = self._get_component(state, node) - self._apply(state, start_nodes, end_nodes, counter) - counter += 1 + elif (len(state.in_edges(node)) == 0 + and sd[node] is None + and not isinstance(node, dace.nodes.EntryNode) + and not isinstance(node, dace.nodes.ExitNode) + ): + component, start_nodes, end_nodes = self._get_component(state, node) + # Only apply if there are no entry nodes in the component + has_entry_node = any([isinstance(v, dace.nodes.EntryNode) for v in component]) + if not has_entry_node: + self._apply(state, start_nodes, end_nodes, counter) + counter += 1 def _apply(self, state: SDFGState, start_nodes: List[dace.nodes.Node], diff --git a/tests/map_over_free_tasklet_test.py b/tests/map_over_free_tasklet_test.py index a3fe8a5eaa..d0f086d101 100644 --- a/tests/map_over_free_tasklet_test.py +++ b/tests/map_over_free_tasklet_test.py @@ -53,6 +53,48 @@ def _add_chain( return chain_elements +def _trivial_map_sdfg(): + sdfg = dace.SDFG("main") + state = sdfg.add_state("_s") + arr_name, arr = sdfg.add_array("A", (5,), dace.dtypes.float32) + an = state.add_access(arr_name) + map_entry, map_exit = state.add_map(name="assign", ndrange={ + "i":dace.subsets.Range([(0, 4, 1)]) + }) + t1 = state.add_tasklet("t1", {}, {"_out"}, "_out = 2.0") + map_exit.add_in_connector(f"IN_{arr_name}") + map_exit.add_out_connector(f"OUT_{arr_name}") + + state.add_edge(map_entry, None, t1, None, dace.memlet.Memlet(None)) + state.add_edge(t1, "_out", map_exit, f"IN_{arr_name}", dace.memlet.Memlet("A[i]")) + state.add_edge(map_exit, f"OUT_{arr_name}", an, None, dace.memlet.Memlet("A[0:5]")) + + sdfg.validate() + return sdfg + +def _trivial_copy_map_sdfg(): + sdfg = dace.SDFG("main") + state = sdfg.add_state("_s") + arr_name, arr = sdfg.add_array("A", (5,), dace.dtypes.float32) + an = state.add_access(arr_name) + pre_an = state.add_access(arr_name) + map_entry, map_exit = state.add_map(name="assign", ndrange={ + "i":dace.subsets.Range([(0, 4, 1)]) + }) + t1 = state.add_tasklet("t1", {"_in"}, {"_out"}, "_out = _in * 2.0") + map_exit.add_in_connector(f"IN_{arr_name}") + map_exit.add_out_connector(f"OUT_{arr_name}") + map_entry.add_in_connector(f"IN_{arr_name}") + map_entry.add_out_connector(f"OUT_{arr_name}") + + state.add_edge(pre_an, None, map_entry, f"IN_{arr_name}", dace.memlet.Memlet("A[0:5]")) + state.add_edge(map_entry, f"OUT_{arr_name}", t1, "_in", dace.memlet.Memlet("A[i]")) + state.add_edge(t1, "_out", map_exit, f"IN_{arr_name}", dace.memlet.Memlet("A[i]")) + state.add_edge(map_exit, f"OUT_{arr_name}", an, None, dace.memlet.Memlet("A[0:5]")) + + sdfg.validate() + return sdfg + def _trivial_chain_sdfg(): sdfg = dace.SDFG("main") state = sdfg.add_state("_s") @@ -201,9 +243,34 @@ def test_trivial_chain_in_nested_sdfg(): _check_recursive(sdfg) sdfg.validate() +def _count_maps(sdfg): + m = 0 + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.NestedSDFG): + m += _count_maps(node.sdfg) + elif isinstance(node, dace.nodes.MapEntry): + m += 1 + return m + +def test_trivial_assign_map(): + sdfg = _trivial_map_sdfg() + mapOverFreeTasklet = MapOverFreeTasklet() + mapOverFreeTasklet.apply_pass(sdfg, {}) + assert(_count_maps(sdfg) == 1) + +def test_trivial_copy_map(): + sdfg = _trivial_copy_map_sdfg() + mapOverFreeTasklet = MapOverFreeTasklet() + mapOverFreeTasklet.apply_pass(sdfg, {}) + assert(_count_maps(sdfg) == 1) + + if __name__ == "__main__": test_trivial_chain() test_two_trivial_chains_sdfg() test_multiple_input_chain_sdfg() test_complex_chain_sdfg() test_trivial_chain_in_nested_sdfg() + test_trivial_assign_map() + test_trivial_copy_map()