Skip to content

Commit

Permalink
compiler: switch cse key to namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 16, 2024
1 parent f639e82 commit 46b15c4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
61 changes: 33 additions & 28 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, namedtuple
from functools import singledispatch

from sympy import Add, Function, Indexed, Mul, Pow
from sympy import Add, Function, Indexed, Mul, Pow, Eq as Eqs
try:
from sympy.core.core import ordering_of_classes
except ImportError:
# Moved in 1.13
from sympy.core.basic import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass, ClusterizedEq
from devito.ir import Cluster, Scope, cluster_pass
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.tools import as_list, frozendict
from devito.types import Eq, Symbol, Temp

__all__ = ['cse']


Counted = namedtuple('Candidate', 'expr, conditionals')


class CTemp(Temp):

"""
Expand Down Expand Up @@ -90,14 +93,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

while True:
# Detect redundancies
counted = count(processed, None).items()
targets = OrderedDict([(k, estimate_cost(k[0], True))
counted = count(processed).items()
targets = OrderedDict([(k, estimate_cost(k.expr, True))
for k, v in counted if v > 1])

# Rule out Dimension-independent data dependencies
targets = OrderedDict([(k, v) for k, v in targets.items()
if not k[0].free_symbols & exclude])

if not k.expr.free_symbols & exclude])
if not targets or max(targets.values()) < min_cost:
break

Expand All @@ -112,13 +113,12 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
updated = []
for e in processed:
pe = e
pe_c = e.conditionals
for (k, c), v in chosen:
if not c == pe_c:
for k, v in chosen:
if not k.conditionals == e.conditionals:
continue
pe, changed = _uxreplace(pe, {k: v})
pe, changed = _uxreplace(pe, {k.expr: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
updated.append(pe.func(v, k.expr, operation=None))
scheduled.append(v)
updated.append(pe)
processed = updated
Expand Down Expand Up @@ -160,59 +160,64 @@ def _compact_temporaries(exprs, exclude):


@singledispatch
def count(expr, conds):
def count(expr):
"""
Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`.
"""
mapper = Counter()
for a in expr.args:
mapper.update(count(a, None))
mapper.update(count(a))
return mapper


@count.register(list)
@count.register(tuple)
def _(exprs, conds):
def _(exprs):
mapper = Counter()
for e in exprs:
mapper.update(count(e, None))
mapper.update(count(e))

return mapper


@count.register(ClusterizedEq)
def _(exprs, conds):
conditionals = exprs.conditionals
return count(exprs.rhs, conditionals)
@count.register(Eqs)
def _(expr):
mapper = count(expr.rhs)
try:
cond = expr.conditionals
except AttributeError:
cond = frozendict()
return {Counted(e, cond): v for e, v in mapper.items()}


@count.register(Indexed)
@count.register(Symbol)
def _(expr, conds):
def _(expr):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr, conds):
def _(expr):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([(expr, conds)])
return Counter([expr])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
@count.register(Function)
def _(expr, conds):
def _(expr):
mapper = Counter()
for a in expr.args:
mapper.update(count(a, conds))
mapper.update(count(a))

mapper[(expr, conds)] += 1
mapper[expr] += 1

return mapper
4 changes: 2 additions & 2 deletions devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import cached_property

from devito.finite_differences import default_rules
from devito.tools import as_tuple
from devito.tools import as_tuple, frozendict
from devito.types.lazy import Evaluable

__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin']
Expand Down Expand Up @@ -139,7 +139,7 @@ def implicit_dims(self):

@property
def conditionals(self):
return None
return frozendict()

@cached_property
def _uses_symbolic_coefficients(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,7 +2783,7 @@ def test_fullopt(self):

assert summary1[('section0', None)].ops == 31
assert summary1[('section1', None)].ops == 88
assert summary1[('section2', None)].ops == 28
assert summary1[('section2', None)].ops == 25
assert np.isclose(summary1[('section0', None)].oi, 1.767, atol=0.001)

assert np.allclose(u0.data, u1.data, atol=10e-5)
Expand Down

0 comments on commit 46b15c4

Please sign in to comment.