Skip to content

Commit

Permalink
api: switch AbstractFunction off grid interpolation to 0th derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 6, 2024
1 parent c441745 commit b10ef96
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 62 deletions.
2 changes: 2 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,8 @@ def __new__(cls, *args, base=None, **kwargs):

func = DifferentiableOp._rebuild

# Since obj.base = base, then Differentiable.__eq__ leads to infinite recursion
# as it checks obj.base == other.base
__eq__ = sympy.Add.__eq__
__hash__ = sympy.Add.__hash__

Expand Down
58 changes: 31 additions & 27 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,43 +1071,47 @@ def _eval_deriv(self):
return self

@property
def _is_on_grid(self):
def _grid_map(self):
"""
Check whether the object is on the grid and requires averaging.
For example, if the original non-staggered function is f(x)
then f(x) is on the grid and f(x + h_x/2) is off the grid.
Mapper of off-grid interpolation points indices for each dimension.
"""
mapper = {}
for i, j, d in zip(self.indices, self.indices_ref, self.dimensions):
# Two indices are aligned if they differ by an Integer*spacing.
v = i - j
v = (i - j)/d.spacing
try:
if int(v/d.spacing) != v/d.spacing:
return False
except TypeError:
return False
return True
if not isinstance(v, sympy.Number) or int(v) == v:
continue
# Skip if index is just a Symbol or integer
elif (i.is_Symbol and not i.has(d)) or i.is_Integer:
continue
else:
mapper.update({d: i})
except (AttributeError, TypeError):
mapper.update({d: i})
return mapper

def _evaluate(self, **kwargs):
"""
Evaluate off the grid with 2nd order interpolation.
Directly available through zeroth order derivative of the base object
i.e f(x + a) = f(x).diff(x, deriv_order=0, fd_order=2, x0={x: x + a})
This allow to evaluate off grid points as EvalDerivative that are better
for the compiler.
"""
# Average values if at a location not on the Function's grid
if self._is_on_grid:
if not self._grid_map:
return self

weight = 1.0
avg_list = [self]
is_averaged = False
for i, ir, d in zip(self.indices, self.indices_ref, self.dimensions):
off = (i - ir)/d.spacing
if not isinstance(off, sympy.Number) or int(off) == off:
pass
else:
weight *= 1/2
is_averaged = True
avg_list = [(a.xreplace({i: i - d.spacing/2}) +
a.xreplace({i: i + d.spacing/2})) for a in avg_list]

if not is_averaged:
return self
return weight * sum(avg_list)
# Base function
retval = self.function
# Apply interpolation from inner most dim
for d, i in self._grid_map.items():
retval = retval.diff(d, 0, fd_order=2, x0={d: i})
# Evaluate. Since we used `self.function` it will be on the grid when evaluate
# is called again within FD
return retval.evaluate

@property
def shape(self):
Expand Down
56 changes: 28 additions & 28 deletions examples/seismic/tutorials/06_elastic_varying_parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -392,22 +392,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.23 s\n"
"Operator `Kernel` ran in 0.22 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.2021389999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.19680499999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.005733999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005080999999999991, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.006415000000000033, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005789000000000039, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.006162000000000033, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005510000000000015, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.006123000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.005339000000000027, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -511,22 +511,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.27 s\n"
"Operator `Kernel` ran in 0.24 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.2191359999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.20575800000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.010224999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.006976999999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.011126999999999972, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.007339000000000036, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.010938000000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.0069080000000000166, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.010993000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.006813000000000026, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 17,
Expand Down Expand Up @@ -699,20 +699,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.23 s\n"
"Operator `Kernel` ran in 0.21 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.200911, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.1906600000000002, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.007577999999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005354999999999988, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.007810000000000017, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005681000000000019, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.00834500000000001, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.005547000000000023, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 24,
Expand Down Expand Up @@ -796,20 +796,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.26 s\n"
"Operator `Kernel` ran in 0.27 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.21829600000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.22429000000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.011041000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.013601000000000016, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.01102099999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.012073000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.01219099999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.012027000000000043, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 26,
Expand Down Expand Up @@ -1016,22 +1016,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.65 s\n"
"Operator `Kernel` ran in 0.57 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.5643610000000013, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.5206050000000007, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.01983800000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.010398000000000041, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.02182, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.012117999999999964, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.020853000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.011993000000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.020895, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.011515999999999974, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 34,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_shift():
assert a.shift(x, x.spacing).shift(x, -x.spacing) == a
assert a.shift(x, x.spacing).shift(x, x.spacing) == a.shift(x, 2*x.spacing)
assert a.dx.evaluate.shift(x, x.spacing) == a.shift(x, x.spacing).dx.evaluate
assert not a.shift(x, .5 * x.spacing)._is_on_grid
assert a.shift(x, .5 * x.spacing)._grid_map == {x: x + .5 * x.spacing}
1 change: 0 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,6 @@ def test_sum_of_nested_derivatives(self, expr, exp_arrays, exp_ops):
op1 = Operator(eqn, opt=('collect-derivs', 'cire-sops', {'openmp': True}))
op2 = Operator(eqn, opt=('cire-sops', {'openmp': True}))
op3 = Operator(eqn, opt=('advanced', {'openmp': True}))
print(op3)

# Check code generation
arrays = [i for i in FindSymbols().visit(op1) if i.is_Array]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_staggered_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
from sympy import simplify

from devito import (Function, Grid, NODE, VectorTimeFunction,
TimeFunction, Eq, Operator, div)
Expand Down Expand Up @@ -53,7 +54,7 @@ def test_avg(ndim):
avg = f
for dd in d:
avg = .5 * (avg + avg.subs({dd: dd - dd.spacing}))
assert shifted.evaluate == avg
assert simplify(shifted.evaluate - avg) == 0


@pytest.mark.parametrize('ndim', [1, 2, 3])
Expand All @@ -75,7 +76,7 @@ def test_is_param(ndim):
avg = f2
for dd in d:
avg = .5 * (avg + avg.subs({dd: dd - dd.spacing}))
assert f2._eval_at(var).evaluate == avg
assert simplify(f2._eval_at(var).evaluate - avg) == 0


@pytest.mark.parametrize('expr, expected', [
Expand Down
6 changes: 3 additions & 3 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ def test_is_on_grid():
x0 = x + .5 * x.spacing
u = Function(name="u", grid=grid, space_order=2)

assert u._is_on_grid
assert not u.subs({x: x0})._is_on_grid
assert all(uu._is_on_grid for uu in retrieve_functions(u.subs({x: x0}).evaluate))
assert u._grid_map == {}
assert u.subs({x: x0})._grid_map == {x: x0}
assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate))


@pytest.mark.parametrize('expr,expected', [
Expand Down

0 comments on commit b10ef96

Please sign in to comment.