From cf85b06cb47c9fc7336f6ca652a897e12b3c5571 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 31 Oct 2023 10:45:28 -0400 Subject: [PATCH 1/2] api: prevent derivative shortcut with incompatible fd order --- devito/finite_differences/differentiable.py | 10 +++++++++- tests/test_derivatives.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index a20902a39a..23152053b8 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -105,7 +105,15 @@ def is_TimeDependent(self): @cached_property def _fd(self): - return dict(ChainMap(*[getattr(i, '_fd', {}) for i in self._args_diff])) + # Filter out all args with fd order too high + fd_args = [] + for f in self._args_diff: + try: + if f.space_order <= self.space_order and f.time_order <= self.time_order: + fd_args.append(f) + except AttributeError: + pass + return dict(ChainMap(*[getattr(i, '_fd', {}) for i in fd_args])) @cached_property def _symbolic_functions(self): diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 5fa11a6fa0..f1c1951c97 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -457,6 +457,21 @@ def test_all_shortcuts(self, so): for fd in g._fd: assert getattr(g, fd) + for d in grid.dimensions: + assert 'd%s' % d.name in f._fd + assert 'd%s' % d.name in g._fd + for o in range(2, min(7, so+1)): + assert 'd%s%s' % (d.name, o) in f._fd + assert 'd%s%s' % (d.name, o) in g._fd + + def test_shortcuts_mixed(self): + grid = Grid(shape=(10,)) + f = Function(name='f', grid=grid, space_order=2) + g = Function(name='g', grid=grid, space_order=4) + assert 'dx4' not in (f*g)._fd + assert 'dx4' not in (f+g)._fd + assert 'dx4' not in (g*f.dx)._fd + def test_transpose_simple(self): grid = Grid(shape=(4, 4)) From e6ca17e995708ef07aba8cb49587688a0653b858 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 31 Oct 2023 12:27:19 -0400 Subject: [PATCH 2/2] compiler: prevent invering dtype on empty cluster --- devito/finite_differences/differentiable.py | 3 ++- devito/ir/clusters/cluster.py | 2 +- tests/test_lower_exprs.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 23152053b8..a10c3932d8 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -109,7 +109,8 @@ def _fd(self): fd_args = [] for f in self._args_diff: try: - if f.space_order <= self.space_order and f.time_order <= self.time_order: + if f.space_order <= self.space_order and \ + (not f.is_TimeDependent or f.time_order <= self.time_order): fd_args.append(f) except AttributeError: pass diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 887793e8e7..69ee00c6c5 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -483,7 +483,7 @@ def dtype(self): If two Clusters perform calculations with different precision, the data type with highest precision is returned. """ - dtypes = {i.dtype for i in self} + dtypes = {i.dtype for i in self} - {None} return infer_dtype(dtypes) diff --git a/tests/test_lower_exprs.py b/tests/test_lower_exprs.py index 4903dbf9a3..0d0b56ca82 100644 --- a/tests/test_lower_exprs.py +++ b/tests/test_lower_exprs.py @@ -65,9 +65,10 @@ def test_symbolic_constant_times_add(self): dt = grid.time_dim.spacing u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) - f = Function(name='f', grid=grid) + f = Function(name='f', grid=grid, space_order=4) eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f)) + leq = collect_derivatives.func([eq])[0] assert len(eq.rhs.args) == 3