diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index a20902a39a..a10c3932d8 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -105,7 +105,16 @@ def is_TimeDependent(self): @cached_property def _fd(self): - return dict(ChainMap(*[getattr(i, '_fd', {}) for i in self._args_diff])) + # Filter out all args with fd order too high + fd_args = [] + for f in self._args_diff: + try: + if f.space_order <= self.space_order and \ + (not f.is_TimeDependent or f.time_order <= self.time_order): + fd_args.append(f) + except AttributeError: + pass + return dict(ChainMap(*[getattr(i, '_fd', {}) for i in fd_args])) @cached_property def _symbolic_functions(self): diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 887793e8e7..69ee00c6c5 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -483,7 +483,7 @@ def dtype(self): If two Clusters perform calculations with different precision, the data type with highest precision is returned. """ - dtypes = {i.dtype for i in self} + dtypes = {i.dtype for i in self} - {None} return infer_dtype(dtypes) diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 5fa11a6fa0..f1c1951c97 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -457,6 +457,21 @@ def test_all_shortcuts(self, so): for fd in g._fd: assert getattr(g, fd) + for d in grid.dimensions: + assert 'd%s' % d.name in f._fd + assert 'd%s' % d.name in g._fd + for o in range(2, min(7, so+1)): + assert 'd%s%s' % (d.name, o) in f._fd + assert 'd%s%s' % (d.name, o) in g._fd + + def test_shortcuts_mixed(self): + grid = Grid(shape=(10,)) + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=4) + assert 'dx4' not in (f*g)._fd + assert 'dx4' not in (f+g)._fd + assert 'dx4' not in (g*f.dx)._fd + def test_transpose_simple(self): grid = Grid(shape=(4, 4)) diff --git a/tests/test_lower_exprs.py b/tests/test_lower_exprs.py index 4903dbf9a3..0d0b56ca82 100644 --- a/tests/test_lower_exprs.py +++ b/tests/test_lower_exprs.py @@ -65,9 +65,10 @@ def test_symbolic_constant_times_add(self): dt = grid.time_dim.spacing u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) - f = Function(name='f', grid=grid) + f = Function(name='f', grid=grid, space_order=4) eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f)) + leq = collect_derivatives.func([eq])[0] assert len(eq.rhs.args) == 3