Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Cleanup and improve SubFunction #2198

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,16 @@ def run(expr):
terms.append(i)

# Collect common funcs
w_funcs = Add(*w_funcs, evaluate=False)
w_funcs = collect(w_funcs, funcs, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_funcs.items()])
except AttributeError:
assert w_funcs == 0
if len(w_funcs) > 1:
Copy link
Contributor Author

@mloubout mloubout Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FabioLuporini this does not fix the pickle reconstruction issue with (-1)* so left the _print_Mul in printer for the time being.

See
https://github.com/devitocodes/devito/actions/runs/6113350028/job/16592681861#step:9:457

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Can you open an issue about that and refer to the failing test without the hack?

w_funcs = Add(*w_funcs, evaluate=False)
w_funcs = collect(w_funcs, funcs, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_funcs.items()])
except AttributeError:
assert w_funcs == 0
else:
terms.extend(w_funcs)

# Collect common pows
w_pows = Add(*w_pows, evaluate=False)
Expand Down
9 changes: 6 additions & 3 deletions devito/symbolics/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# * Number
# * Symbol
# * Indexed
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject)
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject,
IndexedPointer)


def q_symbol(expr):
Expand All @@ -31,7 +32,9 @@ def q_symbol(expr):


def q_leaf(expr):
return expr.is_Atom or expr.is_Indexed or isinstance(expr, extra_leaves)
return (expr.is_Atom or
expr.is_Indexed or
isinstance(expr, extra_leaves))


def q_indexed(expr):
Expand All @@ -51,7 +54,7 @@ def q_derivative(expr):
def q_terminal(expr):
return (expr.is_Symbol or
expr.is_Indexed or
isinstance(expr, extra_leaves + (IndexedPointer,)))
isinstance(expr, extra_leaves))


def q_routine(expr):
Expand Down
17 changes: 4 additions & 13 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = grid.dimensions

if args:
assert len(args) == len(dimensions)
return tuple(dimensions), tuple(args)

# Staggered indices
Expand Down Expand Up @@ -1449,16 +1450,10 @@ class SubFunction(Function):
"""
A Function bound to a "parent" DiscreteFunction.

A SubFunction hands control of argument binding and halo exchange to its
parent DiscreteFunction.
A SubFunction hands control of argument binding and halo exchange to the
DiscreteFunction it's bound to.
"""

__rkwargs__ = Function.__rkwargs__ + ('parent',)

def __init_finalize__(self, *args, **kwargs):
super(SubFunction, self).__init_finalize__(*args, **kwargs)
self._parent = kwargs['parent']

def __padding_setup__(self, **kwargs):
# SubFunctions aren't expected to be used in time-consuming loops
return tuple((0, 0) for i in range(self.ndim))
Expand All @@ -1470,12 +1465,8 @@ def _arg_values(self, **kwargs):
if self.name in kwargs:
raise RuntimeError("`%s` is a SubFunction, so it can't be assigned "
"a value dynamically" % self.name)
else:
return self._parent._arg_defaults(alias=self._parent).reduce_all()

@property
def parent(self):
return self._parent
return self._arg_defaults(alias=self)

@property
def origin(self):
Expand Down
142 changes: 66 additions & 76 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = (Dimension(name='p_%s' % kwargs["name"]),)

if args:
indices = args
return tuple(dimensions), tuple(args)
else:
indices = dimensions

return dimensions, indices
return dimensions, dimensions

@classmethod
def __shape_setup__(cls, **kwargs):
Expand All @@ -80,16 +78,6 @@ def __shape_setup__(cls, **kwargs):
shape = (glb_npoint[grid.distributor.myrank],)
return shape

def func(self, *args, **kwargs):
# Rebuild subfunctions first to avoid new data creation as we have to use `_data`
# as a reconstruction kwargs to avoid the circular dependency
# with the parent in SubFunction
# This is also necessary to avoid shape issue in the SubFunction with mpi
for s in self._sub_functions:
if getattr(self, s) is not None:
kwargs.update({s: getattr(self, s).func(*args, **kwargs)})
return super().func(*args, **kwargs)

def __fd_setup__(self):
"""
Dynamically add derivative short-cuts.
Expand All @@ -108,24 +96,39 @@ def __distributor_setup__(self, **kwargs):
)

def __subfunc_setup__(self, key, suffix, dtype=None):
# Shape and dimensions from args
name = '%s_%s' % (self.name, suffix)

if key is not None and not isinstance(key, SubFunction):
key = np.array(key)

if key is not None:
dimensions = (self._sparse_dim, Dimension(name='d'))
if key.ndim > 2:
dimensions = (self._sparse_dim, Dimension(name='d'),
*mkdims("i", n=key.ndim-2))
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim, *key.shape[2:])
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

# Check if already a SubFunction
if isinstance(key, SubFunction):
return key
# Need to rebuild so the dimensions match the parent SparseFunction
indices = (self.indices[self._sparse_position], *key.indices[1:])
return key._rebuild(*indices, name=name, shape=shape,
alias=self.alias, halo=None)
elif key is not None and not isinstance(key, Iterable):
raise ValueError("`%s` must be either SubFunction "
"or iterable (e.g., list, np.ndarray)" % key)

name = '%s_%s' % (self.name, suffix)
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

if key is None:
# Fallback to default behaviour
dtype = dtype or self.dtype
else:
if key is not None:
key = np.array(key)

if (shape != key.shape[:2] and key.shape != (shape[1],)) and \
if (shape != key.shape and key.shape != (shape[1],)) and \
self._distributor.nprocs == 1:
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" %
(suffix, key.shape[:2], shape))
Expand All @@ -136,12 +139,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

if key is not None and key.ndim > 2:
shape = (*shape, *key.shape[2:])
dimensions = (*dimensions, *mkdims("i", n=key.ndim-2))

sf = SubFunction(
name=name, parent=self, dtype=dtype, dimensions=dimensions,
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor
)
Expand Down Expand Up @@ -657,20 +656,6 @@ def time_dim(self):
"""The time Dimension."""
return self._time_dim

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
shape = kwargs.get('shape')
Expand All @@ -686,6 +671,18 @@ def __shape_setup__(cls, **kwargs):

return tuple(shape)

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]),)

if args:
return tuple(dimensions), tuple(args)
else:
return dimensions, dimensions

@property
def nt(self):
return self.shape[self._time_position]
Expand Down Expand Up @@ -791,7 +788,7 @@ class SparseFunction(AbstractSparseFunction):

_sub_functions = ('coordinates',)

__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates_data',)
__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates',)

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
Expand Down Expand Up @@ -1014,8 +1011,8 @@ class PrecomputedSparseFunction(AbstractSparseFunction):
_sub_functions = ('gridpoints', 'coordinates', 'interpolation_coeffs')

__rkwargs__ = (AbstractSparseFunction.__rkwargs__ +
('r', 'gridpoints_data', 'coordinates_data',
'interpolation_coeffs_data'))
('r', 'gridpoints', 'coordinates',
'interpolation_coeffs'))

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
Expand All @@ -1032,13 +1029,14 @@ def __init_finalize__(self, *args, **kwargs):
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
if interpolation_coeffs is not None:
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
self._radius = r

if coordinates is not None and gridpoints is not None:
Expand Down Expand Up @@ -1179,6 +1177,15 @@ class PrecomputedSparseTimeFunction(AbstractSparseTimeFunction,
PrecomputedSparseFunction.__rkwargs__))


# *** MatrixSparse*Function API
# This is mostly legacy stuff which often escapes the devito's modus operandi

class DynamicSubFunction(SubFunction):

def _arg_defaults(self, **kwargs):
return {}


class MatrixSparseTimeFunction(AbstractSparseTimeFunction):
"""
A specialised type of SparseTimeFunction where the interpolation is externally
Expand Down Expand Up @@ -1378,7 +1385,7 @@ def __init_finalize__(self, *args, **kwargs):
else:
nnz_size = 1

self._mrow = SubFunction(
self._mrow = DynamicSubFunction(
name='mrow_%s' % self.name,
dtype=np.int32,
dimensions=(self.nnzdim,),
Expand All @@ -1387,7 +1394,7 @@ def __init_finalize__(self, *args, **kwargs):
parent=self,
allocator=self._allocator,
)
self._mcol = SubFunction(
self._mcol = DynamicSubFunction(
name='mcol_%s' % self.name,
dtype=np.int32,
dimensions=(self.nnzdim,),
Expand All @@ -1396,7 +1403,7 @@ def __init_finalize__(self, *args, **kwargs):
parent=self,
allocator=self._allocator,
)
self._mval = SubFunction(
self._mval = DynamicSubFunction(
name='mval_%s' % self.name,
dtype=self.dtype,
dimensions=(self.nnzdim,),
Expand All @@ -1413,8 +1420,8 @@ def __init_finalize__(self, *args, **kwargs):
self.par_dim_to_nnz_dim = DynamicDimension('par_dim_to_nnz_%s' % self.name)

# This map acts as an indirect sort of the sources according to their
# position along the parallelisation Dimension
self._par_dim_to_nnz_map = SubFunction(
# position along the parallelisation dimension
self._par_dim_to_nnz_map = DynamicSubFunction(
name='par_dim_to_nnz_map_%s' % self.name,
dtype=np.int32,
dimensions=(self.par_dim_to_nnz_dim,),
Expand All @@ -1423,7 +1430,7 @@ def __init_finalize__(self, *args, **kwargs):
space_order=0,
parent=self,
)
self._par_dim_to_nnz_m = SubFunction(
self._par_dim_to_nnz_m = DynamicSubFunction(
name='par_dim_to_nnz_m_%s' % self.name,
dtype=np.int32,
dimensions=(self._par_dim,),
Expand All @@ -1432,7 +1439,7 @@ def __init_finalize__(self, *args, **kwargs):
space_order=0,
parent=self,
)
self._par_dim_to_nnz_M = SubFunction(
self._par_dim_to_nnz_M = DynamicSubFunction(
name='par_dim_to_nnz_M_%s' % self.name,
dtype=np.int32,
dimensions=(self._par_dim,),
Expand Down Expand Up @@ -1671,23 +1678,6 @@ def inject(self, field, expr, u_t=None, p_t=None):

return out

@classmethod
def __indices_setup__(cls, *args, **kwargs):
"""
Return the default Dimension indices for a given data shape.
"""
dimensions = kwargs.get('dimensions')
if dimensions is None:
dimensions = (kwargs['grid'].time_dim, Dimension(
name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
# This happens before __init__, so we have to get 'npoint'
Expand Down
22 changes: 9 additions & 13 deletions examples/seismic/inversion/inversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ def compute_residual(res, dobs, dsyn):
"""
Computes the data residual dsyn - dobs into residual
"""
if res.grid.distributor.is_parallel:
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()
else:
# A simple data difference is enough in serial
res.data[:] = dsyn.data[:] - dobs.data[:]
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()

return res

Expand Down
Loading