Skip to content

Commit

Permalink
Merge pull request #2394 from devitocodes/data-assign-fix
Browse files Browse the repository at this point in the history
MPI: Fix data assignement for single mpi rank and factor haloupdate
  • Loading branch information
mloubout authored Jul 1, 2024
2 parents 4e86b12 + 41de762 commit 7f77489
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 5 deletions.
5 changes: 3 additions & 2 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __array_finalize__(self, obj):
self._allocator = ALLOC_ALIGNED
elif obj._index_stash is not None:
# From `__getitem__`
self._is_distributed = obj._is_distributed
self._distributor = obj._distributor
glb_idx = obj._normalize_index(obj._index_stash)
self._modulo = tuple(m for i, m in zip(glb_idx, obj._modulo)
Expand All @@ -131,10 +130,12 @@ def __array_finalize__(self, obj):
decomposition.append(dec.reshape(i))
self._decomposition = tuple(decomposition)
self._allocator = obj._allocator
decomp = any(i is not None for i in self._decomposition)
self._is_distributed = decomp and obj._is_distributed
else:
self._is_distributed = obj._is_distributed
self._distributor = obj._distributor
self._allocator = obj._allocator
self._is_distributed = obj._is_distributed
if self.ndim == obj.ndim:
# E.g., from a ufunc, such as `np.add`
self._modulo = obj._modulo
Expand Down
4 changes: 3 additions & 1 deletion devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def __init__(self, npoint, dimension, distributor):
def decompose(cls, npoint, distributor):
"""Distribute `npoint` points over `nprocs` MPI ranks."""
nprocs = distributor.nprocs
if isinstance(npoint, int):
if nprocs == 1:
return (npoint,)
elif isinstance(npoint, int):
# `npoint` is a global count. The `npoint` are evenly distributed
# across the various MPI ranks. Note that there is nothing smart
# in the following -- it's entirely possible that the MPI rank 0,
Expand Down
4 changes: 3 additions & 1 deletion devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MapNodes, MapHaloSpots, Transformer,
retrieve_iteration_tree)
from devito.ir.support import PARALLEL, Scope
from devito.ir.support.guards import GuardFactorEq
from devito.mpi.halo_scheme import HaloScheme
from devito.mpi.reduction_scheme import DistReduce
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
Expand Down Expand Up @@ -160,7 +161,8 @@ def rule2(dep, hs, loc_indices):

# Analysis
cond_mapper = MapHaloSpots().visit(iet)
cond_mapper = {hs: {i for i in v if i.is_Conditional}
cond_mapper = {hs: {i for i in v if i.is_Conditional and
not isinstance(i.condition, GuardFactorEq)}
for hs, v in cond_mapper.items()}

iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)
Expand Down
3 changes: 2 additions & 1 deletion devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,9 @@ def __shape_setup__(cls, **kwargs):
@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
time_dim = kwargs.get('time_dim', kwargs['grid'].time_dim)
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
dimensions = (time_dim,
*super().__indices_setup__(*args, **kwargs)[0])

if args:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ def test_indexing_into_sparse(self):
sf.data[1:-1, 0] = np.arange(8)
assert np.all(sf.data[1:-1, 0] == np.arange(8))

@pytest.mark.parallel(mode=1)
def test_indexing_into_sparse_subfunc_singlempi(self, mode):
grid = Grid(shape=(4, 4))
s = SparseFunction(name='sf', grid=grid, npoint=1)
coords = np.random.rand(*s.coordinates.data.shape)
s.coordinates.data[:] = coords

s.coordinates.data[-1, :] = s.coordinates.data[-1, :] / 2

assert np.allclose(s.coordinates.data[-1, :], coords[-1, :] / 2)


class TestLocDataIDX:
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,27 @@ def test_avoid_redundant_haloupdate(self, mode):
calls = FindNodes(Call).visit(op)
assert len(calls) == 1

@pytest.mark.parallel(mode=1)
def test_avoid_redundant_haloupdate_cond(self, mode):
grid = Grid(shape=(12,))
x = grid.dimensions[0]
t = grid.stepping_dim

i = Dimension(name='i')
j = Dimension(name='j')

f = TimeFunction(name='f', grid=grid)
g = Function(name='g', grid=grid)
t_sub = ConditionalDimension(name='t_sub', parent=t, factor=2)

op = Operator([Eq(f.forward, f[t, x-1] + f[t, x+1] + 1.),
Inc(f[t+1, i], 1.), # no halo update as it's an Inc
# access `f` at `t`, not `t+1` through factor subdim!
Eq(g, f[t, j] + 1, implicit_dim=t_sub)])

calls = FindNodes(Call).visit(op)
assert len(calls) == 1

@pytest.mark.parallel(mode=1)
def test_avoid_haloupdate_if_distr_but_sequential(self, mode):
grid = Grid(shape=(12,))
Expand Down

0 comments on commit 7f77489

Please sign in to comment.