Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: Fix data_gather for sparse functions #2379

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions devito/mpi/distributed.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fixing any known issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there isn't one just somethign found out

Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,20 @@ def decompose(cls, npoint, distributor):
raise TypeError('Need `npoint` int or tuple argument')
return tuple(glb_npoint)

@cached_property
def all_ranges(self):
"""The global ranges of all MPI ranks."""
ret = []
for i in self.decomposition[0]:
# i might be empty if there is less receivers than rank such as for a
# point source
try:
ret.append(EnrichedTuple(range(min(i), max(i) + 1),
getters=self.dimensions))
except ValueError:
ret.append(EnrichedTuple(range(0, 0), getters=self.dimensions))
return tuple(ret)

@property
def distributor(self):
return self._distributor
Expand Down
29 changes: 28 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Dimension, # noqa
Eq, Operator, ALLOC_GUARD, ALLOC_ALIGNED, configuration,
switchconfig)
switchconfig, SparseFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction)
from devito.data import LEFT, RIGHT, Decomposition, loc_data_idx, convert_index
from devito.tools import as_tuple
from devito.types import Scalar
Expand Down Expand Up @@ -1485,6 +1486,32 @@ def test_gather_time_function(self, mode):
else:
assert ans == np.array(None)

@pytest.mark.parallel(mode=[4, 6])
@pytest.mark.parametrize('sfunc', [SparseFunction,
SparseTimeFunction,
PrecomputedSparseFunction,
PrecomputedSparseTimeFunction])
@pytest.mark.parametrize('target_rank', [0, 2])
def test_gather_sparse(self, mode, sfunc, target_rank):
grid = Grid((11, 11))
myrank = grid._distributor.comm.Get_rank()
nt = 10
coords = [[0, 0], [0, .25], [0, .75], [0, 1]]
s = sfunc(name='s', grid=grid, npoint=4, r=4, nt=nt, coordinates=coords)

np.random.seed(1234)
try:
a = np.random.rand(s.nt, s.npoint_global)
except AttributeError:
a = np.random.rand(s.npoint_global,)

s.data[:] = a
out = s.data_gather(rank=target_rank)
if myrank == target_rank:
assert np.allclose(out, a)
else:
assert not out


def test_scalar_arg_substitution():
"""
Expand Down
Loading