Skip to content

Commit 7d79650

Browse files
tqchenGhosts381937
andcommitted
[ARITH] Canonicalize mul-coefficient to rhs
This PR updates the rewrite simplify logic to canonicalize mul-coefficient to rhs. This change is consistent with rest of the code base and allows better simplification of more cases. A test case of floormod with linear offset is added. Co-authored-by: Ghosts381937 <[email protected]>
1 parent 8c9026d commit 7d79650

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
446446
// mul co-efficient folding
447447
TVM_TRY_REWRITE(x + x, x * 2);
448448

449-
TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1));
449+
TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), (y + 1) * x);
450450

451451
TVM_TRY_REWRITE(matches_one_of(x * y + x * z, y * x + x * z, x * y + z * x, y * x + z * x),
452-
x * (y + z));
452+
(y + z) * x);
453453

454454
// DivMod rules
455455
// truc div
@@ -563,12 +563,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
563563
TVM_TRY_REWRITE(matches_one_of(max(x, y) - y, x - min(y, x)), max(x - y, 0));
564564
TVM_TRY_REWRITE(matches_one_of(x - min(x, y), max(y, x) - y), max(0, x - y));
565565

566-
// mul co-efficient folding
566+
// mul co-efficient folding: pefer co-effiicent to stay at rhs
567567
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
568-
TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), x * (y - 1));
569-
TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), x * (1 - y));
568+
TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), (y - 1) * x);
569+
TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), (1 - y) * x);
570570
TVM_TRY_REWRITE(matches_one_of(x * y - x * z, y * x - x * z, x * y - z * x, y * x - z * x),
571-
x * (y - z));
571+
(y - z) * x);
572572

573573
// constant cancelation
574574
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));

tests/python/arith/test_arith_rewrite_simplify.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -391,28 +391,28 @@ class TestAddIndex(BaseCompare):
391391
TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)),
392392
TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)),
393393
TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)),
394-
TestCase(x * y + x * 10, x * (y + 10)),
395-
TestCase(y * x + x * 10, x * (y + 10)),
396-
TestCase(y * x + 10 * x, x * (y + 10)),
397-
TestCase(x * y + 10 * x, x * (y + 10)),
394+
TestCase(x * y + x * 10, (y + 10) * x),
395+
TestCase(y * x + x * 10, (y + 10) * x),
396+
TestCase(y * x + 10 * x, (y + 10) * x),
397+
TestCase(x * y + 10 * x, (y + 10) * x),
398398
TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)),
399-
TestCase(y * x + x, x * (y + 1)),
400-
TestCase(x * y + x, x * (y + 1)),
399+
TestCase(y * x + x, (y + 1) * x),
400+
TestCase(x * y + x, (y + 1) * x),
401401
TestCase((x + 10) + 13, x + 23),
402402
TestCase((x + 10) + (13 + z), x + z + 23),
403-
TestCase(x * y + 10 * x, x * (y + 10)),
404-
TestCase(y * x + x * 3, x * (y + 3)),
403+
TestCase(x * y + 10 * x, (y + 10) * x),
404+
TestCase(y * x + x * 3, (y + 3) * x),
405405
TestCase(x + 3 + y, x + y + 3),
406406
TestCase((3 - y) + x, x - y + 3),
407407
# canonicalization
408408
TestCase(x + 2 + 3 + 4 + x, x * 2 + 9),
409409
TestCase(x + 2 + 3 + 4 + x * 3, x * 4 + 9),
410410
# DivMod rules
411411
# trunc div
412-
TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10)),
412+
TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), (y + 10) * tmod(x, 8)),
413413
TestCase(tdiv(x, 8) * 8 + tmod(x, 8), x),
414414
# floor div
415-
TestCase(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)),
415+
TestCase(y * flm(x, 8) + 10 * flm(x, 8), (y + 10) * flm(x, 8)),
416416
TestCase(fld(x, 8) * 8 + flm(x, 8), x),
417417
TestCase(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)),
418418
)

tests/python/arith/test_arith_simplify.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,18 @@ def test_regression_simplify_inf_recursion():
131131
ana.rewrite_simplify(res)
132132

133133

134+
def test_simplify_floor_mod_with_linear_offset():
135+
"""
136+
Test that the floor_mod is simplified correctly when the offset is linear.
137+
"""
138+
ana = tvm.arith.Analyzer()
139+
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
140+
expr1 = (past_decoder_sequence_length + 1) * 64
141+
divisor1 = (past_decoder_sequence_length + 1) * 32
142+
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0)
143+
divisor2 = 32 * (past_decoder_sequence_length + 1)
144+
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
145+
146+
134147
if __name__ == "__main__":
135148
tvm.testing.main()

0 commit comments

Comments
 (0)