Skip to content

Commit

Permalink
mpi: fix empty aindices halo
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 7, 2023
1 parent f735820 commit 233e82b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
7 changes: 5 additions & 2 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def preprocess(clusters, options=None, **kwargs):
for c in clusters:
if c.is_halo_touch:
hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)
queue.append(c.rebuild(halo_scheme=hs))
if hs.distributed_aindices:
queue.append(c.rebuild(halo_scheme=hs))
else:
processed.append(c.rebuild(exprs=None, halo_scheme=hs))
elif c.is_critical_region and c.syncs:
processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs))
elif c.is_wild:
Expand All @@ -165,7 +168,7 @@ def preprocess(clusters, options=None, **kwargs):

# Skip if the halo exchange would end up outside
# its iteration space
if h_indices and not h_indices & dims:
if h_indices and dims and not h_indices & dims:
continue

diff = dims - distributed_aindices
Expand Down
29 changes: 28 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
SparseTimeFunction, Dimension, ConditionalDimension, SubDimension,
SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration,
switchconfig, generic_derivative, PrecomputedSparseFunction)
switchconfig, generic_derivative, PrecomputedSparseFunction,
DefaultDimension)
from devito.arch.compiler import OneapiCompiler
from devito.data import LEFT, RIGHT
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
Expand Down Expand Up @@ -600,6 +601,32 @@ def test_precomputed_sparse(self, r):
Operator(sf1.interpolate(u))()
assert np.all(sf1.data == 4)

@pytest.mark.parallel(mode=4)
def test_no_grid_dim_slow(self):
shape = (12, 13, 14)
nfreq = 5
nrec = 2

grid = Grid(shape=shape)
f = DefaultDimension(name="f", default_value=nfreq)

u = Function(name="u", grid=grid, dimensions=(*grid.dimensions, f),
shape=(*shape, nfreq), space_order=2)
u.data.fill(1)

class CoordSlowSparseFunction(SparseFunction):
_sparse_position = 0

r = Dimension(name="r")
s = CoordSlowSparseFunction(name="s", grid=grid, dimensions=(r, f),
shape=(nrec, nfreq), npoint=nrec)

rec_eq = s.interpolate(expr=u)

op = Operator(rec_eq)
op.apply()
assert np.all(s.data == 1)


class TestOperatorSimple(object):

Expand Down

0 comments on commit 233e82b

Please sign in to comment.