Skip to content

Commit

Permalink
mpi: Fix data_gather for sparse functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 29, 2024
1 parent 602d5e6 commit 674f53a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def parallel(item, m):
else:
testname = "%s::%s" % (item.fspath, item.name)

args = ["-n", "1", pyversion, "-m", "pytest", "--no-summary", "-s",
args = ["-n", "1", pyversion, "-m", "pytest", "-s",
"--runxfail", "-qq", testname]
if nprocs > 1:
args.extend([":", "-n", "%d" % (nprocs - 1), pyversion, "-m", "pytest",
"-s", "--runxfail", "--tb=no", "-qq", "--no-summary", testname])
"-s", "--runxfail", "--tb=no", "-qq", testname])
# OpenMPI requires an explicit flag for oversubscription. We need it as some
# of the MPI tests will spawn lots of processes
if mpi_distro == 'OpenMPI':
Expand Down
14 changes: 14 additions & 0 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,20 @@ def decompose(cls, npoint, distributor):
raise TypeError('Need `npoint` int or tuple argument')
return tuple(glb_npoint)

@cached_property
def all_ranges(self):
"""The global ranges of all MPI ranks."""
ret = []
for i in self.decomposition[0]:
# i might be empty if there is less receivers than rank such as for a
# point source
try:
ret.append(EnrichedTuple(range(min(i), max(i) + 1),
getters=self.dimensions))
except ValueError:
ret.append(EnrichedTuple(range(0, 0), getters=self.dimensions))
return tuple(ret)

@property
def distributor(self):
return self._distributor
Expand Down
29 changes: 28 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Dimension, # noqa
Eq, Operator, ALLOC_GUARD, ALLOC_ALIGNED, configuration,
switchconfig)
switchconfig, SparseFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction)
from devito.data import LEFT, RIGHT, Decomposition, loc_data_idx, convert_index
from devito.tools import as_tuple
from devito.types import Scalar
Expand Down Expand Up @@ -1485,6 +1486,32 @@ def test_gather_time_function(self, mode):
else:
assert ans == np.array(None)

@pytest.mark.parallel(mode=[4, 6])
@pytest.mark.parametrize('sfunc', [SparseFunction,
SparseTimeFunction,
PrecomputedSparseFunction,
PrecomputedSparseTimeFunction])
@pytest.mark.parametrize('target_rank', [0, 2])
def test_gather_sparse(self, mode, sfunc, target_rank):
grid = Grid((11, 11))
myrank = grid._distributor.comm.Get_rank()
nt = 10
coords = [[0, 0], [0, .25], [0, .75], [0, 1]]
s = sfunc(name='s', grid=grid, npoint=4, r=4, nt=nt, coordinates=coords)

np.random.seed(1234)
try:
a = np.random.rand(s.nt, s.npoint_global)
except AttributeError:
a = np.random.rand(s.npoint_global,)

s.data[:] = a
out = s.data_gather(rank=target_rank)
if myrank == target_rank:
assert np.allclose(out, a)
else:
assert not out


def test_scalar_arg_substitution():
"""
Expand Down

0 comments on commit 674f53a

Please sign in to comment.