Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: prevent derivative shortcut with incompatible fd order #2254

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ def is_TimeDependent(self):

@cached_property
def _fd(self):
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in self._args_diff]))
# Filter out all args with fd order too high
fd_args = []
for f in self._args_diff:
try:
if f.space_order <= self.space_order and \
(not f.is_TimeDependent or f.time_order <= self.time_order):
fd_args.append(f)
except AttributeError:
pass
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in fd_args]))

@cached_property
def _symbolic_functions(self):
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def dtype(self):
If two Clusters perform calculations with different precision, the
data type with highest precision is returned.
"""
dtypes = {i.dtype for i in self}
dtypes = {i.dtype for i in self} - {None}

return infer_dtype(dtypes)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ def test_all_shortcuts(self, so):
for fd in g._fd:
assert getattr(g, fd)

for d in grid.dimensions:
assert 'd%s' % d.name in f._fd
assert 'd%s' % d.name in g._fd
for o in range(2, min(7, so+1)):
assert 'd%s%s' % (d.name, o) in f._fd
assert 'd%s%s' % (d.name, o) in g._fd

def test_shortcuts_mixed(self):
grid = Grid(shape=(10,))
f = Function(name='f', grid=grid, space_order=2)
g = Function(name='g', grid=grid, space_order=4)
assert 'dx4' not in (f*g)._fd
assert 'dx4' not in (f+g)._fd
assert 'dx4' not in (g*f.dx)._fd

def test_transpose_simple(self):
grid = Grid(shape=(4, 4))

Expand Down
3 changes: 2 additions & 1 deletion tests/test_lower_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def test_symbolic_constant_times_add(self):
dt = grid.time_dim.spacing

u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
f = Function(name='f', grid=grid)
f = Function(name='f', grid=grid, space_order=4)

eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f))

leq = collect_derivatives.func([eq])[0]

assert len(eq.rhs.args) == 3
Expand Down
Loading