diff --git a/devito/finite_differences/coefficients.py b/devito/finite_differences/coefficients.py index 1a401b57e16..ab83e1b1c4c 100644 --- a/devito/finite_differences/coefficients.py +++ b/devito/finite_differences/coefficients.py @@ -1,7 +1,7 @@ import numpy as np from cached_property import cached_property -from devito.finite_differences import generate_indices +from devito.finite_differences import Weights, generate_indices from devito.finite_differences.tools import numeric_weights, symbolic_weights from devito.tools import filter_ordered, as_tuple @@ -268,8 +268,13 @@ def generate_subs(deriv_order, function, index): return subs # Determine which 'rules' are missing + sym = get_sym(functions) terms = obj.find(sym) + for i in obj.find(Weights): + for w in i.weights: + terms.update(w.find(sym)) + args_present = filter_ordered(term.args[1:] for term in terms) subs = obj.substitutions diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 170852a50ad..061e4babb76 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -637,6 +637,15 @@ def spacings(self): weights = Array.initvalue + def _xreplace(self, rule): + if self in rule: + return rule[self] + elif not rule: + return self + else: + weights, changed = zip(*[i._xreplace(rule) for i in self.weights]) + return self.func(initvalue=weights, function=None), any(changed) + class IndexDerivative(IndexSum): diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index fa076096a3a..1e269328c18 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -21,6 +21,23 @@ def test_backward_dt2(self): assert_structure(op, ['t,x,y'], 't,x,y') +class TestSymbolicCoefficients(object): + + def test_fallback_to_default(self): + grid = Grid(shape=(8, 8, 8)) + + u = TimeFunction(name='u', grid=grid, coefficients='symbolic', + space_order=4, time_order=2) + + eq = Eq(u.forward, u.dx2 + 1) + + op = Operator(eq, opt=('advanced', {'expand': False})) + + # Ensure all symbols have been resolved + op.arguments(dt=1, time_M=10) + op.cfunction + + class Test1Pass(object): def test_v0(self):