Skip to content

Commit dba987c

Browse files
authored
[Arith] Simplifications for floormod(x, 2) (#13936)
* [Arith] Simplifications for floormod(x, 2) Because `floormod(x,2)` has only two possible values, it can be simplified more aggressively than most FloorMod expressions. The additional simplifications are derived from `floormod(x,2) + floormod(x+1,2) == 1`, which is true for denominator `2`, along with the usual `floordiv(x,2)*2 + floormod(x,2) == x`, which is true for all denominators. This initially arose from an index expression `floormod(x + 1, 2) * 8192`, for `x ∈ [0, 2)`. This commit allows the expression to be re-written as `x * (-8192) + 8192` and recognized as a strided access.
1 parent f5db8b7 commit dba987c

File tree

6 files changed

+128
-52
lines changed

6 files changed

+128
-52
lines changed

src/arith/iter_affine_map.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,11 @@ class IterMapRewriter : public ExprMutator {
898898
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
899899

900900
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
901+
if (sign < 0 && is_const_int(rhs->extent, 2)) {
902+
lhs->base -= rhs->scale;
903+
sign = 1;
904+
}
905+
901906
tir::ExprDeepEqual equal;
902907
for (size_t i = 0; i < lhs->args.size(); ++i) {
903908
IterSplitExpr lvalue = lhs->args[i];

src/arith/rewrite_simplify.cc

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
278278
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
279279
c2.Eval()->value > 0);
280280

281+
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));
282+
281283
// canonicalization rule
282284
// will try rewrite again after canonicalization.
285+
283286
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + (c1 - y), (c1 - y) + x), (x - y) + c1);
284-
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + c1 + y, x + (c1 + y)), (x + y) + c1);
287+
TVM_TRY_RECURSIVE_REWRITE(matches_one_of((x + c1) + y, x + (c1 + y), x + (y + c1)),
288+
(x + y) + c1);
285289
TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
286290
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
287291

@@ -456,6 +460,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
456460
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
457461
c1.Eval()->value != 0);
458462

463+
TVM_TRY_RECURSIVE_REWRITE(
464+
floordiv(x + c1, 2) - floordiv(x + c2, 2),
465+
floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + (floordiv(c1, 2) - floordiv(c2, 2)));
466+
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2),
467+
floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2));
468+
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2),
469+
floormod(x, 2) * floormod(c1, 2) + floordiv(c1, 2));
470+
459471
TVM_TRY_REWRITE_IF(
460472
x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
461473
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
@@ -475,6 +487,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
475487
floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
476488
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
477489

490+
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2));
491+
478492
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
479493
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
480494
c3.Eval()->value > 0);
@@ -485,6 +499,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
485499
// will try rewrite again after canonicalization.
486500
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
487501
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
502+
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
488503
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
489504
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
490505
} else if (op->dtype.is_float()) {
@@ -864,6 +879,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
864879
TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
865880
TVM_TRY_REWRITE(matches_one_of(floordiv(x * c1, x), floordiv(c1 * x, x)), c1);
866881

882+
TVM_TRY_REWRITE(floordiv(floormod(x, 2) + 1, 2), floormod(x, 2));
883+
867884
// Rules involving 2-operands.
868885
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
869886
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
@@ -975,6 +992,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
975992
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
976993
c2.Eval()->value > 0);
977994

995+
TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
996+
floormod(c1.Eval()->value, 2) == 1);
978997
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
979998
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
980999

@@ -985,12 +1004,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
9851004

9861005
TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), ZeroWithTypeLike(y));
9871006

988-
// try modular analysis
9891007
if (floormod(x, c1).Match(ret)) {
990-
ModularSet mod = analyzer_->modular_set(x.Eval());
9911008
int64_t c1val = c1.Eval()->value;
992-
if (mod->coeff % c1val == 0 && c1val > 0) {
993-
return floormod(mod->base, c1).Eval();
1009+
if (c1val > 0) {
1010+
// try modular analysis
1011+
ModularSet mod = analyzer_->modular_set(x.Eval());
1012+
if (mod->coeff % c1val == 0) {
1013+
return floormod(mod->base, c1).Eval();
1014+
}
1015+
1016+
// floormod(x,c1) is a no-op when x is already in the
1017+
// appropriate range.
1018+
ConstIntBound bound = analyzer_->const_int_bound(x.Eval());
1019+
if (bound->min_value >= 0 && bound->max_value < c1val) {
1020+
return x.Eval();
1021+
}
9941022
}
9951023
}
9961024
}

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,18 @@ def test_compound():
199199
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))
200200

201201

202+
def test_compound_floormod_two():
203+
x = tvm.tir.Var("x", "int32")
204+
fld = tvm.tir.floordiv
205+
flm = tvm.tir.floormod
206+
207+
# extent of 2 are normalized to positive scale
208+
assert_iter_sum_pattern(
209+
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
210+
dom_map=var_dom([(x, 8)]),
211+
)
212+
213+
202214
def test_predicate():
203215
x = tvm.tir.Var("x", "int32")
204216
y = tvm.tir.Var("y", "int32")

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,39 @@ class TestFloormodIndex(BaseCompare):
564564
)
565565

566566

567+
class TestFloorModTwo(BaseCompare):
568+
"""Special-case simplifications for FloorMod(expr,2)
569+
570+
Because FloorMod(expr,2) has only two possible values, it can be
571+
simplified more aggressively than most FloorMod expressions. Some
572+
of these have analogues for other denominators (e.g. x%3 + (x+1)%3
573+
+ (x+2)%3 == 0 + 1 + 2), but they don't appear as often and
574+
require identifying more related terms in order to apply.
575+
576+
(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
577+
"""
578+
579+
x, y, z = te.var("x"), te.var("y"), te.var("z")
580+
test_case = tvm.testing.parameter(
581+
# Removing offsets from floormod
582+
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
583+
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
584+
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
585+
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
586+
# Difference of floordiv yields floormod
587+
TestCase(fld(x + 1, 2) - fld(x, 2), flm(x, 2)),
588+
TestCase(fld(x, 2) - fld(x - 1, 2), flm(x, 2) * -1 + 1),
589+
TestCase(fld(x + 5, 2) - fld(x - 2, 2), flm(x, 2) + 3),
590+
TestCase(fld(x + 5, 2) - fld(x - 3, 2), 4),
591+
TestCase(fld(flm(x, 2) + 1, 2), flm(x, 2)),
592+
# Sum of floordiv and floormod to yield floordiv
593+
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
594+
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
595+
# Removal of floormod where possible
596+
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
597+
)
598+
599+
567600
class TestMinIndex(BaseCompare):
568601
x, y, z = te.var("x"), te.var("y"), te.var("z")
569602
test_case = tvm.testing.parameter(

tests/python/unittest/test_meta_schedule_space_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5
736736
for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96):
737737
with T.block("PadInput_shared"):
738738
v0 = T.axis.spatial(1, 0)
739-
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96)
739+
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (i4_0 % 2 + 1)) // 96)
740740
v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
741741
v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
742742
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])

0 commit comments

Comments
 (0)