Skip to content

Commit 492fc5b

Browse files
tqchenGhosts381937
andcommitted
[ARITH] Canonicalize mul-coefficient to rhs
This PR updates the rewrite simplify logic to canonicalize mul-coefficient to rhs. This change is consistent with rest of the code base and allows better simplification of more cases. A test case of floormod with linear offset is added. Co-authored-by: Ghosts381937 <[email protected]>
1 parent 8c9026d commit 492fc5b

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
446446
// mul co-efficient folding
447447
TVM_TRY_REWRITE(x + x, x * 2);
448448

449-
TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1));
449+
TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), (y + 1) * x);
450450

451451
TVM_TRY_REWRITE(matches_one_of(x * y + x * z, y * x + x * z, x * y + z * x, y * x + z * x),
452-
x * (y + z));
452+
(y + z) * x);
453453

454454
// DivMod rules
455455
// truc div
@@ -563,12 +563,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
563563
TVM_TRY_REWRITE(matches_one_of(max(x, y) - y, x - min(y, x)), max(x - y, 0));
564564
TVM_TRY_REWRITE(matches_one_of(x - min(x, y), max(y, x) - y), max(0, x - y));
565565

566-
// mul co-efficient folding
566+
// mul co-efficient folding: pefer co-effiicent to stay at rhs
567567
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
568-
TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), x * (y - 1));
569-
TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), x * (1 - y));
568+
TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), (y - 1) * x);
569+
TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), (1 - y) * x);
570570
TVM_TRY_REWRITE(matches_one_of(x * y - x * z, y * x - x * z, x * y - z * x, y * x - z * x),
571-
x * (y - z));
571+
(y - z) * x);
572572

573573
// constant cancelation
574574
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));

tests/python/arith/test_arith_rewrite_simplify.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -391,28 +391,28 @@ class TestAddIndex(BaseCompare):
391391
TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)),
392392
TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)),
393393
TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)),
394-
TestCase(x * y + x * 10, x * (y + 10)),
395-
TestCase(y * x + x * 10, x * (y + 10)),
396-
TestCase(y * x + 10 * x, x * (y + 10)),
397-
TestCase(x * y + 10 * x, x * (y + 10)),
394+
TestCase(x * y + x * 10, (y + 10) * x),
395+
TestCase(y * x + x * 10, (y + 10) * x),
396+
TestCase(y * x + 10 * x, (y + 10) * x),
397+
TestCase(x * y + 10 * x, (y + 10) * x),
398398
TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)),
399-
TestCase(y * x + x, x * (y + 1)),
400-
TestCase(x * y + x, x * (y + 1)),
399+
TestCase(y * x + x, (y + 1) * x),
400+
TestCase(x * y + x, (y + 1) * x),
401401
TestCase((x + 10) + 13, x + 23),
402402
TestCase((x + 10) + (13 + z), x + z + 23),
403-
TestCase(x * y + 10 * x, x * (y + 10)),
404-
TestCase(y * x + x * 3, x * (y + 3)),
403+
TestCase(x * y + 10 * x, (y + 10) * x),
404+
TestCase(y * x + x * 3, (y + 3) * x),
405405
TestCase(x + 3 + y, x + y + 3),
406406
TestCase((3 - y) + x, x - y + 3),
407407
# canonicalization
408408
TestCase(x + 2 + 3 + 4 + x, x * 2 + 9),
409409
TestCase(x + 2 + 3 + 4 + x * 3, x * 4 + 9),
410410
# DivMod rules
411411
# trunc div
412-
TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10)),
412+
TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), (y + 10) * tmod(x, 8)),
413413
TestCase(tdiv(x, 8) * 8 + tmod(x, 8), x),
414414
# floor div
415-
TestCase(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)),
415+
TestCase(y * flm(x, 8) + 10 * flm(x, 8), (y + 10) * flm(x, 8)),
416416
TestCase(fld(x, 8) * 8 + flm(x, 8), x),
417417
TestCase(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)),
418418
)
@@ -436,10 +436,10 @@ class TestSubIndex(BaseCompare):
436436
TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)),
437437
# mul co-efficient foldng
438438
TestCase(x - x, 0),
439-
TestCase(x * y - x, x * (y + (-1))),
440-
TestCase(x * y - 10 * x, x * (y + (-10))),
441-
TestCase(y * x - x * z, x * (y - z)),
442-
TestCase(y * x - z * x, x * (y - z)),
439+
TestCase(x * y - x, (y + (-1)) * x),
440+
TestCase(x * y - 10 * x, (y + (-10)) * x),
441+
TestCase(y * x - x * z, (y - z) * x),
442+
TestCase(y * x - z * x, (y - z) * x),
443443
TestCase(x + 10 - 20, x + (-10)),
444444
# 4-operands pattern
445445
TestCase((x + y) - (x + z), y - z),

tests/python/arith/test_arith_simplify.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,18 @@ def test_regression_simplify_inf_recursion():
131131
ana.rewrite_simplify(res)
132132

133133

134+
def test_simplify_floor_mod_with_linear_offset():
135+
"""
136+
Test that the floor_mod is simplified correctly when the offset is linear.
137+
"""
138+
ana = tvm.arith.Analyzer()
139+
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
140+
expr1 = (past_decoder_sequence_length + 1) * 64
141+
divisor1 = (past_decoder_sequence_length + 1) * 32
142+
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0)
143+
divisor2 = 32 * (past_decoder_sequence_length + 1)
144+
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
145+
146+
134147
if __name__ == "__main__":
135148
tvm.testing.main()

tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def test_no_normalization_without_commoning():
352352
def func_distributivity(
353353
B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
354354
) -> None:
355-
B[i1] = x * (y + z)
355+
B[i1] = (y + z) * x
356356
B[i2] = x * y + x * z
357357

358358

0 commit comments

Comments
 (0)