diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index a0af838429..9d4e1d31bf 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -481,6 +481,20 @@ def decompose(cls, npoint, distributor): raise TypeError('Need `npoint` int or tuple argument') return tuple(glb_npoint) + @cached_property + def all_ranges(self): + """The global ranges of all MPI ranks.""" + ret = [] + for i in self.decomposition[0]: + # i might be empty if there is less receivers than rank such as for a + # point source + try: + ret.append(EnrichedTuple(range(min(i), max(i) + 1), + getters=self.dimensions)) + except ValueError: + ret.append(EnrichedTuple(range(0, 0), getters=self.dimensions)) + return tuple(ret) + @property def distributor(self): return self._distributor diff --git a/tests/test_data.py b/tests/test_data.py index 87072c03da..47c1c78d16 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,7 +3,8 @@ from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Dimension, # noqa Eq, Operator, ALLOC_GUARD, ALLOC_ALIGNED, configuration, - switchconfig) + switchconfig, SparseFunction, PrecomputedSparseFunction, + PrecomputedSparseTimeFunction) from devito.data import LEFT, RIGHT, Decomposition, loc_data_idx, convert_index from devito.tools import as_tuple from devito.types import Scalar @@ -1485,6 +1486,32 @@ def test_gather_time_function(self, mode): else: assert ans == np.array(None) + @pytest.mark.parallel(mode=[4, 6]) + @pytest.mark.parametrize('sfunc', [SparseFunction, + SparseTimeFunction, + PrecomputedSparseFunction, + PrecomputedSparseTimeFunction]) + @pytest.mark.parametrize('target_rank', [0, 2]) + def test_gather_sparse(self, mode, sfunc, target_rank): + grid = Grid((11, 11)) + myrank = grid._distributor.comm.Get_rank() + nt = 10 + coords = [[0, 0], [0, .25], [0, .75], [0, 1]] + s = sfunc(name='s', grid=grid, npoint=4, r=4, nt=nt, coordinates=coords) + + np.random.seed(1234) + try: + a = np.random.rand(s.nt, s.npoint_global) + except AttributeError: + a = np.random.rand(s.npoint_global,) + + s.data[:] = a + out = s.data_gather(rank=target_rank) + if myrank == target_rank: + assert np.allclose(out, a) + else: + assert not out + def test_scalar_arg_substitution(): """