Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: fix cse with different conditionals #2410

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
compiler: fix cse with different conditionals
mloubout committed Jul 16, 2024

Verified

This commit was signed with the committer’s verified signature.
mloubout Mathias Louboutin
commit f639e823d3086e0bb7a97661b5d1bbcad75ab6d3
2 changes: 1 addition & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def detect_accesses(exprs):
other_dims = set()
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims)
other_dims.update(e.implicit_dims or {})
other_dims = filter_sorted(other_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

Expand Down
40 changes: 25 additions & 15 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sympy.core.basic import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.ir import Cluster, Scope, cluster_pass, ClusterizedEq
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
Expand Down Expand Up @@ -90,12 +90,13 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

while True:
# Detect redundancies
counted = count(processed).items()
targets = OrderedDict([(k, estimate_cost(k, True)) for k, v in counted if v > 1])
counted = count(processed, None).items()
targets = OrderedDict([(k, estimate_cost(k[0], True))
mloubout marked this conversation as resolved.
Show resolved Hide resolved
for k, v in counted if v > 1])

# Rule out Dimension-independent data dependencies
targets = OrderedDict([(k, v) for k, v in targets.items()
if not k.free_symbols & exclude])
if not k[0].free_symbols & exclude])
mloubout marked this conversation as resolved.
Show resolved Hide resolved

if not targets or max(targets.values()) < min_cost:
break
Expand All @@ -111,7 +112,10 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
updated = []
for e in processed:
pe = e
for k, v in chosen:
pe_c = e.conditionals
mloubout marked this conversation as resolved.
Show resolved Hide resolved
for (k, c), v in chosen:
if not c == pe_c:
continue
pe, changed = _uxreplace(pe, {k: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
Expand Down Expand Up @@ -156,53 +160,59 @@ def _compact_temporaries(exprs, exclude):


@singledispatch
def count(expr):
def count(expr, conds):
"""
Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`.
"""
mapper = Counter()
for a in expr.args:
mapper.update(count(a))
mapper.update(count(a, None))
mloubout marked this conversation as resolved.
Show resolved Hide resolved
return mapper


@count.register(list)
@count.register(tuple)
def _(exprs):
def _(exprs, conds):
mapper = Counter()
for e in exprs:
mapper.update(count(e))
mapper.update(count(e, None))
return mapper


@count.register(ClusterizedEq)
def _(exprs, conds):
mloubout marked this conversation as resolved.
Show resolved Hide resolved
conditionals = exprs.conditionals
mloubout marked this conversation as resolved.
Show resolved Hide resolved
return count(exprs.rhs, conditionals)
mloubout marked this conversation as resolved.
Show resolved Hide resolved


@count.register(Indexed)
@count.register(Symbol)
def _(expr):
def _(expr, conds):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr):
def _(expr, conds):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([expr])
return Counter([(expr, conds)])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
@count.register(Function)
def _(expr):
def _(expr, conds):
mapper = Counter()
for a in expr.args:
mapper.update(count(a))
mapper.update(count(a, conds))

mapper[expr] += 1
mapper[(expr, conds)] += 1

return mapper
4 changes: 4 additions & 0 deletions devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def substitutions(self):
def implicit_dims(self):
return self._implicit_dims

@property
def conditionals(self):
return None

@cached_property
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)
Expand Down
53 changes: 51 additions & 2 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, Ge)
sin, sqrt, Ge, Lt)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.finite_differences.differentiable import diffify
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -191,6 +191,55 @@ def test_cse_w_conditionals():
assert len(FindNodes(Conditional).visit(op)) == 1


def test_cse_w_multi_conditionals():
grid = Grid(shape=(10, 10, 10))
x, _, _ = grid.dimensions

cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4),
indirect=True)

cd2 = ConditionalDimension(name='cd2', parent=x, condition=Lt(x, 4),
indirect=True)

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
a0 = Function(name='a0', grid=grid)
a1 = Function(name='a1', grid=grid)
a2 = Function(name='a2', grid=grid)
a3 = Function(name='a3', grid=grid)

eq0 = Eq(h, a0, implicit_dims=cd)
eq1 = Eq(a0, a0 + f*g, implicit_dims=cd)
eq2 = Eq(a1, a1 + f*g, implicit_dims=cd)
eq3 = Eq(a2, a2 + f*g, implicit_dims=cd2)
eq4 = Eq(a3, a3 + f*g, implicit_dims=cd2)

op = Operator([eq0, eq1, eq3])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 0

op = Operator([eq0, eq1, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 1

op = Operator([eq0, eq1, eq2, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 2


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down Expand Up @@ -2734,7 +2783,7 @@ def test_fullopt(self):

assert summary1[('section0', None)].ops == 31
assert summary1[('section1', None)].ops == 88
assert summary1[('section2', None)].ops == 25
assert summary1[('section2', None)].ops == 28
mloubout marked this conversation as resolved.
Show resolved Hide resolved
assert np.isclose(summary1[('section0', None)].oi, 1.767, atol=0.001)

assert np.allclose(u0.data, u1.data, atol=10e-5)
Expand Down