Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Patch CSE in presence of conditionals #2392

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,26 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

# Create temporaries
hit = max(targets.values())
temps = [Eq(make(), k) for k, v in targets.items() if v == hit]
chosen = [(k, make()) for k, v in targets.items() if v == hit]

# Apply replacements
# The extracted temporaries are inserted before the first expression
# that contains it
scheduled = []
updated = []
for e in processed:
pe = e
for t in temps:
pe, changed = _uxreplace(pe, {t.rhs: t.lhs})
if changed and t not in updated:
updated.append(t)
for k, v in chosen:
pe, changed = _uxreplace(pe, {k: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
scheduled.append(v)
updated.append(pe)
processed = updated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use updated henceforth, rather than using processed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym ? we keep on looping and processed needs to be updated? what change are you proposing ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore me then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed an indent


# Update `exclude` for the same reasons as above -- to rule out CSE across
# Dimension-independent data dependences
exclude.update({t.lhs for t in temps})
exclude.update(scheduled)

# At this point we may have useless temporaries (e.g., r0=r1). Let's drop them
processed = _compact_temporaries(processed, exclude)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt)
sin, sqrt, Ge)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.finite_differences.differentiable import diffify
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -168,6 +168,29 @@ def test_cse_temp_order():
assert type(args[2]) is CTemp


def test_cse_w_conditionals():
grid = Grid(shape=(10, 10, 10))
x, _, _ = grid.dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: x = grid.dimensions[0] would be tidier imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tidier but (nitpicking) the one above performs an implicit assertion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fairs


cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4),
indirect=True)

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
a0 = Function(name='a0', grid=grid)
a1 = Function(name='a1', grid=grid)

eqns = [Eq(h, a0, implicit_dims=cd),
Eq(a0, a0 + f*g, implicit_dims=cd),
Eq(a1, a1 + f*g, implicit_dims=cd)]

op = Operator(eqns)

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 1


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down
Loading