Skip to content

Commit

Permalink
More test cases and check for entry nodes in the component
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Nov 19, 2024
1 parent 7dba3b4 commit f7c6327
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
15 changes: 11 additions & 4 deletions dace/transformation/passes/map_over_free_tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]):
if isinstance(node, dace.nodes.NestedSDFG):
inner_sdfg = node.sdfg
self.apply_pass(inner_sdfg, {})
elif len(state.in_edges(node)) == 0 and sd[node] is None:
_, start_nodes, end_nodes = self._get_component(state, node)
self._apply(state, start_nodes, end_nodes, counter)
counter += 1
elif (len(state.in_edges(node)) == 0
and sd[node] is None
and not isinstance(node, dace.nodes.EntryNode)
and not isinstance(node, dace.nodes.ExitNode)
):
component, start_nodes, end_nodes = self._get_component(state, node)
# Only apply if there are no entry nodes in the component
has_entry_node = any([isinstance(v, dace.nodes.EntryNode) for v in component])
if not has_entry_node:
self._apply(state, start_nodes, end_nodes, counter)
counter += 1

def _apply(self, state: SDFGState,
start_nodes: List[dace.nodes.Node],
Expand Down
67 changes: 67 additions & 0 deletions tests/map_over_free_tasklet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,48 @@ def _add_chain(

return chain_elements

def _trivial_map_sdfg():
sdfg = dace.SDFG("main")
state = sdfg.add_state("_s")
arr_name, arr = sdfg.add_array("A", (5,), dace.dtypes.float32)
an = state.add_access(arr_name)
map_entry, map_exit = state.add_map(name="assign", ndrange={
"i":dace.subsets.Range([(0, 4, 1)])
})
t1 = state.add_tasklet("t1", {}, {"_out"}, "_out = 2.0")
map_exit.add_in_connector(f"IN_{arr_name}")
map_exit.add_out_connector(f"OUT_{arr_name}")

state.add_edge(map_entry, None, t1, None, dace.memlet.Memlet(None))
state.add_edge(t1, "_out", map_exit, f"IN_{arr_name}", dace.memlet.Memlet("A[i]"))
state.add_edge(map_exit, f"OUT_{arr_name}", an, None, dace.memlet.Memlet("A[0:5]"))

sdfg.validate()
return sdfg

def _trivial_copy_map_sdfg():
sdfg = dace.SDFG("main")
state = sdfg.add_state("_s")
arr_name, arr = sdfg.add_array("A", (5,), dace.dtypes.float32)
an = state.add_access(arr_name)
pre_an = state.add_access(arr_name)
map_entry, map_exit = state.add_map(name="assign", ndrange={
"i":dace.subsets.Range([(0, 4, 1)])
})
t1 = state.add_tasklet("t1", {"_in"}, {"_out"}, "_out = _in * 2.0")
map_exit.add_in_connector(f"IN_{arr_name}")
map_exit.add_out_connector(f"OUT_{arr_name}")
map_entry.add_in_connector(f"IN_{arr_name}")
map_entry.add_out_connector(f"OUT_{arr_name}")

state.add_edge(pre_an, None, map_entry, f"IN_{arr_name}", dace.memlet.Memlet("A[0:5]"))
state.add_edge(map_entry, f"OUT_{arr_name}", t1, "_in", dace.memlet.Memlet("A[i]"))
state.add_edge(t1, "_out", map_exit, f"IN_{arr_name}", dace.memlet.Memlet("A[i]"))
state.add_edge(map_exit, f"OUT_{arr_name}", an, None, dace.memlet.Memlet("A[0:5]"))

sdfg.validate()
return sdfg

def _trivial_chain_sdfg():
sdfg = dace.SDFG("main")
state = sdfg.add_state("_s")
Expand Down Expand Up @@ -201,9 +243,34 @@ def test_trivial_chain_in_nested_sdfg():
_check_recursive(sdfg)
sdfg.validate()

def _count_maps(sdfg):
m = 0
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
m += _count_maps(node.sdfg)
elif isinstance(node, dace.nodes.MapEntry):
m += 1
return m

def test_trivial_assign_map():
sdfg = _trivial_map_sdfg()
mapOverFreeTasklet = MapOverFreeTasklet()
mapOverFreeTasklet.apply_pass(sdfg, {})
assert(_count_maps(sdfg) == 1)

def test_trivial_copy_map():
sdfg = _trivial_copy_map_sdfg()
mapOverFreeTasklet = MapOverFreeTasklet()
mapOverFreeTasklet.apply_pass(sdfg, {})
assert(_count_maps(sdfg) == 1)


if __name__ == "__main__":
test_trivial_chain()
test_two_trivial_chains_sdfg()
test_multiple_input_chain_sdfg()
test_complex_chain_sdfg()
test_trivial_chain_in_nested_sdfg()
test_trivial_assign_map()
test_trivial_copy_map()

0 comments on commit f7c6327

Please sign in to comment.