Skip to content

Commit

Permalink
Merge pull request #2288 from devitocodes/tweak-unexpansion
Browse files Browse the repository at this point in the history
compiler: Improve IndexDerivative lowering
  • Loading branch information
FabioLuporini authored Jan 2, 2024
2 parents c888cee + df4e53d commit f1db9d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
34 changes: 28 additions & 6 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,36 @@ def dump(exprs, c):
for c in clusters:
exprs = []
for e in c.exprs:
expr, v = _core(e, c, weights, mapper, sregistry)
# Optimization 1: if the LHS is already a Symbol, then surely it's
# usable as a temporary for one of the IndexDerivatives inside `e`
if e.lhs.is_Symbol and e.operation is None:
reusable = {e.lhs}
else:
reusable = set()

expr, v = _core(e, c, weights, reusable, mapper, sregistry)

if v:
dump(exprs, c)
processed.extend(v)
exprs.append(expr)

if e.lhs is expr.rhs:
# Optimization 2: `e` is of the form
# `r = IndexDerivative(...)`
# Rather than say
# `r = foo(IndexDerivative(...))`
# Since `r` is reusable (Optimization 1), we now have `r = r`,
# which can safely be discarded
pass
else:
exprs.append(expr)

dump(exprs, c)

return processed, weights, mapper


def _core(expr, c, weights, mapper, sregistry):
def _core(expr, c, weights, reusables, mapper, sregistry):
"""
Recursively carry out the core of `lower_index_derivatives`.
"""
Expand All @@ -62,7 +80,7 @@ def _core(expr, c, weights, mapper, sregistry):
args = []
processed = []
for a in expr.args:
e, clusters = _core(a, c, weights, mapper, sregistry)
e, clusters = _core(a, c, weights, reusables, mapper, sregistry)
args.append(e)
processed.extend(clusters)

Expand Down Expand Up @@ -97,8 +115,12 @@ def _core(expr, c, weights, mapper, sregistry):
extra = (c.ispace.itdims + dims,)
ispace = IterationSpace.union(c.ispace, ispace0, relations=extra)

name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=w.dtype)
try:
s = reusables.pop()
assert s.dtype is w.dtype
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=w.dtype)
expr0 = Eq(s, 0.)
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_redundant_derivatives(self):
exprs = FindNodes(Expression).visit(op)
assert len(exprs) == 6
temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)]
assert len(temps) == 3
assert len(temps) == 2


class Test2Pass(object):
Expand Down

0 comments on commit f1db9d4

Please sign in to comment.