Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Revamp sparse subfunction #2374

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ def interpolation_coeffs(self):
sf = SubFunction(name="wsinc%s" % r.name, dtype=self.sfunction.dtype,
shape=shape, dimensions=dimensions,
space_order=0, alias=self.sfunction.alias,
distributor=self.sfunction._distributor,
parent=self.sfunction)
parent=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: we can probably save a line here by compacting these four lines into three, but obviously this is for another day

coeffs[r] = sf
return coeffs

Expand Down
9 changes: 2 additions & 7 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
timed_region, contains_val)
from devito.types import Grid, Evaluable, SubFunction
from devito.types import Grid, Evaluable

__all__ = ['Operator']

Expand Down Expand Up @@ -660,12 +660,7 @@ def _postprocess_errors(self, retval):
def _postprocess_arguments(self, args, **kwargs):
"""Process runtime arguments upon returning from ``.apply()``."""
for p in self.parameters:
try:
subfuncs = (args[getattr(p, s).name] for s in p._sub_functions)
p._arg_apply(args[p.name], *subfuncs, alias=kwargs.get(p.name))
except AttributeError:
if not (isinstance(p, SubFunction) and p.parent in self.parameters):
p._arg_apply(args[p.name], alias=kwargs.get(p.name))
p._arg_apply(args[p.name], alias=kwargs.get(p.name))

@cached_property
def _known_arguments(self):
Expand Down
79 changes: 48 additions & 31 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from collections import OrderedDict
try:
from collections.abc import Iterable
except ImportError:
# Before python 3.10
from collections import Iterable
from itertools import product

import sympy
Expand Down Expand Up @@ -34,6 +29,14 @@
_default_radius = {'linear': 1, 'sinc': 4}


class SparseSubFunction(SubFunction):

def _arg_apply(self, dataobj, **kwargs):
if self.parent is not None:
return self.parent._dist_subfunc_gather(dataobj, self)
return super()._arg_apply(dataobj, **kwargs)


class AbstractSparseFunction(DiscreteFunction):

"""
Expand All @@ -49,7 +52,8 @@ class AbstractSparseFunction(DiscreteFunction):
_sub_functions = ()
"""SubFunctions encapsulated within this AbstractSparseFunction."""

__rkwargs__ = DiscreteFunction.__rkwargs__ + ('npoint_global', 'space_order')
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
('dimensions', 'npoint_global', 'space_order'))

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
Expand All @@ -61,10 +65,22 @@ def __init_finalize__(self, *args, **kwargs):

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (Dimension(name='p_%s' % kwargs["name"]),)
# Need this not to break MatrixSparseFunction
try:
_sub_funcs = tuple(cls._sub_functions)
except TypeError:
_sub_funcs = ()
# If a subfunction provided use the sparse dimension
for f in _sub_funcs:
try:
sparse_dim = kwargs[f].indices[0]
break
except (KeyError, AttributeError):
continue
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong indent level for this else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the else goes there with the for. It's a python construct "if for didn't exit then " that avoids wrapping the for in a try/if

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, TIL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and it's quite handy

sparse_dim = Dimension(name='p_%s' % kwargs["name"])
mloubout marked this conversation as resolved.
Show resolved Hide resolved

dimensions = as_tuple(kwargs.get('dimensions', sparse_dim))
if args:
return tuple(dimensions), tuple(args)
else:
Expand Down Expand Up @@ -120,6 +136,22 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
if key is not None and not isinstance(key, SubFunction):
key = np.array(key)

# Check if already a SubFunction
if isinstance(key, SubFunction):
d = self.indices[self._sparse_position]
if d in key.indices:
# Can use as is, dimension already matches
if self.alias:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.alias:
    return key._rebuild(alias=self.alias, name=name)
return key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to keep the else for simple blocks and unindent when it's a decent code chunck that is the "main" part

return key._rebuild(alias=self.alias, name=name)
else:
return key
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else can be removed and the contents of the block unindented by a level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

# Need to rebuild so the dimensions match the parent
# SparseFunction, for example we end up here via `.subs(d, new_d)`
indices = (d, *key.indices[1:])
return key._rebuild(*indices, name=name, alias=self.alias)

# Given an array or nothing, create dimension and SubFunction
if key is not None:
dimensions = (self._sparse_dim, Dimension(name='d'))
if key.ndim > 2:
Expand All @@ -132,22 +164,13 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

# Check if already a SubFunction
if isinstance(key, SubFunction):
# Need to rebuild so the dimensions match the parent SparseFunction
indices = (self.indices[self._sparse_position], *key.indices[1:])
return key._rebuild(*indices, name=name, shape=shape,
alias=self.alias, halo=None)
elif key is not None and not isinstance(key, Iterable):
raise ValueError("`%s` must be either SubFunction "
"or iterable (e.g., list, np.ndarray)" % key)

if key is None:
# Fallback to default behaviour
dtype = dtype or self.dtype
else:
if (shape != key.shape and key.shape != (shape[1],)) and \
self._distributor.nprocs == 1:
if shape != key.shape and \
key.shape != (shape[1],) and \
self._distributor.nprocs == 1:
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" %
(suffix, key.shape[:2], shape))

Expand All @@ -157,7 +180,7 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

sf = SubFunction(
sf = SparseSubFunction(
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor, parent=self
Expand Down Expand Up @@ -584,12 +607,6 @@ def _dist_scatter(self, data=None):
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
return mapper

def _dist_gather(self, data, *subfunc):
self._dist_data_gather(data)
for (sg, s) in zip(subfunc, self._sub_functions):
if getattr(self, s) is not None:
self._dist_subfunc_gather(sg, getattr(self, s))

def _eval_at(self, func):
return self

Expand Down Expand Up @@ -637,11 +654,11 @@ def _arg_values(self, **kwargs):

return values

def _arg_apply(self, dataobj, *subfuncs, alias=None):
def _arg_apply(self, dataobj, alias=None):
key = alias if alias is not None else self
if isinstance(key, AbstractSparseFunction):
# Gather into `self.data`
key._dist_gather(dataobj, *subfuncs)
key._dist_data_gather(dataobj)
elif self._distributor.nprocs > 1:
raise NotImplementedError("Don't know how to gather data from an "
"object of type `%s`" % type(key))
Expand Down Expand Up @@ -698,7 +715,7 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]),)
*super().__indices_setup__(*args, **kwargs)[0])

if args:
return tuple(dimensions), tuple(args)
Expand Down
10 changes: 3 additions & 7 deletions examples/seismic/elastic/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from devito import Eq, Operator, VectorTimeFunction, TensorTimeFunction
from devito import div, grad, diag, solve
from examples.seismic import PointSource, Receiver


def src_rec(v, tau, model, geometry):
Expand All @@ -9,12 +8,9 @@ def src_rec(v, tau, model, geometry):
"""
s = model.grid.time_dim.spacing
# Source symbol with input wavelet
src = PointSource(name='src', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nsrc)
rec1 = Receiver(name='rec1', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
rec2 = Receiver(name='rec2', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
src = geometry.src
rec1 = geometry.new_rec(name="rec1")
rec2 = geometry.new_rec(name="rec2")

# The source injection term
src_expr = src.inject(tau.forward.diagonal(), expr=src * s)
Expand Down
27 changes: 12 additions & 15 deletions examples/seismic/inversion/fwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from devito import configuration, Function, norm, mmax, mmin

from examples.seismic import demo_model, AcquisitionGeometry, Receiver
from examples.seismic import demo_model, AcquisitionGeometry
from examples.seismic.acoustic import AcousticWaveSolver

from inversion_utils import compute_residual, update_with_box
Expand Down Expand Up @@ -55,33 +55,30 @@
source_locations[:, 0] = 20.
source_locations[:, 1] = np.linspace(0., 1000, num=nshots)

# Create placeholders for the data residual and data
residual = geometry.new_rec(name='residual')
d_obs = geometry.new_rec(name='d_obs', coordinates=residual.coordinates)
d_syn = geometry.new_rec(name='d_syn', coordinates=residual.coordinates)

src = solver.geometry.src


def fwi_gradient(vp_in):
# Create symbols to hold the gradient
grad = Function(name="grad", grid=model.grid)
objective = 0.
for i in range(nshots):
# Create placeholders for the data residual and data
residual = Receiver(name='residual', grid=model.grid,
time_range=geometry.time_axis,
coordinates=geometry.rec_positions)
d_obs = Receiver(name='d_obs', grid=model.grid,
time_range=geometry.time_axis,
coordinates=geometry.rec_positions)
d_syn = Receiver(name='d_syn', grid=model.grid,
time_range=geometry.time_axis,
coordinates=geometry.rec_positions)
# Update source location
solver.geometry.src_positions[0, :] = source_locations[i, :]
src.coordinates.data[0, :] = source_locations[i, :]

# Generate synthetic data from true model
solver.forward(vp=model.vp, rec=d_obs)
solver.forward(vp=model.vp, rec=d_obs, src=src)

# Compute smooth data and full forward wavefield u0
_, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn)
_, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn, src=src)

# Compute gradient from data residual and update objective function
residual = compute_residual(residual, d_obs, d_syn)
compute_residual(residual, d_obs, d_syn)

objective += .5*norm(residual)**2
solver.jacobian_adjoint(rec=residual, u=u0, vp=vp_in, grad=grad)
Expand Down
5 changes: 1 addition & 4 deletions examples/seismic/inversion/inversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ def compute_residual(res, dobs, dsyn):
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
diff_eq = Eq(res, dsyn - dobs)
Operator(diff_eq)()

return res


def update_with_box(vp, alpha, dm, vmin=2.0, vmax=3.5):
"""
Expand Down
9 changes: 9 additions & 0 deletions examples/seismic/test_seismic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,12 @@ def test_geom(shape):

assert geometry.new_rec(name="bonjour").name == "bonjour"
assert geometry.new_src(name="bonjour").name == "bonjour"

src1 = geometry.src
src2 = geometry.src
assert src1.coordinates is not src2.coordinates
assert src1._sparse_dim is src2._sparse_dim

src3 = geometry.new_src(name="src3", coordinates=src1.coordinates)
assert src1.coordinates is src3.coordinates
assert src1._sparse_dim is src3._sparse_dim
23 changes: 7 additions & 16 deletions examples/seismic/tti/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from devito import (Eq, Operator, Function, TimeFunction, NODE, Inc, solve,
cos, sin, sqrt, div, grad)
from examples.seismic import PointSource, Receiver
from examples.seismic.acoustic.operators import freesurface


Expand Down Expand Up @@ -444,10 +443,8 @@ def ForwardOperator(model, geometry, space_order=4,
v = TimeFunction(name='v', grid=model.grid, staggered=stagg_v,
save=geometry.nt if save else None,
time_order=time_order, space_order=space_order)
src = PointSource(name='src', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nsrc)
rec = Receiver(name='rec', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
src = geometry.src
rec = geometry.rec

# FD kernels of the PDE
FD_kernel = kernels[(kernel, len(model.shape))]
Expand Down Expand Up @@ -493,10 +490,8 @@ def AdjointOperator(model, geometry, space_order=4,
time_order=time_order, space_order=space_order)
r = TimeFunction(name='r', grid=model.grid, staggered=stagg_r,
time_order=time_order, space_order=space_order)
srca = PointSource(name='srca', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nsrc)
rec = Receiver(name='rec', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
srca = geometry.new_src(name='srca', src_type=None)
rec = geometry.rec

# FD kernels of the PDE
FD_kernel = kernels[(kernel, len(model.shape))]
Expand Down Expand Up @@ -535,11 +530,8 @@ def JacobianOperator(model, geometry, space_order=4,
time_order = 2

# Create source and receiver symbols
src = Receiver(name='src', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nsrc)

rec = Receiver(name='rec', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
src = geometry.src
rec = geometry.rec

# Create wavefields and a dm field
u0 = TimeFunction(name='u0', grid=model.grid, save=None, time_order=time_order,
Expand Down Expand Up @@ -607,8 +599,7 @@ def JacobianAdjOperator(model, geometry, space_order=4,

dm = Function(name="dm", grid=model.grid)

rec = Receiver(name='rec', grid=model.grid, time_range=geometry.time_axis,
npoint=geometry.nrec)
rec = geometry.rec

# FD kernels of the PDE
FD_kernel = kernels[('centered', len(model.shape))]
Expand Down
Loading
Loading