Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement commute() #224

Merged
merged 4 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/exo/API_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,39 @@ def reorder_stmts(proc, block_cursor):
loopir = Schedules.DoReorderStmt(loopir, s1, s2).result()
return Procedure(loopir, _provenance_eq_Procedure=proc)

@sched_op([ExprCursorA(many=True)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this operation apply to a single location or many locations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with both, do you have a suggestion?

def commute_expr(proc, expr_cursors):
"""
commute the binary operation of '+' and '*'.

args:
expr_cursors - a list of cursors to the binary operation

rewrite:
`a * b` <-- expr_cursor
-->
`b * a`

or

`a + b` <-- expr_cursor
-->
`b + a`
"""

exprs = [ ec._impl._node() for ec in expr_cursors ]
for e in exprs:
if not isinstance(e, LoopIR.BinOp) or (e.op != '+' and e.op != '*'):
raise TypeError(f"only '+' or '*' can commute, got {e.op}")
if any(not e.type.is_numeric() for e in exprs):
raise TypeError("only numeric (not index or size) expressions "
"can commute by commute_expr()")

loopir = proc._loopir_proc
loopir = Schedules.DoCommuteExpr(loopir, exprs).result()
return Procedure(loopir, _provenance_eq_Procedure=proc)


@sched_op([ExprCursorA(many=True), NameA, BoolA])
def bind_expr(proc, expr_cursors, new_name, cse=False):
"""
Expand Down
14 changes: 14 additions & 0 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,19 @@ def map_e(self, e):
return super().map_e(e)


class _DoCommuteExpr(LoopIR_Rewrite):
def __init__(self, proc, exprs):
self.exprs = exprs
super().__init__(proc)
self.proc = InferEffects(self.proc).result()

def map_e(self, e):
if e in self.exprs:
assert isinstance(e, LoopIR.BinOp)
return e.update(lhs=e.rhs, rhs=e.lhs)
else:
return super().map_e(e)


class _BindExpr(LoopIR_Rewrite):
def __init__(self, proc, new_name, exprs, cse=False):
Expand Down Expand Up @@ -3643,3 +3656,4 @@ class Schedules:
DoLiftAllocSimple = _DoLiftAllocSimple
DoFissionAfterSimple = _DoFissionAfterSimple
DoProductLoop = _DoProductLoop
DoCommuteExpr = _DoCommuteExpr
1 change: 1 addition & 0 deletions src/exo/stdlib/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
delete_pass,
reorder_stmts,
bind_expr,
commute_expr,
#
# subprocedure oriented operations
extract_subproc,
Expand Down
2 changes: 2 additions & 0 deletions tests/golden/test_schedules/test_commute.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def foo(x: R[3] @ DRAM, y: R[3] @ DRAM, z: R @ DRAM):
z = y[2] * x[0]
2 changes: 2 additions & 0 deletions tests/golden/test_schedules/test_commute3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def foo(x: R[3] @ DRAM, y: R[3] @ DRAM, z: R @ DRAM):
z = (x[1] + y[1] + y[2]) * (x[0] + y[0])
32 changes: 32 additions & 0 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
from exo import ParseFragmentError
from exo.stdlib.scheduling import *

def test_commute(golden):
@proc
def foo(x : R[3], y : R[3], z : R):
z = x[0]*y[2]
assert str(commute_expr(foo, 'x[0] * y[_]')) == golden

def test_commute2():
@proc
def foo(x : R[3], y : R[3], z : R):
z = x[0] + y[0] + x[1] + y[1]

with pytest.raises(SchedulingError, match='failed to find matches'):
# TODO: Currently, expression pattern matching fails to find
# 'y[0]+x[1]' because LoopIR.BinOp is structured as (x[0], (y[0], (x[1], y[1]))).
# I think pattern matching should be powerful to find this.
commute_expr(foo, 'y[0] + x[1]')

def test_commute3(golden):
@proc
def foo(x : R[3], y : R[3], z : R):
z = (x[0] + y[0]) * (x[1] + y[1] + y[2])
assert str(commute_expr(foo, '(x[_] + y[_]) * (x[_] + y[_] + y[_])')) == golden

def test_commute4():
@proc
def foo(x : R[3], y : R[3], z : R):
z = x[0] - y[2]

with pytest.raises(TypeError, match="can commute"):
commute_expr(foo, 'x[0] - y[_]')


def test_product_loop(golden):
@proc
def foo(n : size):
Expand Down