@@ -278,10 +278,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
278278 TVM_TRY_REWRITE_IF (floordiv (floormod (x, c2) + c1, c2) + floordiv (x, c2), floordiv (x + c1, c2),
279279 c2.Eval ()->value > 0 );
280280
281+ TVM_TRY_RECURSIVE_REWRITE (floordiv (x, 2 ) + floormod (x, 2 ), floordiv (x + 1 , 2 ));
282+
281283 // canonicalization rule
282284 // will try rewrite again after canonicalization.
285+
283286 TVM_TRY_RECURSIVE_REWRITE (matches_one_of (x + (c1 - y), (c1 - y) + x), (x - y) + c1);
284- TVM_TRY_RECURSIVE_REWRITE (matches_one_of (x + c1 + y, x + (c1 + y)), (x + y) + c1);
287+ TVM_TRY_RECURSIVE_REWRITE (matches_one_of ((x + c1) + y, x + (c1 + y), x + (y + c1)),
288+ (x + y) + c1);
285289 TVM_TRY_RECURSIVE_REWRITE (x + max (y, z), max (y, z) + x);
286290 TVM_TRY_RECURSIVE_REWRITE (x + min (y, z), min (y, z) + x);
287291
@@ -456,6 +460,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
456460 TVM_TRY_REWRITE_IF (floordiv (x - y, c1) * c1 - x, 0 - floormod (x - y, c1) - y,
457461 c1.Eval ()->value != 0 );
458462
463+ TVM_TRY_RECURSIVE_REWRITE (
464+ floordiv (x + c1, 2 ) - floordiv (x + c2, 2 ),
465+ floormod (x, 2 ) * (floormod (c1, 2 ) - floormod (c2, 2 )) + (floordiv (c1, 2 ) - floordiv (c2, 2 )));
466+ TVM_TRY_RECURSIVE_REWRITE (floordiv (x, 2 ) - floordiv (x + c2, 2 ),
467+ floormod (x, 2 ) * (0 - floormod (c2, 2 )) - floordiv (c2, 2 ));
468+ TVM_TRY_RECURSIVE_REWRITE (floordiv (x + c1, 2 ) - floordiv (x, 2 ),
469+ floormod (x, 2 ) * floormod (c1, 2 ) + floordiv (c1, 2 ));
470+
459471 TVM_TRY_REWRITE_IF (
460472 x * c2 - floordiv (x, c1) * c3, floormod (x, c1) * c2,
461473 c1.Eval ()->value != 0 && c3.Eval ()->value == c1.Eval ()->value * c2.Eval ()->value );
@@ -475,6 +487,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
475487 floordiv (x - y, c1) * c3 - x * c2, (0 - floormod (x - y, c1) - y) * c2,
476488 c1.Eval ()->value != 0 && c3.Eval ()->value == c1.Eval ()->value * c2.Eval ()->value );
477489
490+ TVM_TRY_RECURSIVE_REWRITE (floordiv (x + 1 , 2 ) - floormod (x, 2 ), floordiv (x, 2 ));
491+
478492 TVM_TRY_REWRITE_IF (floordiv (x + c1, c3) - floordiv (x + c2, c3),
479493 floordiv (floormod (x + floormod (c2, c3), c3) + (c1 - c2), c3),
480494 c3.Eval ()->value > 0 );
@@ -485,6 +499,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
485499 // will try rewrite again after canonicalization.
486500 TVM_TRY_REWRITE (x - c1, x + (0 - c1));
487501 TVM_TRY_RECURSIVE_REWRITE ((x + c1) - y, (x - y) + c1);
502+ TVM_TRY_RECURSIVE_REWRITE (x - (y + c1), (x - y) + (0 - c1));
488503 TVM_TRY_RECURSIVE_REWRITE (x - (y - z), (x + z) - y);
489504 TVM_TRY_RECURSIVE_REWRITE (x - y * c1, x + y * (0 - c1));
490505 } else if (op->dtype .is_float ()) {
@@ -864,6 +879,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
864879 TVM_TRY_REWRITE (floordiv (x, x), OneWithTypeLike (x));
865880 TVM_TRY_REWRITE (matches_one_of (floordiv (x * c1, x), floordiv (c1 * x, x)), c1);
866881
882+ TVM_TRY_REWRITE (floordiv (floormod (x, 2 ) + 1 , 2 ), floormod (x, 2 ));
883+
867884 // Rules involving 2-operands.
868885 TVM_TRY_REWRITE_IF (floordiv (min (x * c1, y), c2), min (x * floordiv (c1, c2), floordiv (y, c2)),
869886 c2.Eval ()->value > 0 && c1.Eval ()->value % c2.Eval ()->value == 0 );
@@ -975,6 +992,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
975992 TVM_TRY_REWRITE_IF (floormod (x * c1 + y, c2), floormod (x * floormod (c1, c2) + y, c2),
976993 c2.Eval ()->value > 0 );
977994
995+ TVM_TRY_RECURSIVE_REWRITE_IF (floormod (x + c1, 2 ), floormod (x, 2 ) * (-1 ) + 1 ,
996+ floormod (c1.Eval ()->value , 2 ) == 1 );
978997 TVM_TRY_REWRITE_IF (floormod (x + c1, c2), floormod (x, c2),
979998 c2.Eval ()->value > 0 && c1.Eval ()->value % c2.Eval ()->value == 0 );
980999
@@ -985,12 +1004,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
9851004
9861005 TVM_TRY_REWRITE (matches_one_of (floormod (x * y, y), floormod (y * x, y)), ZeroWithTypeLike (y));
9871006
988- // try modular analysis
9891007 if (floormod (x, c1).Match (ret)) {
990- ModularSet mod = analyzer_->modular_set (x.Eval ());
9911008 int64_t c1val = c1.Eval ()->value ;
992- if (mod->coeff % c1val == 0 && c1val > 0 ) {
993- return floormod (mod->base , c1).Eval ();
1009+ if (c1val > 0 ) {
1010+ // try modular analysis
1011+ ModularSet mod = analyzer_->modular_set (x.Eval ());
1012+ if (mod->coeff % c1val == 0 ) {
1013+ return floormod (mod->base , c1).Eval ();
1014+ }
1015+
1016+ // floormod(x,c1) is a no-op when x is already in the
1017+ // appropriate range.
1018+ ConstIntBound bound = analyzer_->const_int_bound (x.Eval ());
1019+ if (bound->min_value >= 0 && bound->max_value < c1val) {
1020+ return x.Eval ();
1021+ }
9941022 }
9951023 }
9961024 }
0 commit comments