Skip to content

Commit 4ffcb23

Browse files
committed
api: fix corner cases for x0/dims derivative specification
1 parent dd337e4 commit 4ffcb23

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

devito/finite_differences/derivative.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .differentiable import Differentiable
99
from .tools import direct, transpose
1010
from .rsfd import d45
11-
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict
11+
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict, is_integer
1212
from devito.types.utils import DimensionTuple
1313

1414
__all__ = ['Derivative']
@@ -121,7 +121,8 @@ def __new__(cls, expr, *dims, **kwargs):
121121
processed.append(i)
122122
obj._ppsubs = tuple(processed)
123123

124-
obj._x0 = frozendict(kwargs.get('x0', {}))
124+
obj._x0 = cls._process_x0(obj._dims, **kwargs)
125+
125126
return obj
126127

127128
@classmethod
@@ -161,6 +162,10 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
161162
new_dims = (dims[0],)
162163
else:
163164
new_dims = tuple([dims[0]]*orders)
165+
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
166+
# special case of single dimension and order
167+
new_dims = (dims[0],)
168+
orders = dims[1]
164169
else:
165170
# Iterable of 2-tuple, e.g. ((x, 2), (y, 3))
166171
new_dims = []
@@ -171,16 +176,16 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
171176
new_dims.extend([d[0] for _ in range(max(1, d[1]))])
172177
orders.append(d[1])
173178
else:
174-
new_dims.extend([d for _ in range(o)])
179+
new_dims.extend([d for _ in range(max(1, o))])
175180
orders.append(o)
176181
new_dims = as_tuple(new_dims)
177182
orders = as_tuple(orders)
178183

179184
# Finite difference orders depending on input dimension (.dt or .dx)
180185
fd_orders = kwargs.get('fd_order', tuple([expr.time_order if
181186
getattr(d, 'is_Time', False) else
182-
expr.space_order for d in dims]))
183-
if len(dims) == 1 and isinstance(fd_orders, Iterable):
187+
expr.space_order for d in new_dims]))
188+
if len(new_dims) == 1 and isinstance(fd_orders, Iterable):
184189
fd_orders = fd_orders[0]
185190

186191
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
@@ -190,26 +195,36 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
190195
for s in filter_ordered(new_dims)]
191196
return new_dims, orders, fd_orders, variable_count
192197

198+
@classmethod
199+
def _process_x0(cls, dims, **kwargs):
200+
try:
201+
x0 = frozendict(kwargs.get('x0', {}))
202+
except TypeError:
203+
# Only given a value
204+
assert len(dims) == 1
205+
x0 = frozendict({dims[0]: kwargs.get('x0')})
206+
207+
return x0
208+
193209
def __call__(self, x0=None, fd_order=None, side=None, method=None):
210+
x0 = self._process_x0(self.dims, x0=x0)
211+
_x0 = frozendict({**self.x0, **x0})
194212
if self.ndims == 1:
195213
fd_order = fd_order or self._fd_order
196214
side = side or self._side
197215
method = method or self._method
198-
x0 = {self.dims[0]: x0} if x0 is not None else self.x0
199-
return self._new_from_self(fd_order=fd_order, side=side, x0=x0,
216+
return self._new_from_self(fd_order=fd_order, side=side, x0=_x0,
200217
method=method)
201218

202219
if side is not None:
203220
raise TypeError("Side only supported for first order single"
204221
"Dimension derivative such as `.dxl` or .dx(side=left)")
205222
# Cross derivative
206-
_x0 = dict(self._x0)
207223
_fd_order = dict(self.fd_order._getters)
208224
try:
209225
_fd_order.update(fd_order or {})
210226
_fd_order = tuple(_fd_order.values())
211227
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
212-
_x0.update(x0 or {})
213228
except AttributeError:
214229
raise TypeError("Multi-dimensional Derivative, input expected as a dict")
215230

tests/test_derivatives.py

+23
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,29 @@ def test_zero_fd_interp(self):
676676
finterp = Derivative(f, (x, 0), x0={x: x+x.spacing/2}).evaluate
677677
assert simplify(finterp - expected) == 0
678678

679+
def test_deriv_spec(self):
680+
grid = Grid((11, 11))
681+
x, y = grid.dimensions
682+
f = Function(name="f", grid=grid, space_order=4)
683+
684+
assert f.dx(x0=x + x.spacing) == f.dx(x0={x: x + x.spacing})
685+
assert Derivative(f, x, 1) == Derivative(f, (x, 1))
686+
687+
x0xy = {x: x+x.spacing/2, y: y+y.spacing/2}
688+
dxy0 = Derivative(f, (x, 0), (y, 0), x0=x0xy)
689+
dxy02 = Derivative(f, x, y, deriv_order=(0, 0), x0=x0xy)
690+
assert dxy0 == dxy02
691+
assert dxy0.dims == (x, y)
692+
assert dxy0.deriv_order == (0, 0)
693+
assert dxy0.fd_order == (4, 4)
694+
assert dxy0.x0 == x0xy
695+
696+
dxy0 = Derivative(f, (x, 0), (y, 0), x0={y: y+y.spacing/2})
697+
dxy02 = Derivative(f, x, y, deriv_order=(0, 0), x0={x: x+x.spacing/2})
698+
assert dxy0 != dxy02
699+
assert dxy0.x0 == {y: y+y.spacing/2}
700+
assert dxy02.x0 == {x: x+x.spacing/2}
701+
679702

680703
class TestTwoStageEvaluation(object):
681704

0 commit comments

Comments
 (0)