From 65d7d697ddebfc5bc5661bcf6539327d596851f6 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 16 Jul 2024 08:12:53 -0400 Subject: [PATCH] compiler: switch cse key to namedtuple --- devito/passes/clusters/cse.py | 55 +++++++++++++++++++---------------- devito/types/equation.py | 4 --- tests/test_dse.py | 2 +- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index f892a259592..9ac028f8df0 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -1,4 +1,4 @@ -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, namedtuple from functools import singledispatch from sympy import Add, Function, Indexed, Mul, Pow @@ -13,12 +13,15 @@ from devito.passes.clusters.utils import makeit_ssa from devito.symbolics import estimate_cost, q_leaf from devito.symbolics.manipulation import _uxreplace -from devito.tools import as_list +from devito.tools import as_list, frozendict from devito.types import Eq, Symbol, Temp __all__ = ['cse'] +Counted = namedtuple('Candidate', 'expr, conditionals') + + class CTemp(Temp): """ @@ -90,14 +93,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): while True: # Detect redundancies - counted = count(processed, None).items() - targets = OrderedDict([(k, estimate_cost(k[0], True)) + counted = count(processed).items() + targets = OrderedDict([(k, estimate_cost(k.expr, True)) for k, v in counted if v > 1]) - # Rule out Dimension-independent data dependencies targets = OrderedDict([(k, v) for k, v in targets.items() - if not k[0].free_symbols & exclude]) - + if not k.expr.free_symbols & exclude]) if not targets or max(targets.values()) < min_cost: break @@ -112,13 +113,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): updated = [] for e in processed: pe = e - pe_c = e.conditionals - for (k, c), v in chosen: - if not c == pe_c: + for k, v in chosen: + if not k.conditionals == e.conditionals: continue - pe, changed = _uxreplace(pe, {k: v}) + pe, changed = _uxreplace(pe, {k.expr: v}) if changed and v not in scheduled: - updated.append(pe.func(v, k, operation=None)) + updated.append(pe.func(v, k.expr, operation=None)) scheduled.append(v) updated.append(pe) processed = updated @@ -160,34 +160,39 @@ def _compact_temporaries(exprs, exclude): @singledispatch -def count(expr, conds): +def count(expr): """ Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`. """ mapper = Counter() for a in expr.args: - mapper.update(count(a, None)) + mapper.update(count(a)) return mapper @count.register(list) @count.register(tuple) -def _(exprs, conds): +def _(exprs): mapper = Counter() for e in exprs: - mapper.update(count(e, None)) + mapper.update(count(e)) + return mapper @count.register(ClusterizedEq) -def _(exprs, conds): - conditionals = exprs.conditionals - return count(exprs.rhs, conditionals) +def _(expr): + mapper = count(expr.rhs) + try: + cond = expr.conditionals + except AttributeError: + cond = frozendict() + return {Counted(e, cond): v for e, v in mapper.items()} @count.register(Indexed) @count.register(Symbol) -def _(expr, conds): +def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. """ @@ -195,24 +200,24 @@ def _(expr, conds): @count.register(IndexDerivative) -def _(expr, conds): +def _(expr): """ Handler for symbol-binding objects. There can be many of them and therefore they should be detected as common subexpressions, but it's either pointless or forbidden to look inside them. """ - return Counter([(expr, conds)]) + return Counter([expr]) @count.register(Add) @count.register(Mul) @count.register(Pow) @count.register(Function) -def _(expr, conds): +def _(expr): mapper = Counter() for a in expr.args: - mapper.update(count(a, conds)) + mapper.update(count(a)) - mapper[(expr, conds)] += 1 + mapper[expr] += 1 return mapper diff --git a/devito/types/equation.py b/devito/types/equation.py index 3847e9f4f71..c4af3385c38 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -137,10 +137,6 @@ def substitutions(self): def implicit_dims(self): return self._implicit_dims - @property - def conditionals(self): - return None - @cached_property def _uses_symbolic_coefficients(self): return bool(self._symbolic_functions) diff --git a/tests/test_dse.py b/tests/test_dse.py index f6965f18bc2..6817e9db6bb 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -2783,7 +2783,7 @@ def test_fullopt(self): assert summary1[('section0', None)].ops == 31 assert summary1[('section1', None)].ops == 88 - assert summary1[('section2', None)].ops == 28 + assert summary1[('section2', None)].ops == 25 assert np.isclose(summary1[('section0', None)].oi, 1.767, atol=0.001) assert np.allclose(u0.data, u1.data, atol=10e-5)