Skip to content

Commit

Permalink
Merge pull request #2379 from devitocodes/sparse-gather
Browse files Browse the repository at this point in the history
mpi: Fix data_gather  for sparse functions
  • Loading branch information
mloubout authored May 29, 2024
2 parents 602d5e6 + 7b1079e commit daa1d85
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
14 changes: 14 additions & 0 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,20 @@ def decompose(cls, npoint, distributor):
raise TypeError('Need `npoint` int or tuple argument')
return tuple(glb_npoint)

@cached_property
def all_ranges(self):
"""The global ranges of all MPI ranks."""
ret = []
for i in self.decomposition[0]:
# i might be empty if there is less receivers than rank such as for a
# point source
try:
ret.append(EnrichedTuple(range(min(i), max(i) + 1),
getters=self.dimensions))
except ValueError:
ret.append(EnrichedTuple(range(0, 0), getters=self.dimensions))
return tuple(ret)

@property
def distributor(self):
return self._distributor
Expand Down
29 changes: 28 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Dimension, # noqa
Eq, Operator, ALLOC_GUARD, ALLOC_ALIGNED, configuration,
switchconfig)
switchconfig, SparseFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction)
from devito.data import LEFT, RIGHT, Decomposition, loc_data_idx, convert_index
from devito.tools import as_tuple
from devito.types import Scalar
Expand Down Expand Up @@ -1485,6 +1486,32 @@ def test_gather_time_function(self, mode):
else:
assert ans == np.array(None)

@pytest.mark.parallel(mode=[4, 6])
@pytest.mark.parametrize('sfunc', [SparseFunction,
SparseTimeFunction,
PrecomputedSparseFunction,
PrecomputedSparseTimeFunction])
@pytest.mark.parametrize('target_rank', [0, 2])
def test_gather_sparse(self, mode, sfunc, target_rank):
grid = Grid((11, 11))
myrank = grid._distributor.comm.Get_rank()
nt = 10
coords = [[0, 0], [0, .25], [0, .75], [0, 1]]
s = sfunc(name='s', grid=grid, npoint=4, r=4, nt=nt, coordinates=coords)

np.random.seed(1234)
try:
a = np.random.rand(s.nt, s.npoint_global)
except AttributeError:
a = np.random.rand(s.npoint_global,)

s.data[:] = a
out = s.data_gather(rank=target_rank)
if myrank == target_rank:
assert np.allclose(out, a)
else:
assert not out


def test_scalar_arg_substitution():
"""
Expand Down

0 comments on commit daa1d85

Please sign in to comment.