diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 4b82fdc262..77d114ad17 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -41,18 +41,36 @@ def dump(exprs, c): for c in clusters: exprs = [] for e in c.exprs: - expr, v = _core(e, c, weights, mapper, sregistry) + # Optimization 1: if the LHS is already a Symbol, then surely it's + # usable as a temporary for one of the IndexDerivatives inside `e` + if e.lhs.is_Symbol and e.operation is None: + reusable = {e.lhs} + else: + reusable = set() + + expr, v = _core(e, c, weights, reusable, mapper, sregistry) + if v: dump(exprs, c) processed.extend(v) - exprs.append(expr) + + if e.lhs is expr.rhs: + # Optimization 2: `e` is of the form + # `r = IndexDerivative(...)` + # Rather than say + # `r = foo(IndexDerivative(...))` + # Since `r` is reusable (Optimization 1), we now have `r = r`, + # which can safely be discarded + pass + else: + exprs.append(expr) dump(exprs, c) return processed, weights, mapper -def _core(expr, c, weights, mapper, sregistry): +def _core(expr, c, weights, reusables, mapper, sregistry): """ Recursively carry out the core of `lower_index_derivatives`. """ @@ -62,7 +80,7 @@ def _core(expr, c, weights, mapper, sregistry): args = [] processed = [] for a in expr.args: - e, clusters = _core(a, c, weights, mapper, sregistry) + e, clusters = _core(a, c, weights, reusables, mapper, sregistry) args.append(e) processed.extend(clusters) @@ -97,8 +115,12 @@ def _core(expr, c, weights, mapper, sregistry): extra = (c.ispace.itdims + dims,) ispace = IterationSpace.union(c.ispace, ispace0, relations=extra) - name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=w.dtype) + try: + s = reusables.pop() + assert s.dtype is w.dtype + except KeyError: + name = sregistry.make_name(prefix='r') + s = Symbol(name=name, dtype=w.dtype) expr0 = Eq(s, 0.) ispace1 = ispace.project(lambda d: d is not dims[-1]) processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1)) diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 97b855326d..05ba84f0bd 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -326,7 +326,7 @@ def test_redundant_derivatives(self): exprs = FindNodes(Expression).visit(op) assert len(exprs) == 6 temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)] - assert len(temps) == 3 + assert len(temps) == 2 class Test2Pass(object):