From 727206984e1b485262fe7d96e118258ca936c7fe Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 15 Dec 2023 17:10:19 +0000 Subject: [PATCH 1/2] compiler: Improve estimate_cost --- devito/symbolics/inspection.py | 63 ++++++++++++++++++++++------------ tests/test_dse.py | 2 ++ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 45b5dce754..99e752abce 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False): # We don't use SymPy's count_ops because we do not count integer arithmetic # (e.g., array index functions such as i+1 in A[i+1]) # Also, the routine below is *much* faster than count_ops + seen = {} flops = 0 for expr in as_tuple(exprs): # TODO: this if-then should be part of singledispatch too, but because @@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False): else: e = expr - flops += _estimate_cost(e, estimate)[0] + flops += _estimate_cost(e, estimate, seen)[0] return flops except: @@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False): } +def dont_count_if_seen(func): + """ + This decorator is used to avoid counting the same expression multiple + times. This is necessary because the same expression may appear multiple + times in the same expression tree or even across different expressions. + """ + def wrapper(expr, estimate, seen): + try: + _, flags = seen[expr] + flops = 0 + except KeyError: + flops, flags = seen[expr] = func(expr, estimate, seen) + return flops, flags + return wrapper + + @singledispatch -def _estimate_cost(expr, estimate): +def _estimate_cost(expr, estimate, seen): # Retval: flops (int), flag (bool) # The flag tells wether it's an integer expression (implying flops==0) or not - flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if all(flags): # `expr` is an operation involving integer operands only @@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate): @_estimate_cost.register(Tuple) @_estimate_cost.register(CallFromPointer) -def _(expr, estimate): +def _(expr, estimate, seen): try: - flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) except ValueError: flops, flags = [], [] return sum(flops), all(flags) @_estimate_cost.register(Integer) -def _(expr, estimate): +def _(expr, estimate, seen): return 0, True @_estimate_cost.register(Number) @_estimate_cost.register(ReservedWord) -def _(expr, estimate): +def _(expr, estimate, seen): return 0, False @_estimate_cost.register(Symbol) @_estimate_cost.register(Indexed) -def _(expr, estimate): +def _(expr, estimate, seen): try: if issubclass(expr.dtype, np.integer): return 0, True @@ -169,27 +186,27 @@ def _(expr, estimate): @_estimate_cost.register(Mul) -def _(expr, estimate): - flops, flags = _estimate_cost.registry[object](expr, estimate) +def _(expr, estimate, seen): + flops, flags = _estimate_cost.registry[object](expr, estimate, seen) if {S.One, S.NegativeOne}.intersection(expr.args): flops -= 1 return flops, flags @_estimate_cost.register(INT) -def _(expr, estimate): - return _estimate_cost(expr.base, estimate)[0], True +def _(expr, estimate, seen): + return _estimate_cost(expr.base, estimate, seen)[0], True @_estimate_cost.register(Cast) -def _(expr, estimate): - return _estimate_cost(expr.base, estimate)[0], False +def _(expr, estimate, seen): + return _estimate_cost(expr.base, estimate, seen)[0], False @_estimate_cost.register(Function) -def _(expr, estimate): +def _(expr, estimate, seen): if q_routine(expr): - flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if isinstance(expr, DefFunction): # Bypass user-defined or language-specific functions @@ -207,8 +224,8 @@ def _(expr, estimate): @_estimate_cost.register(Pow) -def _(expr, estimate): - flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args]) +def _(expr, estimate, seen): + flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if estimate: if expr.exp.is_Number: @@ -229,13 +246,15 @@ def _(expr, estimate): @_estimate_cost.register(Derivative) -def _(expr, estimate): - return _estimate_cost(expr._evaluate(expand=False), estimate) +@dont_count_if_seen +def _(expr, estimate, seen): + return _estimate_cost(expr._evaluate(expand=False), estimate, seen) @_estimate_cost.register(IndexDerivative) -def _(expr, estimate): - flops, _ = _estimate_cost(expr.expr, estimate) +@dont_count_if_seen +def _(expr, estimate, seen): + flops, _ = _estimate_cost(expr.expr, estimate, seen) # It's an increment flops += 1 diff --git a/tests/test_dse.py b/tests/test_dse.py index 5fc88b5d94..8f4d756a8b 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -270,6 +270,8 @@ def test_factorize(expr, expected): ('Eq(fb, fd.dx)', 10, True), ('Eq(fb, fd.dx._evaluate(expand=False))', 10, False), ('Eq(fb, fd.dx.dy + fa.dx)', 66, False), + # Ensure redundancies aren't counted + ('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True), ]) def test_estimate_cost(expr, expected, estimate): # Note: integer arithmetic isn't counted From 8f504254e58da782658ff5307983d046169e098c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 18 Dec 2023 11:44:57 +0000 Subject: [PATCH 2/2] compiler: Catch redundant IndexDerivatives early on --- devito/passes/clusters/cse.py | 14 +++++++++++-- devito/passes/clusters/derivatives.py | 4 ++++ tests/test_unexpansion.py | 30 +++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index ddff97f7c6..5e4ce40d36 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -10,7 +10,7 @@ from devito.symbolics import estimate_cost, q_leaf from devito.symbolics.manipulation import _uxreplace from devito.tools import as_list -from devito.types import Eq, Temp +from devito.types import Eq, Symbol, Temp __all__ = ['cse'] @@ -169,8 +169,8 @@ def _(exprs): return mapper -@count.register(IndexDerivative) @count.register(Indexed) +@count.register(Symbol) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. @@ -178,6 +178,16 @@ def _(expr): return Counter() +@count.register(IndexDerivative) +def _(expr): + """ + Handler for symbol-binding objects. There can be many of them and therefore + they should be detected as common subexpressions, but it's either pointless + or forbidden to look inside them. + """ + return Counter([expr]) + + @count.register(Add) @count.register(Mul) @count.register(Pow) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 01e5afea6a..4b82fdc262 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -19,6 +19,10 @@ def lower_index_derivatives(clusters, mode=None, **kwargs): if mode != 'noop': clusters = fuse(clusters, toposort='maximal') + # At this point we can detect redundancies induced by inner derivatives that + # previously were just not detectable via e.g. plain CSE. For example, if + # there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx` + # then it's only after `_lower_index_derivatives` that they're detectable! clusters = CDE(mapper).process(clusters) return clusters diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 0de2d55de7..8a4dcbbfed 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -5,6 +5,7 @@ from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator, Substitutions, Coefficient, cos, sin) from devito.arch.compiler import OneapiCompiler +from devito.ir import Expression, FindNodes, FindSymbols from devito.parameters import switchconfig, configuration from devito.types import Symbol @@ -300,9 +301,34 @@ def test_transpose(self): op1.apply(time_M=10, u=u1) assert np.allclose(u.data, u1.data, rtol=10e-6) + def test_redundant_derivatives(self): + grid = Grid(shape=(10, 10, 10)) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + u = TimeFunction(name='u', grid=grid, space_order=4) + + # It's the same `u.dx.dy` appearing multiple times, and the cost models + # must be smart enough to detect that! + eq = Eq(u.forward, (f*u.dx.dy + g*u.dx.dy + h*u.dx.dy + + (f*g)*u.dx.dy + (f*h)*u.dx.dy + (g*h)*u.dx.dy)) + + op = Operator(eq, opt=('advanced', {'expand': False, + 'blocklevels': 0})) + + # Check generated code + assert len(get_arrays(op)) == 0 + assert op._profiler._sections['section0'].sops == 74 + exprs = FindNodes(Expression).visit(op) + assert len(exprs) == 6 + temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)] + assert len(temps) == 3 + class Test2Pass(object): + @switchconfig(safe_math=True) def test_v0(self): grid = Grid(shape=(10, 10, 10)) @@ -312,14 +338,14 @@ def test_v0(self): v1 = TimeFunction(name='v', grid=grid, space_order=8) eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)), - Eq(v.forward, (v + u.dx.dy + 1.))] + Eq(v.forward, (v + u.dx.dz + 1.))] op0 = Operator(eqns) op1 = Operator(eqns, opt=('advanced', {'expand': False, 'openmp': True})) # Check generated code - assert op1._profiler._sections['section0'].sops == 41 + assert op1._profiler._sections['section0'].sops == 59 assert_structure(op1, ['t', 't,x0_blk0,y0_blk0,x,y,z', 't,x0_blk0,y0_blk0,x,y,z,i0',