From 8f504254e58da782658ff5307983d046169e098c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 18 Dec 2023 11:44:57 +0000 Subject: [PATCH] compiler: Catch redundant IndexDerivatives early on --- devito/passes/clusters/cse.py | 14 +++++++++++-- devito/passes/clusters/derivatives.py | 4 ++++ tests/test_unexpansion.py | 30 +++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index ddff97f7c6..5e4ce40d36 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -10,7 +10,7 @@ from devito.symbolics import estimate_cost, q_leaf from devito.symbolics.manipulation import _uxreplace from devito.tools import as_list -from devito.types import Eq, Temp +from devito.types import Eq, Symbol, Temp __all__ = ['cse'] @@ -169,8 +169,8 @@ def _(exprs): return mapper -@count.register(IndexDerivative) @count.register(Indexed) +@count.register(Symbol) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. @@ -178,6 +178,16 @@ def _(expr): return Counter() +@count.register(IndexDerivative) +def _(expr): + """ + Handler for symbol-binding objects. There can be many of them and therefore + they should be detected as common subexpressions, but it's either pointless + or forbidden to look inside them. + """ + return Counter([expr]) + + @count.register(Add) @count.register(Mul) @count.register(Pow) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 01e5afea6a..4b82fdc262 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -19,6 +19,10 @@ def lower_index_derivatives(clusters, mode=None, **kwargs): if mode != 'noop': clusters = fuse(clusters, toposort='maximal') + # At this point we can detect redundancies induced by inner derivatives that + # previously were just not detectable via e.g. plain CSE. For example, if + # there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx` + # then it's only after `_lower_index_derivatives` that they're detectable! clusters = CDE(mapper).process(clusters) return clusters diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 0de2d55de7..8a4dcbbfed 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -5,6 +5,7 @@ from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator, Substitutions, Coefficient, cos, sin) from devito.arch.compiler import OneapiCompiler +from devito.ir import Expression, FindNodes, FindSymbols from devito.parameters import switchconfig, configuration from devito.types import Symbol @@ -300,9 +301,34 @@ def test_transpose(self): op1.apply(time_M=10, u=u1) assert np.allclose(u.data, u1.data, rtol=10e-6) + def test_redundant_derivatives(self): + grid = Grid(shape=(10, 10, 10)) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + u = TimeFunction(name='u', grid=grid, space_order=4) + + # It's the same `u.dx.dy` appearing multiple times, and the cost models + # must be smart enough to detect that! + eq = Eq(u.forward, (f*u.dx.dy + g*u.dx.dy + h*u.dx.dy + + (f*g)*u.dx.dy + (f*h)*u.dx.dy + (g*h)*u.dx.dy)) + + op = Operator(eq, opt=('advanced', {'expand': False, + 'blocklevels': 0})) + + # Check generated code + assert len(get_arrays(op)) == 0 + assert op._profiler._sections['section0'].sops == 74 + exprs = FindNodes(Expression).visit(op) + assert len(exprs) == 6 + temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)] + assert len(temps) == 3 + class Test2Pass(object): + @switchconfig(safe_math=True) def test_v0(self): grid = Grid(shape=(10, 10, 10)) @@ -312,14 +338,14 @@ def test_v0(self): v1 = TimeFunction(name='v', grid=grid, space_order=8) eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)), - Eq(v.forward, (v + u.dx.dy + 1.))] + Eq(v.forward, (v + u.dx.dz + 1.))] op0 = Operator(eqns) op1 = Operator(eqns, opt=('advanced', {'expand': False, 'openmp': True})) # Check generated code - assert op1._profiler._sections['section0'].sops == 41 + assert op1._profiler._sections['section0'].sops == 59 assert_structure(op1, ['t', 't,x0_blk0,y0_blk0,x,y,z', 't,x0_blk0,y0_blk0,x,y,z,i0',