Skip to content

Commit

Permalink
mpi: fix handling of subfunction when used alone
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Dec 13, 2023
1 parent ac68fb3 commit 030e0b6
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 22 deletions.
6 changes: 2 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ class IndexDerivative(IndexSum):
__rargs__ = ('expr', 'mapper')

def __new__(cls, expr, mapper, **kwargs):
dimensions = as_tuple(mapper.values())
dimensions = as_tuple(set(mapper.values()))

# Detect the Weights among the arguments
weightss = []
Expand Down Expand Up @@ -799,9 +799,7 @@ def _evaluate(self, **kwargs):
mapper = {w.subs(d, i): f.weights[n] for n, i in enumerate(d.range)}
expr = expr.xreplace(mapper)

basexpr = set(a.function for a in self.expr.args if a.is_Function) - {f}

return EvalDerivative(expr, base=basexpr.pop())
return EvalDerivative(expr, base=self.base)


# SymPy args ordering is the same for Derivatives and IndexDerivatives
Expand Down
5 changes: 3 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
timed_region, contains_val)
from devito.types import Grid, Evaluable
from devito.types import Grid, Evaluable, SubFunction

__all__ = ['Operator']

Expand Down Expand Up @@ -648,7 +648,8 @@ def _postprocess_arguments(self, args, **kwargs):
subfuncs = (args[getattr(p, s).name] for s in p._sub_functions)
p._arg_apply(args[p.name], *subfuncs, alias=kwargs.get(p.name))
except AttributeError:
p._arg_apply(args[p.name], alias=kwargs.get(p.name))
if not (isinstance(p, SubFunction) and p.parent in self.parameters):
p._arg_apply(args[p.name], alias=kwargs.get(p.name))

@cached_property
def _known_arguments(self):
Expand Down
1 change: 1 addition & 0 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ def lower_schedule(schedule, meta, sregistry, ftemps):
properties[d] = normalize_properties(v, {PARALLEL_IF_PVT}) - \
{ROUNDABLE}
except KeyError:
# Non-dimension key such as (x, y) for diagonal stencil u(x+i hx, y+i hy)
pass

# Track star-shaped stencils for potential future optimization
Expand Down
21 changes: 19 additions & 2 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,12 @@ class SubFunction(Function):
DiscreteFunction it's bound to.
"""

__rkwargs__ = DiscreteFunction.__rkwargs__ + ('dimensions', 'shape')

def __init_finalize__(self, *args, **kwargs):
self._parent = kwargs.pop('parent', None)
super().__init_finalize__(*args, **kwargs)

def __padding_setup__(self, **kwargs):
# SubFunctions aren't expected to be used in time-consuming loops
return tuple((0, 0) for i in range(self.ndim))
Expand All @@ -1465,17 +1471,28 @@ def _halo_exchange(self):
return

def _arg_values(self, **kwargs):
if self.name in kwargs:
if self._parent is not None and self.parent.name not in kwargs:
return self._parent._arg_defaults(alias=self._parent).reduce_all()
elif self.name in kwargs:
raise RuntimeError("`%s` is a SubFunction, so it can't be assigned "
"a value dynamically" % self.name)
else:
return self._arg_defaults(alias=self)

return self._arg_defaults(alias=self)
def _arg_apply(self, *args, **kwargs):
if self._parent is not None:
return self._parent._arg_apply(*args, **kwargs)
return super()._arg_apply(*args, **kwargs)

@property
def origin(self):
# SubFunction have zero origin
return DimensionTuple(*(0 for _ in range(self.ndim)), getters=self.dimensions)

@property
def parent(self):
return self._parent


class TempFunction(DiscreteFunction):

Expand Down
25 changes: 16 additions & 9 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def __shape_setup__(cls, **kwargs):
if i is cls._sparse_position:
loc_shape.append(glb_npoint[grid.distributor.myrank])
elif d in grid.dimensions:
loc_shape.append(grid.dimension_map[d])
loc_shape.append(grid.dimension_map[d].loc)
else:
loc_shape.append(s)
return loc_shape
return tuple(loc_shape)

def __fd_setup__(self):
"""
Expand All @@ -100,11 +100,13 @@ def __distributor_setup__(self, **kwargs):
A `SparseDistributor` handles the SparseFunction decomposition based on
physical ownership, and allows to convert between global and local indices.
"""
return SparseDistributor(
kwargs.get('npoint', kwargs.get('npoint_global')),
self._sparse_dim,
kwargs['grid'].distributor
)
distributor = kwargs.get('distributor')
if distributor is None:
distributor = SparseDistributor(
kwargs.get('npoint', kwargs.get('npoint_global')),
self._sparse_dim, kwargs['grid'].distributor)

return distributor

def __subfunc_setup__(self, key, suffix, dtype=None):
# Shape and dimensions from args
Expand Down Expand Up @@ -153,7 +155,7 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
sf = SubFunction(
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor
distributor=self._distributor, parent=self
)

if self.npoint == 0:
Expand All @@ -165,9 +167,13 @@ def __subfunc_setup__(self, key, suffix, dtype=None):

return sf

@property
def sparse_position(self):
return self._sparse_position

@property
def _sparse_dim(self):
return self.dimensions[self._sparse_position]
return self.dimensions[self.sparse_position]

@property
def _mpitype(self):
Expand All @@ -185,6 +191,7 @@ def _comm(self):

@property
def _coords_indices(self):
# Assume integer dtypemeans gridpioints
if self.gridpoints_data is not None:
return self.gridpoints_data
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,12 @@ class SparseFirst(SparseFunction):
ds = DefaultDimension("ps", default_value=3)
grid = Grid((11, 11))
dims = grid.dimensions
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3),
coordinates=[[.5, .5], [.2, .2]])
s = SparseFirst(name="s", grid=grid, npoint=4, dimensions=(dr, ds), shape=(4, 3),
coordinates=[[.1, .1], [.2, .2], [.3, .3], [.5, .5]])

# Check dimensions and shape are correctly initialized
assert s.indices[s._sparse_position] == dr
assert s.shape == (2, 3)
assert s.shape == (1, 3)
assert s.coordinates.indices[0] == dr

# Operator
Expand Down
35 changes: 33 additions & 2 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import scipy.sparse

from devito import Grid, TimeFunction, Eq, Operator, Dimension
from devito import Grid, TimeFunction, Eq, Operator, Dimension, Function
from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction, MatrixSparseTimeFunction)

Expand Down Expand Up @@ -455,6 +455,37 @@ def test_subs(self, sptype):
assert getattr(sps, subf).indices[0] == new_spdim
assert np.all(getattr(sps, subf).data == getattr(sp, subf).data)

@pytest.mark.parallel(mode=[1, 4])
def test_mpi_no_data(self):
grid = Grid((11, 11), extent=(10, 10))
time = grid.time_dim
# Base object
sp = SparseTimeFunction(name="s", grid=grid, npoint=1, nt=1,
coordinates=[[5., 5.]])

m = TimeFunction(name="m", grid=grid, space_order=2, time_order=1)
eq = [Eq(m.forward, m + m.laplace)]

op = Operator(eq + sp.inject(field=m.forward, expr=time))
# Not using the source data so can run with any time_M
op(time_M=5)

expected = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 4., -10., 4., 0., 0., 0., 0.],
[0., 0., 0., 6., -30., 55., -30., 6., 0., 0., 0.],
[0., 0., 4., -30., 102., -158., 102., -30., 4., 0., 0.],
[0., 1., -10., 55., -158., 239., -158., 55., -10., 1., 0.],
[0., 0., 4., -30., 102., -158., 102., -30., 4., 0., 0.],
[0., 0., 0., 6., -30., 55., -30., 6., 0., 0., 0.],
[0., 0., 0., 0., 4., -10., 4., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

ftest = Function(name='ftest', grid=grid, space_order=2)
ftest.data[:] = expected
assert np.all(m.data[0, :, :] == ftest.data[:])


if __name__ == "__main__":
TestMatrixSparseTimeFunction().test_mpi()
TestMatrixSparseTimeFunction().test_mpi_no_data()

0 comments on commit 030e0b6

Please sign in to comment.