Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix handling of redundant derivatives #2284

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']

Expand Down Expand Up @@ -169,15 +169,25 @@ def _(exprs):
return mapper


@count.register(IndexDerivative)
@count.register(Indexed)
@count.register(Symbol)
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([expr])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
Expand Down
4 changes: 4 additions & 0 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def lower_index_derivatives(clusters, mode=None, **kwargs):
if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

# At this point we can detect redundancies induced by inner derivatives that
# previously were just not detectable via e.g. plain CSE. For example, if
# there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx`
# then it's only after `_lower_index_derivatives` that they're detectable!
clusters = CDE(mapper).process(clusters)

return clusters
Expand Down
63 changes: 41 additions & 22 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False):
# We don't use SymPy's count_ops because we do not count integer arithmetic
# (e.g., array index functions such as i+1 in A[i+1])
# Also, the routine below is *much* faster than count_ops
seen = {}
flops = 0
for expr in as_tuple(exprs):
# TODO: this if-then should be part of singledispatch too, but because
Expand All @@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False):
else:
e = expr

flops += _estimate_cost(e, estimate)[0]
flops += _estimate_cost(e, estimate, seen)[0]

return flops
except:
Expand All @@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False):
}


def dont_count_if_seen(func):
"""
This decorator is used to avoid counting the same expression multiple
times. This is necessary because the same expression may appear multiple
times in the same expression tree or even across different expressions.
"""
def wrapper(expr, estimate, seen):
try:
_, flags = seen[expr]
flops = 0
except KeyError:
flops, flags = seen[expr] = func(expr, estimate, seen)
return flops, flags
return wrapper


@singledispatch
def _estimate_cost(expr, estimate):
def _estimate_cost(expr, estimate, seen):
# Retval: flops (int), flag (bool)
# The flag tells wether it's an integer expression (implying flops==0) or not
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if all(flags):
# `expr` is an operation involving integer operands only
Expand All @@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate):

@_estimate_cost.register(Tuple)
@_estimate_cost.register(CallFromPointer)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
except ValueError:
flops, flags = [], []
return sum(flops), all(flags)


@_estimate_cost.register(Integer)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, True


@_estimate_cost.register(Number)
@_estimate_cost.register(ReservedWord)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, False


@_estimate_cost.register(Symbol)
@_estimate_cost.register(Indexed)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
if issubclass(expr.dtype, np.integer):
return 0, True
Expand All @@ -169,27 +186,27 @@ def _(expr, estimate):


@_estimate_cost.register(Mul)
def _(expr, estimate):
flops, flags = _estimate_cost.registry[object](expr, estimate)
def _(expr, estimate, seen):
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
return flops, flags


@_estimate_cost.register(INT)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], True
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], True


@_estimate_cost.register(Cast)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], False
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], False


@_estimate_cost.register(Function)
def _(expr, estimate):
def _(expr, estimate, seen):
if q_routine(expr):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if isinstance(expr, DefFunction):
# Bypass user-defined or language-specific functions
Expand All @@ -207,8 +224,8 @@ def _(expr, estimate):


@_estimate_cost.register(Pow)
def _(expr, estimate):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
def _(expr, estimate, seen):
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if estimate:
if expr.exp.is_Number:
Expand All @@ -229,13 +246,15 @@ def _(expr, estimate):


@_estimate_cost.register(Derivative)
def _(expr, estimate):
return _estimate_cost(expr._evaluate(expand=False), estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
return _estimate_cost(expr._evaluate(expand=False), estimate, seen)


@_estimate_cost.register(IndexDerivative)
def _(expr, estimate):
flops, _ = _estimate_cost(expr.expr, estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be also done for EvalDerivative to catch prematurely evaluated derivs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary
EvalDerivatives are Adds so they will be treated as such. EvalDerivatives are lower-level, more rudimentary objects, so it's OK to be treated as such

flops, _ = _estimate_cost(expr.expr, estimate, seen)

# It's an increment
flops += 1
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def test_factorize(expr, expected):
('Eq(fb, fd.dx)', 10, True),
('Eq(fb, fd.dx._evaluate(expand=False))', 10, False),
('Eq(fb, fd.dx.dy + fa.dx)', 66, False),
# Ensure redundancies aren't counted
('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True),
])
def test_estimate_cost(expr, expected, estimate):
# Note: integer arithmetic isn't counted
Expand Down
30 changes: 28 additions & 2 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
Substitutions, Coefficient, cos, sin)
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
from devito.parameters import switchconfig, configuration
from devito.types import Symbol

Expand Down Expand Up @@ -300,9 +301,34 @@ def test_transpose(self):
op1.apply(time_M=10, u=u1)
assert np.allclose(u.data, u1.data, rtol=10e-6)

def test_redundant_derivatives(self):
grid = Grid(shape=(10, 10, 10))

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
u = TimeFunction(name='u', grid=grid, space_order=4)

# It's the same `u.dx.dy` appearing multiple times, and the cost models
# must be smart enough to detect that!
eq = Eq(u.forward, (f*u.dx.dy + g*u.dx.dy + h*u.dx.dy +
(f*g)*u.dx.dy + (f*h)*u.dx.dy + (g*h)*u.dx.dy))

op = Operator(eq, opt=('advanced', {'expand': False,
'blocklevels': 0}))

# Check generated code
assert len(get_arrays(op)) == 0
assert op._profiler._sections['section0'].sops == 74
exprs = FindNodes(Expression).visit(op)
assert len(exprs) == 6
temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)]
assert len(temps) == 3


class Test2Pass(object):

@switchconfig(safe_math=True)
def test_v0(self):
grid = Grid(shape=(10, 10, 10))

Expand All @@ -312,14 +338,14 @@ def test_v0(self):
v1 = TimeFunction(name='v', grid=grid, space_order=8)

eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)),
Eq(v.forward, (v + u.dx.dy + 1.))]
Eq(v.forward, (v + u.dx.dz + 1.))]

op0 = Operator(eqns)
op1 = Operator(eqns, opt=('advanced', {'expand': False,
'openmp': True}))

# Check generated code
assert op1._profiler._sections['section0'].sops == 41
assert op1._profiler._sections['section0'].sops == 59
assert_structure(op1, ['t',
't,x0_blk0,y0_blk0,x,y,z',
't,x0_blk0,y0_blk0,x,y,z,i0',
Expand Down
Loading