Skip to content

Commit

Permalink
passes: Fixup _hoist_halospots
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Nov 9, 2020
1 parent a719a26 commit bd88a53
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
10 changes: 5 additions & 5 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def _hoist_halospots(iet):
# a stopper to halo hoisting

def rule0(dep, candidates):
# E.g., `dep=W<f,[x]> -> R<f,[x-1]>` and `candidates=(time, x)` => False
# E.g., `dep=W<f,[x]> -> R<f,[x-1]>` and `candidates=({time}, {x})` => False
# E.g., `dep=W<f,[t1, x, y]> -> R<f,[t0, x-1, y+1]>`, `dep.cause={t,time}` and
# `candidates=(x,)` => True
return (all(d in dep.distance_mapper for d in candidates) and
not dep.cause & candidates)
# `candidates=({x},)` => True
return (all(i & set(dep.distance_mapper) for i in candidates) and
not any(i & dep.cause for i in candidates))

def rule1(dep, candidates):
# An increment isn't a stopper to hoisting
Expand All @@ -94,7 +94,7 @@ def rule1(dep, candidates):

for f in hs.fmapper:
for n, i in enumerate(iters):
candidates = set().union(*[i.dim._defines for i in iters[n:]])
candidates = [i.dim._defines for i in iters[n:]]

test = True
for dep in scopes[i].d_flow.project(f):
Expand Down
20 changes: 14 additions & 6 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,9 +1731,13 @@ def test_cire_with_shifted_diagonal_halo_touch(self):
assert u0_norm == u1_norm

@pytest.mark.parallel(mode=4)
def test_cire_w_rotations(self):
@pytest.mark.parametrize('opt_options', [
{'cire-repeats-sops': 9, 'cire-rotate': True}, # Issue #1490 (rotating registers)
{'min-storage': True}, # Issue #1491 (min-storage option)
])
def test_cire_options(self, opt_options):
"""
MFE for issue #1490.
MFEs for several issues tracked on GitHub.
"""
grid = Grid(shape=(128, 128, 128), dtype=np.float64)

Expand All @@ -1746,19 +1750,23 @@ def test_cire_w_rotations(self):
eqn = Eq(p.forward, (p.dx).dx + (p.dy).dy + (p.dz).dz)

op0 = Operator(eqn, opt='noop')
op1 = Operator(eqn, opt=('advanced', {'cire-repeats-sops': 9,
'cire-rotate': True}))
op1 = Operator(eqn, opt=('advanced', opt_options))

# Check generated code
arrays = [i for i in FindSymbols().visit(op1._func_table['bf0']) if i.is_Array]
assert len(arrays) == 3
assert 'haloupdate_0' in op1._func_table
# We expect exactly one halo exchange
calls = FindNodes(Call).visit(op1)
assert len(calls) == 5
assert calls[0].name == 'haloupdate_0'
assert all(i.name == 'bf0' for i in calls[1:])

op0.apply(time_M=1)
op1.apply(time_M=1, p=p1)

# TODO: we can tighthen the tolerance, or switch to single precision,
# once issue #1438 is fixed
# TODO: we will tighten the tolerance, or switch to single precision,
# or both, once issue #1438 is fixed
assert np.allclose(p.data, p1.data, rtol=10e-11)

@pytest.mark.parallel(mode=[(4, 'full', True)])
Expand Down

0 comments on commit bd88a53

Please sign in to comment.