Skip to content

Commit

Permalink
api: fix derivative corner case bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 3, 2024
1 parent 4ffcb23 commit 2fa1537
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 50 deletions.
33 changes: 18 additions & 15 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
fd_orders = kwargs.get('fd_order')
deriv_orders = kwargs.get('deriv_order')
if len(dims) == 1:
dims = tuple([dims[0]]*deriv_orders)
dims = tuple([dims[0]]*max(1, deriv_orders))
variable_count = [sympy.Tuple(s, dims.count(s))
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count
Expand All @@ -158,41 +158,39 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
orders = kwargs.get('deriv_order', 1)
if isinstance(orders, Iterable):
orders = orders[0]
if orders == 0:
new_dims = (dims[0],)
else:
new_dims = tuple([dims[0]]*orders)
new_dims = tuple([dims[0]]*max(1, orders))
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
# special case of single dimension and order
new_dims = (dims[0],)
orders = dims[1]
new_dims = tuple([dims[0]]*max(1, orders))
else:
# Iterable of 2-tuple, e.g. ((x, 2), (y, 3))
new_dims = []
orders = []
d_ord = kwargs.get('deriv_order', tuple([1]*len(dims)))
for d, o in zip(dims, d_ord):
if isinstance(d, Iterable):
new_dims.extend([d[0] for _ in range(max(1, d[1]))])
new_dims.extend([d[0]]*max(1, d[1]))
orders.append(d[1])
else:
new_dims.extend([d for _ in range(max(1, o))])
new_dims.extend([d]*max(1, o))
orders.append(o)
new_dims = as_tuple(new_dims)
orders = as_tuple(orders)

# Finite difference orders depending on input dimension (.dt or .dx)
odims = filter_ordered(new_dims)
fd_orders = kwargs.get('fd_order', tuple([expr.time_order if
getattr(d, 'is_Time', False) else
expr.space_order for d in new_dims]))
if len(new_dims) == 1 and isinstance(fd_orders, Iterable):
expr.space_order for d in odims]))
if len(odims) == 1 and isinstance(fd_orders, Iterable):
fd_orders = fd_orders[0]

# SymPy expects the list of variable w.r.t. which we differentiate to be a list
# of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
# of the derivative
variable_count = [sympy.Tuple(s, new_dims.count(s))
for s in filter_ordered(new_dims)]
for s in odims]
return new_dims, orders, fd_orders, variable_count

@classmethod
Expand All @@ -201,8 +199,12 @@ def _process_x0(cls, dims, **kwargs):
x0 = frozendict(kwargs.get('x0', {}))
except TypeError:
# Only given a value
assert len(dims) == 1
x0 = frozendict({dims[0]: kwargs.get('x0')})
_x0 = kwargs.get('x0')
assert len(dims) == 1 or _x0 is None
if _x0 is not None:
x0 = frozendict({dims[0]: _x0})
else:
x0 = frozendict({})

return x0

Expand Down Expand Up @@ -355,7 +357,7 @@ def _eval_at(self, func):
has to be computed at x=x + h_x/2.
"""
# If an x0 already exists do not overwrite it
x0 = self.x0 or dict(func.indices_ref._getters)
x0 = self.x0 or func.indices_ref._getters
if self.expr.is_Add:
# If `expr` has both staggered and non-staggered terms such as
# `(u(x + h_x/2) + v(x)).dx` then we exploit linearity of FD to split
Expand Down Expand Up @@ -427,7 +429,8 @@ def _eval_fd(self, expr, **kwargs):
matvec=self.transpose, x0=self.x0, expand=expand)
else:
assert self.method == 'FD'
res = generic_derivative(expr, *self.dims, self.fd_order, self.deriv_order,
res = generic_derivative(expr, self.dims[0], as_tuple(self.fd_order)[0],
self.deriv_order,
matvec=self.transpose, x0=self.x0, expand=expand)

# Step 3: Apply substitutions
Expand Down
6 changes: 5 additions & 1 deletion devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,15 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None):

# Evaluation point relative to the expression's grid
mid = (x0 - expr.indices_ref[dim]).subs({dim: 0, dim.spacing: 1})

# Indices range
o_min = int(np.ceil(mid - order/2)) + side.val
o_max = int(np.floor(mid + order/2)) + side.val
if o_max == o_min:
o_max += 1
if dim.is_Time or not expr.is_Staggered:
o_max += 1
else:
o_min -= 1

# StencilDimension and expression
d = make_stencil_dimension(expr, o_min, o_max)
Expand Down
64 changes: 32 additions & 32 deletions examples/seismic/tutorials/06_elastic_varying_parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -392,22 +392,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.25 s\n"
"Operator `Kernel` ran in 0.23 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.21099099999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.2021389999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.006651999999999994, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.005733999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.007887000000000019, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.006415000000000033, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.007380000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.006162000000000033, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.007149000000000023, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.006123000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -511,22 +511,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.38 s\n"
"Operator `Kernel` ran in 0.27 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.29417999999999983, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.2191359999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.026481999999999884, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.010224999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.015131999999999975, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.011126999999999972, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.011889999999999993, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.010938000000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.024015000000000057, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.010993000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 17,
Expand Down Expand Up @@ -699,20 +699,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.30 s\n"
"Operator `Kernel` ran in 0.23 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.2539820000000001, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.200911, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.013418, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.007577999999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.013824999999999992, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.007810000000000017, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.013175000000000003, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.00834500000000001, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 24,
Expand Down Expand Up @@ -796,20 +796,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.30 s\n"
"Operator `Kernel` ran in 0.26 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.2411920000000002, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.21829600000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.019666999999999997, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.011041000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.015708999999999966, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.01102099999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.016480000000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.01219099999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 26,
Expand Down Expand Up @@ -921,12 +921,12 @@
{
"data": {
"text/latex": [
"$\\displaystyle - \\frac{f(x, y)}{h_{x}} + \\frac{f(x + h_x, y)}{h_{x}}$"
"$\\displaystyle \\frac{f(x, y)}{h_{x}} - \\frac{f(x - h_x, y)}{h_{x}}$"
],
"text/plain": [
" f(x, y) f(x + hₓ, y)\n",
"- ─────── + ────────────\n",
" hₓ hₓ "
"f(x, y) f(x - hₓ, y)\n",
"─────── - ────────────\n",
" hₓ hₓ "
]
},
"execution_count": 31,
Expand Down Expand Up @@ -1016,22 +1016,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Operator `Kernel` ran in 0.63 s\n"
"Operator `Kernel` ran in 0.65 s\n"
]
},
{
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.5551359999999997, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.5643610000000013, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.014585000000000094, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.01983800000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.017796999999999973, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.02182, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section3', rank=None),\n",
" PerfEntry(time=0.01653199999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.020853000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section4', rank=None),\n",
" PerfEntry(time=0.016534999999999966, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.020895, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 34,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def test_solve(self, operate_on_empty_cache):
# created by the finite difference (u.dt, u.dx2). We would have had
# three extra references to u(t + dt), u(x - h_x) and u(x + h_x).
# But this is not the case anymore!
assert len(_SymbolCache) == 11
assert len(_SymbolCache) == 12
clear_cache()
assert len(_SymbolCache) == 8
clear_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,7 @@ def test_haloupdate_multi_op(self, mode):
op.apply()
f.data[:, :] = fo.data[:, :]

assert (np.isclose(norm(f), 17.24904, atol=1e-4, rtol=0))
assert (np.isclose(norm(f), 17.86754, atol=1e-4, rtol=0))

@pytest.mark.parallel(mode=1)
def test_haloupdate_issue_1613(self, mode):
Expand Down

0 comments on commit 2fa1537

Please sign in to comment.