Skip to content

Commit

Permalink
Merge pull request #2284 from devitocodes/fix-estimate-cost
Browse files Browse the repository at this point in the history
compiler: Fix handling of redundant derivatives
  • Loading branch information
FabioLuporini authored Dec 19, 2023
2 parents b7016f2 + 8f50425 commit 063d07d
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 26 deletions.
14 changes: 12 additions & 2 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']

Expand Down Expand Up @@ -169,15 +169,25 @@ def _(exprs):
return mapper


@count.register(IndexDerivative)
@count.register(Indexed)
@count.register(Symbol)
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([expr])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
Expand Down
4 changes: 4 additions & 0 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def lower_index_derivatives(clusters, mode=None, **kwargs):
if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

# At this point we can detect redundancies induced by inner derivatives that
# previously were just not detectable via e.g. plain CSE. For example, if
# there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx`
# then it's only after `_lower_index_derivatives` that they're detectable!
clusters = CDE(mapper).process(clusters)

return clusters
Expand Down
63 changes: 41 additions & 22 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False):
# We don't use SymPy's count_ops because we do not count integer arithmetic
# (e.g., array index functions such as i+1 in A[i+1])
# Also, the routine below is *much* faster than count_ops
seen = {}
flops = 0
for expr in as_tuple(exprs):
# TODO: this if-then should be part of singledispatch too, but because
Expand All @@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False):
else:
e = expr

flops += _estimate_cost(e, estimate)[0]
flops += _estimate_cost(e, estimate, seen)[0]

return flops
except:
Expand All @@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False):
}


def dont_count_if_seen(func):
"""
This decorator is used to avoid counting the same expression multiple
times. This is necessary because the same expression may appear multiple
times in the same expression tree or even across different expressions.
"""
def wrapper(expr, estimate, seen):
try:
_, flags = seen[expr]
flops = 0
except KeyError:
flops, flags = seen[expr] = func(expr, estimate, seen)
return flops, flags
return wrapper


@singledispatch
def _estimate_cost(expr, estimate):
def _estimate_cost(expr, estimate, seen):
# Retval: flops (int), flag (bool)
# The flag tells wether it's an integer expression (implying flops==0) or not
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if all(flags):
# `expr` is an operation involving integer operands only
Expand All @@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate):

@_estimate_cost.register(Tuple)
@_estimate_cost.register(CallFromPointer)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
except ValueError:
flops, flags = [], []
return sum(flops), all(flags)


@_estimate_cost.register(Integer)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, True


@_estimate_cost.register(Number)
@_estimate_cost.register(ReservedWord)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, False


@_estimate_cost.register(Symbol)
@_estimate_cost.register(Indexed)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
if issubclass(expr.dtype, np.integer):
return 0, True
Expand All @@ -169,27 +186,27 @@ def _(expr, estimate):


@_estimate_cost.register(Mul)
def _(expr, estimate):
flops, flags = _estimate_cost.registry[object](expr, estimate)
def _(expr, estimate, seen):
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
return flops, flags


@_estimate_cost.register(INT)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], True
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], True


@_estimate_cost.register(Cast)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], False
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], False


@_estimate_cost.register(Function)
def _(expr, estimate):
def _(expr, estimate, seen):
if q_routine(expr):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if isinstance(expr, DefFunction):
# Bypass user-defined or language-specific functions
Expand All @@ -207,8 +224,8 @@ def _(expr, estimate):


@_estimate_cost.register(Pow)
def _(expr, estimate):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
def _(expr, estimate, seen):
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if estimate:
if expr.exp.is_Number:
Expand All @@ -229,13 +246,15 @@ def _(expr, estimate):


@_estimate_cost.register(Derivative)
def _(expr, estimate):
return _estimate_cost(expr._evaluate(expand=False), estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
return _estimate_cost(expr._evaluate(expand=False), estimate, seen)


@_estimate_cost.register(IndexDerivative)
def _(expr, estimate):
flops, _ = _estimate_cost(expr.expr, estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
flops, _ = _estimate_cost(expr.expr, estimate, seen)

# It's an increment
flops += 1
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def test_factorize(expr, expected):
('Eq(fb, fd.dx)', 10, True),
('Eq(fb, fd.dx._evaluate(expand=False))', 10, False),
('Eq(fb, fd.dx.dy + fa.dx)', 66, False),
# Ensure redundancies aren't counted
('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True),
])
def test_estimate_cost(expr, expected, estimate):
# Note: integer arithmetic isn't counted
Expand Down
30 changes: 28 additions & 2 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
Substitutions, Coefficient, cos, sin)
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
from devito.parameters import switchconfig, configuration
from devito.types import Symbol

Expand Down Expand Up @@ -300,9 +301,34 @@ def test_transpose(self):
op1.apply(time_M=10, u=u1)
assert np.allclose(u.data, u1.data, rtol=10e-6)

def test_redundant_derivatives(self):
grid = Grid(shape=(10, 10, 10))

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
u = TimeFunction(name='u', grid=grid, space_order=4)

# It's the same `u.dx.dy` appearing multiple times, and the cost models
# must be smart enough to detect that!
eq = Eq(u.forward, (f*u.dx.dy + g*u.dx.dy + h*u.dx.dy +
(f*g)*u.dx.dy + (f*h)*u.dx.dy + (g*h)*u.dx.dy))

op = Operator(eq, opt=('advanced', {'expand': False,
'blocklevels': 0}))

# Check generated code
assert len(get_arrays(op)) == 0
assert op._profiler._sections['section0'].sops == 74
exprs = FindNodes(Expression).visit(op)
assert len(exprs) == 6
temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)]
assert len(temps) == 3


class Test2Pass(object):

@switchconfig(safe_math=True)
def test_v0(self):
grid = Grid(shape=(10, 10, 10))

Expand All @@ -312,14 +338,14 @@ def test_v0(self):
v1 = TimeFunction(name='v', grid=grid, space_order=8)

eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)),
Eq(v.forward, (v + u.dx.dy + 1.))]
Eq(v.forward, (v + u.dx.dz + 1.))]

op0 = Operator(eqns)
op1 = Operator(eqns, opt=('advanced', {'expand': False,
'openmp': True}))

# Check generated code
assert op1._profiler._sections['section0'].sops == 41
assert op1._profiler._sections['section0'].sops == 59
assert_structure(op1, ['t',
't,x0_blk0,y0_blk0,x,y,z',
't,x0_blk0,y0_blk0,x,y,z,i0',
Expand Down

0 comments on commit 063d07d

Please sign in to comment.