Skip to content

Commit

Permalink
Make scalar to symbol promotion robust to node order in state (#1766)
Browse files Browse the repository at this point in the history
Fixes #1727
  • Loading branch information
tbennun authored Nov 15, 2024
1 parent f757687 commit b5f91e1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
2 changes: 2 additions & 0 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, children: Optional[List['ScheduleTreeNode']] = None):
if self.children:
for child in children:
child.parent = self
self.containers = {}
self.symbols = {}

def as_string(self, indent: int = 0):
if not self.children:
Expand Down
4 changes: 3 additions & 1 deletion dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]):
for state in sdfg.states():
scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names]
for node in scalar_nodes:
if node not in state:
continue
symname = array_names[node.data]
for out_edge in state.out_edges(node):
for e in state.memlet_tree(out_edge):
Expand Down Expand Up @@ -649,7 +651,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]:
scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote]
# Step 2: Assignment tasklets
for node in scalar_nodes:
if state.in_degree(node) == 0:
if node not in state or state.in_degree(node) == 0:
continue
in_edge = state.in_edges(node)[0]
input = in_edge.src
Expand Down
30 changes: 30 additions & 0 deletions tests/passes/scalar_to_symbol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,35 @@ def test_double_index_bug():
assert getattr(sympy_node, "name", None) != "indices"


def test_reversed_order():
"""
Tests a failure reported in issue #1727.
"""
sdfg = dace.SDFG('tester')
sdfg.add_array('inputs', [1], dace.int32)
sdfg.add_transient('a', [1], dace.int32)
sdfg.add_transient('b', [1], dace.int32)
sdfg.add_array('output', [1], dace.int32)
initstate = sdfg.add_state()
state = sdfg.add_state_after(initstate)
finistate = sdfg.add_state_after(state)

# Note the order here
w = state.add_write('b')
t = state.add_tasklet('assign', {'inp'}, {'out'}, 'out = inp')
r = state.add_read('a')
state.add_edge(t, 'out', w, None, dace.Memlet('b'))
state.add_edge(r, None, t, 'inp', dace.Memlet('a'))

initstate.add_nedge(initstate.add_read('inputs'), initstate.add_write('a'), dace.Memlet('inputs'))
finistate.add_nedge(finistate.add_read('b'), finistate.add_write('output'), dace.Memlet('output'))

sdfg.validate()
promoted = scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {})
assert promoted == {'a', 'b'}
sdfg.compile()


if __name__ == '__main__':
test_find_promotable()
test_promote_simple()
Expand All @@ -753,3 +782,4 @@ def test_double_index_bug():
test_ternary_expression(False)
test_ternary_expression(True)
test_double_index_bug()
test_reversed_order()

0 comments on commit b5f91e1

Please sign in to comment.