Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate extraneous branch-end gotos in code generation #1355

Merged
merged 5 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class ControlFlow:
# a string with its generated code.
dispatch_state: Callable[[SDFGState], str]

# The parent control flow block of this one, used to avoid generating extraneous ``goto``s
parent: Optional['ControlFlow']

@property
def first_state(self) -> SDFGState:
"""
Expand Down Expand Up @@ -222,11 +225,18 @@ def as_cpp(self, codegen, symbols) -> str:
out_edges = sdfg.out_edges(elem.state)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
# If this is the last generated edge and it leads
# to the next state, skip emitting goto
# Skip gotos to immediate successors
successor = None
if (j == (len(out_edges) - 1) and (i + 1) < len(self.elements)):
successor = self.elements[i + 1].first_state
# If this is the last generated edge
if j == (len(out_edges) - 1):
if (i + 1) < len(self.elements):
# If last edge leads to next state in block
successor = self.elements[i + 1].first_state
elif i == len(self.elements) - 1:
# If last edge leads to first state in next block
next_block = _find_next_block(self)
if next_block is not None:
successor = next_block.first_state

expr += elem.generate_transition(sdfg, e, successor)
else:
Expand Down Expand Up @@ -478,13 +488,14 @@ def children(self) -> List[ControlFlow]:

def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]],
dispatch_state: Callable[[SDFGState], str]) -> Union[ForScope, WhileScope]:
dispatch_state: Callable[[SDFGState],
str], parent_block: GeneralBlock) -> Union[ForScope, WhileScope]:
"""
Helper method that constructs the correct structured loop construct from a
set of states. Can construct for or while loops.
"""

body = GeneralBlock(dispatch_state, [], [], [], [], [], True)
body = GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True)

guard_inedges = sdfg.in_edges(guard)
increment_edges = [e for e in guard_inedges if e in back_edges]
Expand Down Expand Up @@ -535,10 +546,10 @@ def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[Intersta
# Also ignore assignments in increment edge (handled in for stmt)
body.assignments_to_ignore.append(increment_edge)

return ForScope(dispatch_state, itvar, guard, init, condition, update, body, init_edges)
return ForScope(dispatch_state, parent_block, itvar, guard, init, condition, update, body, init_edges)

# Otherwise, it is a while loop
return WhileScope(dispatch_state, guard, condition, body)
return WhileScope(dispatch_state, parent_block, guard, condition, body)


def _cases_from_branches(
Expand Down Expand Up @@ -617,6 +628,31 @@ def _child_of(node: SDFGState, parent: SDFGState, ptree: Dict[SDFGState, SDFGSta
return False


def _find_next_block(block: ControlFlow) -> Optional[ControlFlow]:
"""
Returns the immediate successor control flow block.
"""
# Find block in parent
parent = block.parent
if parent is None:
return None
ind = next(i for i, b in enumerate(parent.children) if b is block)
if ind == len(parent.children) - 1 or isinstance(parent, (IfScope, IfElseChain, SwitchCaseScope)):
# If last block, or other children are not reachable from current node (branches),
# recursively continue upwards
return _find_next_block(parent)
return parent.children[ind + 1]


def _reset_block_parents(block: ControlFlow):
"""
Fixes block parents after processing.
"""
for child in block.children:
child.parent = block
_reset_block_parents(child)


def _structured_control_flow_traversal(sdfg: SDFG,
start: SDFGState,
ptree: Dict[SDFGState, SDFGState],
Expand Down Expand Up @@ -645,7 +681,7 @@ def _structured_control_flow_traversal(sdfg: SDFG,
"""

def make_empty_block():
return GeneralBlock(dispatch_state, [], [], [], [], [], True)
return GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True)

# Traverse states in custom order
visited = set() if visited is None else visited
Expand All @@ -657,7 +693,7 @@ def make_empty_block():
if node in visited or node is stop:
continue
visited.add(node)
stateblock = SingleState(dispatch_state, node)
stateblock = SingleState(dispatch_state, parent_block, node)

oe = sdfg.out_edges(node)
if len(oe) == 0: # End state
Expand Down Expand Up @@ -708,23 +744,25 @@ def make_empty_block():
if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())):
# If without else
if oe[0].dst is mergestate:
branch_block = IfScope(dispatch_state, sdfg, node, oe[1].data.condition, cblocks[oe[1]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[1].data.condition,
cblocks[oe[1]])
elif oe[1].dst is mergestate:
branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition,
cblocks[oe[0]])
else:
branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]],
cblocks[oe[1]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition,
cblocks[oe[0]], cblocks[oe[1]])
else:
# If there are 2 or more edges (one is not the negation of the
# other):
switch = _cases_from_branches(oe, cblocks)
if switch:
# If all edges are of form "x == y" for a single x and
# integer y, it is a switch/case
branch_block = SwitchCaseScope(dispatch_state, sdfg, node, switch[0], switch[1])
branch_block = SwitchCaseScope(dispatch_state, parent_block, sdfg, node, switch[0], switch[1])
else:
# Otherwise, create if/else if/.../else goto exit chain
branch_block = IfElseChain(dispatch_state, sdfg, node,
branch_block = IfElseChain(dispatch_state, parent_block, sdfg, node,
[(e.data.condition, cblocks[e] if e in cblocks else make_empty_block())
for e in oe])
# End of branch classification
Expand All @@ -739,11 +777,11 @@ def make_empty_block():
loop_exit = None
scope = None
if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node:
scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state)
scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state, parent_block)
body_start = oe[0].dst
loop_exit = oe[1].dst
elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node:
scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state)
scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state, parent_block)
body_start = oe[1].dst
loop_exit = oe[0].dst

Expand Down Expand Up @@ -836,7 +874,8 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState
if len(common_frontier) == 1:
branch_merges[state] = next(iter(common_frontier))

root_block = GeneralBlock(dispatch_state, [], [], [], [], [], True)
root_block = GeneralBlock(dispatch_state, None, [], [], [], [], [], True)
_structured_control_flow_traversal(sdfg, sdfg.start_state, ptree, branch_merges, back_edges, dispatch_state,
root_block)
_reset_block_parents(root_block)
return root_block
2 changes: 1 addition & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def dispatch_state(state: SDFGState) -> str:
# If disabled, generate entire graph as general control flow block
states_topological = list(sdfg.topological_sort(sdfg.start_state))
last = states_topological[-1]
cft = cflow.GeneralBlock(dispatch_state,
cft = cflow.GeneralBlock(dispatch_state, None,
[cflow.SingleState(dispatch_state, s, s is last) for s in states_topological], [],
[], [], [], False)

Expand Down
29 changes: 29 additions & 0 deletions tests/codegen/control_flow_detection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def test_single_outedge_branch():
assert np.allclose(res, 2)


def test_extraneous_goto():

@dace.program
def tester(a: dace.float64[20]):
if a[0] < 0:
a[1] = 1
a[2] = 1

sdfg = tester.to_sdfg(simplify=True)
assert 'goto' not in sdfg.generate_code()[0].code


def test_extraneous_goto_nested():

@dace.program
def tester(a: dace.float64[20]):
if a[0] < 0:
if a[0] < 1:
a[1] = 1
else:
a[1] = 2
a[2] = 1

sdfg = tester.to_sdfg(simplify=True)
assert 'goto' not in sdfg.generate_code()[0].code


if __name__ == '__main__':
test_for_loop_detection()
test_invalid_for_loop_detection()
Expand All @@ -128,3 +155,5 @@ def test_single_outedge_branch():
test_edge_sympy_function('TrueFalse')
test_edge_sympy_function('SwitchCase')
test_single_outedge_branch()
test_extraneous_goto()
test_extraneous_goto_nested()
Loading