8
8
from .differentiable import Differentiable
9
9
from .tools import direct , transpose
10
10
from .rsfd import d45
11
- from devito .tools import as_mapper , as_tuple , filter_ordered , frozendict
11
+ from devito .tools import as_mapper , as_tuple , filter_ordered , frozendict , is_integer
12
12
from devito .types .utils import DimensionTuple
13
13
14
14
__all__ = ['Derivative' ]
@@ -121,7 +121,8 @@ def __new__(cls, expr, *dims, **kwargs):
121
121
processed .append (i )
122
122
obj ._ppsubs = tuple (processed )
123
123
124
- obj ._x0 = frozendict (kwargs .get ('x0' , {}))
124
+ obj ._x0 = cls ._process_x0 (obj ._dims , ** kwargs )
125
+
125
126
return obj
126
127
127
128
@classmethod
@@ -161,6 +162,10 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
161
162
new_dims = (dims [0 ],)
162
163
else :
163
164
new_dims = tuple ([dims [0 ]]* orders )
165
+ elif len (dims ) == 2 and not isinstance (dims [1 ], Iterable ) and is_integer (dims [1 ]):
166
+ # special case of single dimension and order
167
+ new_dims = (dims [0 ],)
168
+ orders = dims [1 ]
164
169
else :
165
170
# Iterable of 2-tuple, e.g. ((x, 2), (y, 3))
166
171
new_dims = []
@@ -171,16 +176,16 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
171
176
new_dims .extend ([d [0 ] for _ in range (max (1 , d [1 ]))])
172
177
orders .append (d [1 ])
173
178
else :
174
- new_dims .extend ([d for _ in range (o )])
179
+ new_dims .extend ([d for _ in range (max ( 1 , o ) )])
175
180
orders .append (o )
176
181
new_dims = as_tuple (new_dims )
177
182
orders = as_tuple (orders )
178
183
179
184
# Finite difference orders depending on input dimension (.dt or .dx)
180
185
fd_orders = kwargs .get ('fd_order' , tuple ([expr .time_order if
181
186
getattr (d , 'is_Time' , False ) else
182
- expr .space_order for d in dims ]))
183
- if len (dims ) == 1 and isinstance (fd_orders , Iterable ):
187
+ expr .space_order for d in new_dims ]))
188
+ if len (new_dims ) == 1 and isinstance (fd_orders , Iterable ):
184
189
fd_orders = fd_orders [0 ]
185
190
186
191
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
@@ -190,26 +195,36 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
190
195
for s in filter_ordered (new_dims )]
191
196
return new_dims , orders , fd_orders , variable_count
192
197
198
+ @classmethod
199
+ def _process_x0 (cls , dims , ** kwargs ):
200
+ try :
201
+ x0 = frozendict (kwargs .get ('x0' , {}))
202
+ except TypeError :
203
+ # Only given a value
204
+ assert len (dims ) == 1
205
+ x0 = frozendict ({dims [0 ]: kwargs .get ('x0' )})
206
+
207
+ return x0
208
+
193
209
def __call__ (self , x0 = None , fd_order = None , side = None , method = None ):
210
+ x0 = self ._process_x0 (self .dims , x0 = x0 )
211
+ _x0 = frozendict ({** self .x0 , ** x0 })
194
212
if self .ndims == 1 :
195
213
fd_order = fd_order or self ._fd_order
196
214
side = side or self ._side
197
215
method = method or self ._method
198
- x0 = {self .dims [0 ]: x0 } if x0 is not None else self .x0
199
- return self ._new_from_self (fd_order = fd_order , side = side , x0 = x0 ,
216
+ return self ._new_from_self (fd_order = fd_order , side = side , x0 = _x0 ,
200
217
method = method )
201
218
202
219
if side is not None :
203
220
raise TypeError ("Side only supported for first order single"
204
221
"Dimension derivative such as `.dxl` or .dx(side=left)" )
205
222
# Cross derivative
206
- _x0 = dict (self ._x0 )
207
223
_fd_order = dict (self .fd_order ._getters )
208
224
try :
209
225
_fd_order .update (fd_order or {})
210
226
_fd_order = tuple (_fd_order .values ())
211
227
_fd_order = DimensionTuple (* _fd_order , getters = self .dims )
212
- _x0 .update (x0 or {})
213
228
except AttributeError :
214
229
raise TypeError ("Multi-dimensional Derivative, input expected as a dict" )
215
230
0 commit comments