Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Revert "compiler: Relax WaitLock regions in a ScheduleTree" #2141

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from devito.ir.stree.tree import (ScheduleTree, NodeIteration, NodeConditional,
NodeSync, NodeExprs, NodeSection, NodeHalo)
from devito.ir.support import (SEQUENTIAL, Any, Interval, IterationInterval,
IterationSpace, WaitLock, normalize_properties,
normalize_syncs)
IterationSpace, normalize_properties, normalize_syncs)
from devito.mpi.halo_scheme import HaloScheme
from devito.tools import Bunch, DefaultOrderedDict

Expand Down Expand Up @@ -157,10 +156,13 @@ def preprocess(clusters, options=None, **kwargs):
found.append(c1)
queue.remove(c1)

syncs = normalize_syncs(c.syncs, *[c1.syncs for c1 in found])
halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found])
syncs = normalize_syncs(*[c1.syncs for c1 in found])
if syncs:
ispace = c.ispace.project(syncs)
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))

processed.append(c.rebuild(syncs=syncs, halo_scheme=halo_scheme))
halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=halo_scheme))

# Sanity check!
try:
Expand All @@ -177,34 +179,15 @@ def reuse_partial_subtree(c0, c1, d=None):


def reuse_whole_subtree(c0, c1, d=None):
if not reuse_partial_subtree(c0, c1, d):
return False

syncs0 = c0.syncs.get(d, [])
syncs1 = c1.syncs.get(d, [])

if syncs0 == syncs1:
return True
elif not syncs0 and all(isinstance(s, WaitLock) for s in syncs1):
return True

return False
return (c0.guards.get(d) == c1.guards.get(d) and
c0.syncs.get(d) == c1.syncs.get(d))


def augment_partial_subtree(cluster, tip, mapper, it=None):
d = it.dim

try:
syncs = cluster.syncs[d]
if all(isinstance(s, WaitLock) for s in syncs):
# Unlike all other SyncOps, a WaitLock "floats" in the stree, in that
# it doesn't need to wrap any subtree. Thus, a WaitLock acts like
# a barrier to what follows inside `d`
NodeSync(syncs, tip)
else:
tip = NodeSync(syncs, tip)
except KeyError:
pass
if d in cluster.syncs:
tip = NodeSync(cluster.syncs[d], tip)

mapper[it].bottom = tip

Expand Down
2 changes: 1 addition & 1 deletion devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _as_number(self, v, args):
else:
assert self.target.c0.is_Array
assert args is not None
return int(v.subs(args))
return int(subs_op_args(v, args))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sympy 1.12 patch


def _arg_defaults(self, allocator, alias, args=None):
# Lazy initialization if `allocator` is necessary as the `allocator`
Expand Down
39 changes: 17 additions & 22 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,12 @@ def test_tasking_in_isolation(self, opt):
op = Operator(eqns, opt=opt)

# Check generated code
trees = retrieve_iteration_tree(op)
assert len(trees) == 2
assert len(retrieve_iteration_tree(op)) == 3
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1
sections = FindNodes(Section).visit(op)
assert len(sections) == 2
assert str(trees[0].root.nodes[0].body[0]) == 'while(lock0[0] == 0);'
body = sections[1].body[0].body[0]
assert len(sections) == 3
assert str(sections[0].body[0].body[0].body[0].body[0]) == 'while(lock0[0] == 0);'
body = sections[2].body[0].body[0]
assert str(body.body[0].condition) == 'Ne(lock0[0], 2)'
assert str(body.body[1]) == 'lock0[0] = 0;'
body = body.body[2]
Expand Down Expand Up @@ -220,12 +219,11 @@ def test_tasking_unfused_two_locks(self):
op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'linearize': False}))

# Check generated code
trees = retrieve_iteration_tree(op)
assert len(trees) == 3
assert len(retrieve_iteration_tree(op)) == 3
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1 + 2
sections = FindNodes(Section).visit(op)
assert len(sections) == 4
assert (str(trees[0].root.nodes[1].body[0]) ==
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert str(body.body[0].condition) == 'Ne(lock0[0], 2)'
Expand Down Expand Up @@ -272,12 +270,11 @@ def test_tasking_forcefuse(self):
{'fuse-tasks': True, 'linearize': False}))

# Check generated code
trees = retrieve_iteration_tree(op)
assert len(trees) == 3
assert len(retrieve_iteration_tree(op)) == 3
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert (str(trees[0].root.nodes[1].body[0]) ==
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert str(body.body[0].condition) == 'Ne(lock0[0], 2) | Ne(lock1[0], 2)'
Expand Down Expand Up @@ -342,12 +339,12 @@ def test_tasking_multi_output(self):
op1 = Operator(eqns, opt=('tasking', 'orchestrate', {'linearize': False}))

# Check generated code
trees = retrieve_iteration_tree(op1)
assert len(trees) == 4
assert len(retrieve_iteration_tree(op1)) == 4
assert len([i for i in FindSymbols().visit(op1) if isinstance(i, Lock)]) == 1
assert str(trees[1].root.nodes[0].body[0]) == 'while(lock0[t2] == 0);'
sections = FindNodes(Section).visit(op1)
assert len(sections) == 2
assert str(sections[0].body[0].body[0].body[0].body[0]) ==\
'while(lock0[t2] == 0);'
for i in range(3):
assert 'lock0[t' in str(sections[1].body[0].body[0].body[1 + i]) # Set-lock
assert str(sections[1].body[0].body[0].body[4].body[-1]) ==\
Expand Down Expand Up @@ -379,13 +376,13 @@ def test_tasking_lock_placement(self):
op = Operator(eqns, opt=('tasking', 'orchestrate'))

# Check generated code -- the wait-lock is expected in section1
trees = retrieve_iteration_tree(op)
assert len(trees) == 5
assert len(retrieve_iteration_tree(op)) == 5
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert sections[0].body[0].body[0].body[0].is_Iteration
assert str(trees[1].root.nodes[1].body[0]) == 'while(lock0[t1] == 0);'
assert str(sections[1].body[0].body[0].body[0].body[0]) ==\
'while(lock0[t1] == 0);'

@pytest.mark.parametrize('opt,ntmps', [
(('buffering', 'streaming', 'orchestrate'), 2),
Expand Down Expand Up @@ -817,12 +814,11 @@ def test_tasking_over_compiler_generated(self):

# Check generated code
for op in [op1, op2]:
trees = retrieve_iteration_tree(op)
assert len(trees) == 5
assert len(retrieve_iteration_tree(op)) == 5
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert 'while(lock0[t1] == 0)' in str(trees[1].root.nodes[1].body[0])
assert 'while(lock0[t1] == 0)' in str(sections[1].body[0].body[0].body[0])

op0.apply(time_M=nt-1)
op1.apply(time_M=nt-1, u=u1, usave=usave1)
Expand Down Expand Up @@ -1278,8 +1274,7 @@ def test_fuse_compatible_guards(self):
't,x,y,x,y,x,y')
nodes = FindNodes(Conditional).visit(op)
assert len(nodes) == 2
assert len(nodes[1].then_body) == 4
assert str(nodes[1].then_body[0].body[0]) == 'while(lock0[0] == 0);'
assert len(nodes[1].then_body) == 3
assert len(retrieve_iteration_tree(nodes[1])) == 2


Expand Down