From 674f53a2556f2fcf2378a1e34efc5750809b3ee0 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 29 May 2024 12:00:47 -0400 Subject: [PATCH] mpi: Fix data_gather for sparse functions --- conftest.py | 4 ++-- devito/mpi/distributed.py | 14 ++++++++++++++ tests/test_data.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index c38829f6e8b..c80aa6c6264 100644 --- a/conftest.py +++ b/conftest.py @@ -149,11 +149,11 @@ def parallel(item, m): else: testname = "%s::%s" % (item.fspath, item.name) - args = ["-n", "1", pyversion, "-m", "pytest", "--no-summary", "-s", + args = ["-n", "1", pyversion, "-m", "pytest", "-s", "--runxfail", "-qq", testname] if nprocs > 1: args.extend([":", "-n", "%d" % (nprocs - 1), pyversion, "-m", "pytest", - "-s", "--runxfail", "--tb=no", "-qq", "--no-summary", testname]) + "-s", "--runxfail", "--tb=no", "-qq", testname]) # OpenMPI requires an explicit flag for oversubscription. We need it as some # of the MPI tests will spawn lots of processes if mpi_distro == 'OpenMPI': diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index a0af8384291..9d4e1d31bf1 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -481,6 +481,20 @@ def decompose(cls, npoint, distributor): raise TypeError('Need `npoint` int or tuple argument') return tuple(glb_npoint) + @cached_property + def all_ranges(self): + """The global ranges of all MPI ranks.""" + ret = [] + for i in self.decomposition[0]: + # i might be empty if there is less receivers than rank such as for a + # point source + try: + ret.append(EnrichedTuple(range(min(i), max(i) + 1), + getters=self.dimensions)) + except ValueError: + ret.append(EnrichedTuple(range(0, 0), getters=self.dimensions)) + return tuple(ret) + @property def distributor(self): return self._distributor diff --git a/tests/test_data.py b/tests/test_data.py index 87072c03da1..47c1c78d169 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,7 +3,8 @@ from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Dimension, # noqa Eq, Operator, ALLOC_GUARD, ALLOC_ALIGNED, configuration, - switchconfig) + switchconfig, SparseFunction, PrecomputedSparseFunction, + PrecomputedSparseTimeFunction) from devito.data import LEFT, RIGHT, Decomposition, loc_data_idx, convert_index from devito.tools import as_tuple from devito.types import Scalar @@ -1485,6 +1486,32 @@ def test_gather_time_function(self, mode): else: assert ans == np.array(None) + @pytest.mark.parallel(mode=[4, 6]) + @pytest.mark.parametrize('sfunc', [SparseFunction, + SparseTimeFunction, + PrecomputedSparseFunction, + PrecomputedSparseTimeFunction]) + @pytest.mark.parametrize('target_rank', [0, 2]) + def test_gather_sparse(self, mode, sfunc, target_rank): + grid = Grid((11, 11)) + myrank = grid._distributor.comm.Get_rank() + nt = 10 + coords = [[0, 0], [0, .25], [0, .75], [0, 1]] + s = sfunc(name='s', grid=grid, npoint=4, r=4, nt=nt, coordinates=coords) + + np.random.seed(1234) + try: + a = np.random.rand(s.nt, s.npoint_global) + except AttributeError: + a = np.random.rand(s.npoint_global,) + + s.data[:] = a + out = s.data_gather(rank=target_rank) + if myrank == target_rank: + assert np.allclose(out, a) + else: + assert not out + def test_scalar_arg_substitution(): """