Skip to content

Commit

Permalink
compiler: switch cse key to namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 16, 2024
1 parent f639e82 commit 65d7d69
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
55 changes: 30 additions & 25 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, namedtuple
from functools import singledispatch

from sympy import Add, Function, Indexed, Mul, Pow
Expand All @@ -13,12 +13,15 @@
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.tools import as_list, frozendict
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']


Counted = namedtuple('Candidate', 'expr, conditionals')


class CTemp(Temp):

"""
Expand Down Expand Up @@ -90,14 +93,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

while True:
# Detect redundancies
counted = count(processed, None).items()
targets = OrderedDict([(k, estimate_cost(k[0], True))
counted = count(processed).items()
targets = OrderedDict([(k, estimate_cost(k.expr, True))
for k, v in counted if v > 1])

# Rule out Dimension-independent data dependencies
targets = OrderedDict([(k, v) for k, v in targets.items()
if not k[0].free_symbols & exclude])

if not k.expr.free_symbols & exclude])
if not targets or max(targets.values()) < min_cost:
break

Expand All @@ -112,13 +113,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
updated = []
for e in processed:
pe = e
pe_c = e.conditionals
for (k, c), v in chosen:
if not c == pe_c:
for k, v in chosen:
if not k.conditionals == e.conditionals:
continue
pe, changed = _uxreplace(pe, {k: v})
pe, changed = _uxreplace(pe, {k.expr: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
updated.append(pe.func(v, k.expr, operation=None))
scheduled.append(v)
updated.append(pe)
processed = updated
Expand Down Expand Up @@ -160,59 +160,64 @@ def _compact_temporaries(exprs, exclude):


@singledispatch
def count(expr, conds):
def count(expr):
"""
Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`.
"""
mapper = Counter()
for a in expr.args:
mapper.update(count(a, None))
mapper.update(count(a))
return mapper


@count.register(list)
@count.register(tuple)
def _(exprs, conds):
def _(exprs):
mapper = Counter()
for e in exprs:
mapper.update(count(e, None))
mapper.update(count(e))

return mapper


@count.register(ClusterizedEq)
def _(exprs, conds):
conditionals = exprs.conditionals
return count(exprs.rhs, conditionals)
def _(expr):
mapper = count(expr.rhs)
try:
cond = expr.conditionals
except AttributeError:
cond = frozendict()
return {Counted(e, cond): v for e, v in mapper.items()}


@count.register(Indexed)
@count.register(Symbol)
def _(expr, conds):
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr, conds):
def _(expr):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([(expr, conds)])
return Counter([expr])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
@count.register(Function)
def _(expr, conds):
def _(expr):
mapper = Counter()
for a in expr.args:
mapper.update(count(a, conds))
mapper.update(count(a))

mapper[(expr, conds)] += 1
mapper[expr] += 1

return mapper
4 changes: 0 additions & 4 deletions devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def substitutions(self):
def implicit_dims(self):
return self._implicit_dims

@property
def conditionals(self):
return None

@cached_property
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,7 +2783,7 @@ def test_fullopt(self):

assert summary1[('section0', None)].ops == 31
assert summary1[('section1', None)].ops == 88
assert summary1[('section2', None)].ops == 28
assert summary1[('section2', None)].ops == 25
assert np.isclose(summary1[('section0', None)].oi, 1.767, atol=0.001)

assert np.allclose(u0.data, u1.data, atol=10e-5)
Expand Down

0 comments on commit 65d7d69

Please sign in to comment.