Skip to content

Commit

Permalink
compiler: Catch redundant IndexDerivatives early on
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 19, 2023
1 parent 7272069 commit 8f50425
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
14 changes: 12 additions & 2 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']

Expand Down Expand Up @@ -169,15 +169,25 @@ def _(exprs):
return mapper


@count.register(IndexDerivative)
@count.register(Indexed)
@count.register(Symbol)
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([expr])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
Expand Down
4 changes: 4 additions & 0 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def lower_index_derivatives(clusters, mode=None, **kwargs):
if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

# At this point we can detect redundancies induced by inner derivatives that
# previously were just not detectable via e.g. plain CSE. For example, if
# there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx`
# then it's only after `_lower_index_derivatives` that they're detectable!
clusters = CDE(mapper).process(clusters)

return clusters
Expand Down
30 changes: 28 additions & 2 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
Substitutions, Coefficient, cos, sin)
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
from devito.parameters import switchconfig, configuration
from devito.types import Symbol

Expand Down Expand Up @@ -300,9 +301,34 @@ def test_transpose(self):
op1.apply(time_M=10, u=u1)
assert np.allclose(u.data, u1.data, rtol=10e-6)

def test_redundant_derivatives(self):
grid = Grid(shape=(10, 10, 10))

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
u = TimeFunction(name='u', grid=grid, space_order=4)

# It's the same `u.dx.dy` appearing multiple times, and the cost models
# must be smart enough to detect that!
eq = Eq(u.forward, (f*u.dx.dy + g*u.dx.dy + h*u.dx.dy +
(f*g)*u.dx.dy + (f*h)*u.dx.dy + (g*h)*u.dx.dy))

op = Operator(eq, opt=('advanced', {'expand': False,
'blocklevels': 0}))

# Check generated code
assert len(get_arrays(op)) == 0
assert op._profiler._sections['section0'].sops == 74
exprs = FindNodes(Expression).visit(op)
assert len(exprs) == 6
temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)]
assert len(temps) == 3


class Test2Pass(object):

@switchconfig(safe_math=True)
def test_v0(self):
grid = Grid(shape=(10, 10, 10))

Expand All @@ -312,14 +338,14 @@ def test_v0(self):
v1 = TimeFunction(name='v', grid=grid, space_order=8)

eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)),
Eq(v.forward, (v + u.dx.dy + 1.))]
Eq(v.forward, (v + u.dx.dz + 1.))]

op0 = Operator(eqns)
op1 = Operator(eqns, opt=('advanced', {'expand': False,
'openmp': True}))

# Check generated code
assert op1._profiler._sections['section0'].sops == 41
assert op1._profiler._sections['section0'].sops == 59
assert_structure(op1, ['t',
't,x0_blk0,y0_blk0,x,y,z',
't,x0_blk0,y0_blk0,x,y,z,i0',
Expand Down

0 comments on commit 8f50425

Please sign in to comment.