Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: cleanup FD tools and support zeroth order derivative #2368

Merged
merged 7 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count

# Sanitise `dims`. ((x, 2), (y, 0)) is valid input, but (y, 0) should be dropped.
dims = tuple(d for d in dims if not (isinstance(d, Iterable) and d[1] == 0))

# Check `dims`. It can be a single Dimension, an iterable of Dimensions, or even
# an iterable of 2-tuple (Dimension, deriv_order)
if len(dims) == 0:
Expand All @@ -160,15 +157,18 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
orders = kwargs.get('deriv_order', 1)
if isinstance(orders, Iterable):
orders = orders[0]
new_dims = tuple([dims[0]]*orders)
if orders == 0:
new_dims = (dims[0],)
else:
new_dims = tuple([dims[0]]*orders)
else:
# Iterable of 2-tuple, e.g. ((x, 2), (y, 3))
new_dims = []
orders = []
d_ord = kwargs.get('deriv_order', tuple([1]*len(dims)))
for d, o in zip(dims, d_ord):
if isinstance(d, Iterable):
new_dims.extend([d[0] for _ in range(d[1])])
new_dims.extend([d[0] for _ in range(max(1, d[1]))])
orders.append(d[1])
else:
new_dims.extend([d for _ in range(o)])
Expand Down
4 changes: 4 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __eq__(self, other):
if ret is NotImplemented or not ret:
# Non comparable or not equal as sympy objects
return False

return all(getattr(self, i, None) == getattr(other, i, None)
for i in self.__rkwargs__)

Expand Down Expand Up @@ -876,6 +877,9 @@ def __new__(cls, *args, base=None, **kwargs):

func = DifferentiableOp._rebuild

__eq__ = sympy.Add.__eq__
mloubout marked this conversation as resolved.
Show resolved Hide resolved
__hash__ = sympy.Add.__hash__

def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)
Expand Down
4 changes: 4 additions & 0 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ def generic_derivative(expr, dim, fd_order, deriv_order, matvec=direct, x0=None,
if deriv_order == 1 and fd_order == 2 and coefficients != 'symbolic':
fd_order = 1

# Zeroth order derivative is just the expression itself if not shifted
if deriv_order == 0 and not x0:
return expr

# Enforce stable time coefficients
if dim.is_Time and coefficients != 'symbolic':
coefficients = 'taylor'
Expand Down
4 changes: 2 additions & 2 deletions devito/finite_differences/rsfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from devito.types import NODE
from devito.types.dimension import StencilDimension
from .differentiable import Weights, DiffDerivative
from .tools import generate_indices_staggered, fd_weights_registry
from .tools import generate_indices, fd_weights_registry

__all__ = ['drot', 'd45']

Expand Down Expand Up @@ -42,7 +42,7 @@ def drot(expr, dim, dir=1, x0=None):
s = 2**(expr.grid.dim - 1)

# Center point and indices
start, indices = generate_indices_staggered(expr, dim, expr.space_order, x0=x0)
indices, start = generate_indices(expr, dim, expr.space_order, x0=x0)

# a-dimensional center for FD coefficients.
adim_start = x0.get(dim, expr.indices_ref[dim]).subs({dim: 0, dim.spacing: 1})
Expand Down
133 changes: 15 additions & 118 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,128 +275,25 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None):
-------
An IndexSet, representing an ordered list of indices.
"""
if expr.is_Staggered and not dim.is_Time and side is None:
x0, indices = generate_indices_staggered(expr, dim, order, side=side, x0=x0)
else:
x0 = (x0 or {dim: dim}).get(dim, dim)
# Check if called from first_derivative()
indices = generate_indices_cartesian(expr, dim, order, side, x0)
# Evaluation point
x0 = ((x0 or {}).get(dim) or expr.indices_ref[dim])

assert isinstance(indices, IndexSet)
# Shift for side
side = side or centered

return indices, x0
# Evaluation point relative to the expression's grid
mid = (x0 - expr.indices_ref[dim]).subs({dim: 0, dim.spacing: 1})
# Indices range
o_min = int(np.ceil(mid - order/2)) + side.val
o_max = int(np.floor(mid + order/2)) + side.val
if o_max == o_min:
o_max += 1

# StencilDimension and expression
d = make_stencil_dimension(expr, o_min, o_max)
iexpr = expr.indices_ref[dim] + d * dim.spacing

def generate_indices_cartesian(expr, dim, order, side, x0):
"""
Indices for the finite-difference scheme on a cartesian grid.

Parameters
----------
expr : expr-like
Expression that is differentiated.
dim : Dimension
Dimensions w.r.t which the derivative is taken.
order : int
Order of the finite-difference scheme.
side : Side
Side of the scheme (centered, left, right).
x0 : dict of {Dimension: Dimension or expr-like or Number}
Origin of the scheme, ie. `x`, `x + .5 * x.spacing`, ...

Returns
-------
An IndexSet, representing an ordered list of indices.
"""
shift = 0
# Shift if `x0` is not on the grid
offset_c = 0 if sympify(x0).is_Integer else (dim - x0)/dim.spacing
offset_c = np.sign(offset_c) * (offset_c % 1)
offset = offset_c * dim.spacing
# Spacing
diff = dim.spacing
if side in [left, right]:
shift = 1
diff *= side.val
# Indices
if order < 2:
indices = [x0, x0 + diff] if offset == 0 else [x0 - offset, x0 + offset]
return IndexSet(dim, indices)
else:
# Left and right max offsets for indices
o_min = -order//2 + int(np.ceil(-offset_c))
o_max = order//2 - int(np.ceil(offset_c))

d = make_stencil_dimension(expr, o_min, o_max)
iexpr = x0 + (d + shift) * diff + offset
return IndexSet(dim, expr=iexpr)


def generate_indices_staggered(expr, dim, order, side=None, x0=None):
"""
Indices for the finite-difference scheme on a staggered grid.

Parameters
----------
expr : expr-like
Expression that is differentiated.
dim : Dimension
Dimensions w.r.t which the derivative is taken.
order : int
Order of the finite-difference scheme.
side : Side, optional
Side of the scheme (centered, left, right).
x0 : dict of {Dimension: Dimension or expr-like or Number}, optional
Origin of the scheme, ie. `x`, `x + .5 * x.spacing`, ...

Returns
-------
An IndexSet, representing an ordered list of indices.
"""
diff = dim.spacing
start = (x0 or {}).get(dim) or expr.indices_ref[dim]
try:
ind0 = expr.indices_ref[dim]
except AttributeError:
ind0 = start

if start != ind0:
if order < 2:
indices = [start - diff/2, start + diff/2]
indices = IndexSet(dim, indices)
else:
o_min = -order//2 + 1
o_max = order//2

d = make_stencil_dimension(expr, o_min, o_max)
iexpr = start - diff/2 + d * diff
indices = IndexSet(dim, expr=iexpr)
else:
if order < 2:
indices = [start, start - diff]
indices = IndexSet(dim, indices)
else:
if x0 is None or order % 2 == 0:
# No _eval_at or even order derivatives
# keep the centered indices
o_min = -order//2
o_max = order//2
elif start is dim:
# Staggered FD requires half cell shift
# for stability
o_min = -order//2 + 1
o_max = order//2
start = dim + diff/2
else:
o_min = -order//2
o_max = order//2 - 1
start = dim

d = make_stencil_dimension(expr, o_min, o_max)
iexpr = ind0 + d * diff
indices = IndexSet(dim, expr=iexpr)

return start, indices
return IndexSet(dim, expr=iexpr), x0


def make_shift_x0(shift, ndim):
Expand Down
23 changes: 16 additions & 7 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def test_fd_new_side(self):
assert u.dxl(side=centered).evaluate == u.dx.evaluate

@pytest.mark.parametrize('so, expected', [
(2, 'u(x)/h_x - u(x - 1.0*h_x)/h_x'),
(2, 'u(x)/h_x - u(x - h_x)/h_x'),
(4, '1.125*u(x)/h_x + 0.0416666667*u(x - 2*h_x)/h_x - '
'1.125*u(x - h_x)/h_x - 0.0416666667*u(x + h_x)/h_x'),
(6, '1.171875*u(x)/h_x - 0.0046875*u(x - 3*h_x)/h_x + '
Expand Down Expand Up @@ -445,8 +445,8 @@ def test_fd_new_lo(self):
x = grid.dimensions[0]
u = Function(name="u", grid=grid, space_order=2)

dplus = "-u(x)/h_x + u(x + 1.0*h_x)/h_x"
dminus = "u(x)/h_x - u(x - 1.0*h_x)/h_x"
dplus = "-u(x)/h_x + u(x + h_x)/h_x"
dminus = "u(x)/h_x - u(x - h_x)/h_x"
assert str(u.dx(x0=x + .5 * x.spacing).evaluate) == dplus
assert str(u.dx(x0=x - .5 * x.spacing).evaluate) == dminus
assert str(u.dx(x0=x + .5 * x.spacing, fd_order=1).evaluate) == dplus
Expand Down Expand Up @@ -639,13 +639,22 @@ def test_zero_spec(self):
"""
grid = Grid((11, 11))
x, y = grid.dimensions
f = Function(name="f", grid=grid, space_order=4)
f = Function(name="f", grid=grid, space_order=2)
# Check that both specifications match
drv0 = Derivative(f, (x, 2))
drv1 = Derivative(f, (x, 2), (y, 0))
assert drv0.dims == drv1.dims
assert drv0.fd_order == drv1.fd_order
assert drv0.deriv_order == drv1.deriv_order
assert drv0.dims == (x,)
assert drv1.dims == (x, y)
assert drv0.fd_order == 2
assert drv1.fd_order == (2, 2)
assert drv0.deriv_order == 2
assert drv1.deriv_order == (2, 0)

assert drv0.evaluate == drv1.evaluate
fmean = .5 * (f + f._subs(y, y + y.spacing))
drv1x0 = drv1(x0={y: y+y.spacing/2}).evaluate

assert simplify(fmean.dx2.evaluate - drv1x0) == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be a good idea to check equivalence between the Derivative(f, y, deriv_order=0, x0=y+y.spacing/2) and Derivative(f, (y, 0), x0={y: y+y.spacing/2}) specifications.

There should also be tests for Derivative(f, (x, 0), (y, 0), x0={x: x+x.spacing/2, y: y+y.spacing/2}) and Derivative(f, x, y,, deriv_order=(0, 0), x0={x: x+x.spacing/2, y: y+y.spacing/2})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


# Check that substitution can applied correctly
expr0 = drv0 + 1
Expand Down
Loading