Skip to content

Commit c59bc29

Browse files
authored
[Arith] Add simplification rule for x - max(x+y, z) (#14271)
This parallels an existing simplification rule for `x - min(x,y, z)`, applying the same cancellation for `max`.
1 parent 852f97d commit c59bc29

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
387387

388388
TVM_TRY_REWRITE(matches_one_of(x - min(x + y, z), x - min(y + x, z)), max(0 - y, x - z));
389389
TVM_TRY_REWRITE(matches_one_of(x - min(z, x + y), x - min(z, y + x)), max(x - z, 0 - y));
390+
TVM_TRY_REWRITE(matches_one_of(x - max(x + y, z), x - max(y + x, z)), min(0 - y, x - z));
391+
TVM_TRY_REWRITE(matches_one_of(x - max(z, x + y), x - max(z, y + x)), min(x - z, 0 - y));
390392

391393
TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x));
392394
TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x));

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ class TestSubIndex(BaseCompare):
360360
TestCase(tvm.te.max(x, y) - tvm.te.max(y, x), 0),
361361
TestCase(tvm.te.min(x, y) - tvm.te.min(x + 10, y + 10), -10),
362362
TestCase(tvm.te.min(x + 10, y + 1) - tvm.te.min(x, y - 9), 10),
363+
TestCase(x - tvm.te.max(x + y, 0), tvm.te.min(0 - y, x)),
364+
TestCase(x - tvm.te.max(0, x + y), tvm.te.min(x, 0 - y)),
365+
TestCase(x - tvm.te.min(x + y, 0), tvm.te.max(0 - y, x)),
366+
TestCase(x - tvm.te.min(0, x + y), tvm.te.max(x, 0 - y)),
363367
# DivMod patterns
364368
# truc div
365369
TestCase(x - tdiv(x, 3) * 3, tmod(x, 3)),

0 commit comments

Comments
 (0)