From 7daffaa879f1109454e1160a6688c99598eee363 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Thu, 9 Feb 2023 09:38:23 -0600 Subject: [PATCH 1/9] [Arith] Simplifications for floormod(x, 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Because `floormod(x,2)` has only two possible values, it can be simplified more aggressively than most FloorMod expressions. The additional simplifications are derived from `floormod(x,2) + floormod(x+1,2) == 1`, which is true for denominator `2`, along with the usual `floordiv(x,2)*2 + floormod(x,2) == x`, which is true for all denominators. This initially arose from an index expression `floormod(x + 1, 2) * 8192`, for `x ∈ [0, 2)`. This commit allows the expression to be re-written as `x * (-8192) + 8192` and recognized as a strided access. --- src/arith/rewrite_simplify.cc | 39 ++++++++++++++++--- .../unittest/test_arith_rewrite_simplify.py | 32 +++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ce2c3e1a962e..ac2b70a2aa31 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -278,10 +278,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE(floormod(x + 1, 2) + floormod(x, 2), OneWithTypeLike(x)); + TVM_TRY_REWRITE(floormod(x, 2) + floormod(x + 1, 2), OneWithTypeLike(x)); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2)); + // canonicalization rule // will try rewrite again after canonicalization. + TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + (c1 - y), (c1 - y) + x), (x - y) + c1); - TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + c1 + y, x + (c1 + y)), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(matches_one_of((x + c1) + y, x + (c1 + y), x + (y + c1)), + (x + y) + c1); TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x); @@ -454,6 +460,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y, c1.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE( + floordiv(x + c1, 2) - floordiv(x + c2, 2), + floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + floordiv(c1, 2) - floordiv(c2, 2)); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2), + floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2)); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2), + floormod(x, 2) * floormod(c1, 2) + floordiv(c1, 2)); + TVM_TRY_REWRITE_IF( x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); @@ -473,6 +487,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2)); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floormod(x - 1, 2), floordiv(x - 1, 2)); + TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), c3.Eval()->value > 0); @@ -483,6 +500,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { // will try rewrite again after canonicalization. TVM_TRY_REWRITE(x - c1, x + (0 - c1)); TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1)); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); } else if (op->dtype.is_float()) { @@ -973,6 +991,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1, + floormod(c1.Eval()->value, 2) == 1); TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -983,12 +1003,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), ZeroWithTypeLike(y)); - // try modular analysis if (floormod(x, c1).Match(ret)) { - ModularSet mod = analyzer_->modular_set(x.Eval()); int64_t c1val = c1.Eval()->value; - if (mod->coeff % c1val == 0 && c1val > 0) { - return floormod(mod->base, c1).Eval(); + if (c1val > 0) { + // try modular analysis + ModularSet mod = analyzer_->modular_set(x.Eval()); + if (mod->coeff % c1val == 0) { + return floormod(mod->base, c1).Eval(); + } + + // floormod(x,c1) is a no-op when x is already in the + // appropriate range. + ConstIntBound bound = analyzer_->const_int_bound(x.Eval()); + if (bound->min_value >= 0 && bound->max_value < c1val) { + return x.Eval(); + } } } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 9be5b55ed825..e66f7f75aa79 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -560,6 +560,38 @@ class TestFloormodIndex(BaseCompare): ) +class TestFloorModTwo(BaseCompare): + """Special-case simplifications for FloorMod(expr,2) + + Because FloorMod(expr,2) has only two possible values, it can be + simplified more aggressively than most FloorMod expressions. Some + of these have analogues for other denominators (e.g. x%3 + (x+1)%3 + + (x+2)%3 == 0 + 1 + 2), but they don't appear as often and + require identifying more related terms in order to apply. + + (x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2) + """ + + x, y, z = te.var("x"), te.var("y"), te.var("z") + test_case = tvm.testing.parameter( + # Removing offsets from floormod + TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1), + TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1), + TestCase(flm(x, 2) + flm(x + 1, 2), 1), + TestCase(flm(x + 1, 2) + flm(x, 2), 1), + # Difference of floordiv yields floormod + TestCase(fld(x + 1, 2) - fld(x, 2), flm(x, 2)), + TestCase(fld(x, 2) - fld(x - 1, 2), flm(x, 2) * -1 + 1), + TestCase(fld(x + 5, 2) - fld(x - 2, 2), flm(x, 2) + 3), + TestCase(fld(x + 5, 2) - fld(x - 3, 2), 4), + # Sum of floordiv and floormod to yield floordiv + TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)), + TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)), + # Removal of floormod where possible + TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]), + ) + + class TestMinIndex(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") test_case = tvm.testing.parameter( From fb8b22bf39cfee5d1aa6ad9f2eed0cb8360550d7 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 10 Feb 2023 12:19:26 -0600 Subject: [PATCH 2/9] Added a missing simplification --- src/arith/rewrite_simplify.cc | 2 ++ tests/python/unittest/test_arith_rewrite_simplify.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ac2b70a2aa31..2dbdc74d9430 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -880,6 +880,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x)); TVM_TRY_REWRITE(matches_one_of(floordiv(x * c1, x), floordiv(c1 * x, x)), c1); + TVM_TRY_REWRITE(floordiv(floormod(x, 2) + 1, 2), floormod(x, 2)); + // Rules involving 2-operands. TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index e66f7f75aa79..23d0543bc4f7 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -584,6 +584,7 @@ class TestFloorModTwo(BaseCompare): TestCase(fld(x, 2) - fld(x - 1, 2), flm(x, 2) * -1 + 1), TestCase(fld(x + 5, 2) - fld(x - 2, 2), flm(x, 2) + 3), TestCase(fld(x + 5, 2) - fld(x - 3, 2), 4), + TestCase(fld(flm(x, 2) + 1, 2), flm(x, 2)), # Sum of floordiv and floormod to yield floordiv TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)), TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)), From 532362739dbf9e3378df953be7d22708341dd890 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 10 Feb 2023 15:10:24 -0600 Subject: [PATCH 3/9] Update reference examples for inject software pipeline unit tests --- ..._tir_transform_inject_software_pipeline.py | 104 +++++++++--------- 1 file changed, 51 insertions(+), 53 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1e5fd8843ba3..66f61d15948b 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -139,8 +139,8 @@ def transformed_simple_compute( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1]]) - T.writes([B[(i + 1) % 2, tx, 0]]) - B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + T.writes([B[1 - i % 2, tx, 0]]) + B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) @@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation( ): with T.block(): T.reads([A[tx, i + 1]]) - T.writes([B[(i + 1) % 2, tx, 0]]) - B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + T.writes([B[1 - i % 2, tx, 0]]) + B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) @@ -266,7 +266,7 @@ def transformed_three_stage_compute( T.where(i == 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) + C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) @@ -278,7 +278,7 @@ def transformed_three_stage_compute( with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) + C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) @@ -291,7 +291,7 @@ def transformed_three_stage_compute( T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) + C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) @@ -391,12 +391,12 @@ def transformed_dag_interleaving( BS[tx, 0] = B[tx, i + 1] + T.float32(2) with T.block(): T.reads(AS[tx, 0]) - T.writes(AL[(i + 1) % 2, 0, 0]) - AL[(i + 1) % 2, 0, 0] = AS[tx, 0] + T.writes(AL[1 - i % 2, 0, 0]) + AL[1 - i % 2, 0, 0] = AS[tx, 0] with T.block(): T.reads(BS[tx, 0]) - T.writes(BL[(i + 1) % 2, 0, 0]) - BL[(i + 1) % 2, 0, 0] = BS[tx, 0] + T.writes(BL[1 - i % 2, 0, 0]) + BL[1 - i % 2, 0, 0] = BS[tx, 0] with T.block(): T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0]) T.writes(C[tx, i]) @@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1, 0:16]]) - T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + T.writes([A_shared[1 - i % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, i + 1, j]]) - T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) - A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + T.writes([A_shared[1 - i % 2, tx, 0, j]]) + A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j] with T.block(): T.reads([A_shared[i % 2, tx, i, 0]]) T.writes([B[0, tx, i, 0]]) @@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[i % 2, tx, i, j + 1]]) - T.writes([B[(j + 1) % 2, tx, i, 0]]) - B[(j + 1) % 2, tx, i, 0] = A_shared[ - i % 2, tx, 0, j + 1 - ] * T.float32(2) + T.writes([B[1 - j % 2, tx, i, 0]]) + B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32( + 2 + ) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) @@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[1, tx, 15, j + 1]]) - T.writes([B[(j + 1) % 2, tx, 15, 0]]) - B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + T.writes([B[1 - j % 2, tx, 15, 0]]) + B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1, 0:16]]) - T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + T.writes([A_shared[1 - i % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, i + 1, j]]) - T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) - A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + T.writes([A_shared[1 - i % 2, tx, 0, j]]) + A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j] with T.block(): T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): with T.block(): T.reads([A_shared[i % 2, tx, i, j + 1]]) - T.writes([B[(j + 1) % 2, tx, i, 0]]) - B[(j + 1) % 2, tx, i, 0] = A_shared[ - i % 2, tx, 0, j + 1 - ] * T.float32(2) + T.writes([B[1 - j % 2, tx, i, 0]]) + B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32( + 2 + ) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) with T.block(): - T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) + T.reads([A_shared[1 - i % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) - B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) + B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * T.float32(2) with T.block(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) @@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[1, tx, 15, j + 1]]) - T.writes([B[(j + 1) % 2, tx, 15, 0]]) - B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + T.writes([B[1 - j % 2, tx, 15, 0]]) + B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving( for j in T.serial(0, 15): with T.block(): T.reads([A_local[tx, i, j + 1]]) - T.writes([B[(j + 1) % 2, tx, i, 0]]) - B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + T.writes([B[1 - j % 2, tx, i, 0]]) + B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) @@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving( for j in T.serial(0, 15): with T.block(): T.reads([A_local[tx, 15, j + 1]]) - T.writes([B[(j + 1) % 2, tx, 15, 0]]) - B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + T.writes([B[1 - j % 2, tx, 15, 0]]) + B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -929,27 +929,25 @@ def transformed_nested_pipeline_double_buffer( for j in T.serial(0, 15): with T.block(): T.reads([A_local[i % 2, tx, i, j + 1]]) - T.writes([B[(j + 1) % 2, tx, i, 0]]) - B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( - 2 - ) + T.writes([B[1 - j % 2, tx, i, 0]]) + B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) with T.block(): T.reads([A_shared[tx, 0, 0:16]]) - T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) + T.writes([A_local[1 - i % 2, 0, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A_shared[tx, 0, j]]) - T.writes([A_local[(i + 1) % 2, 0, 0, j]]) + T.writes([A_local[1 - i % 2, 0, 0, j]]) T.block_attr({"double_buffer_scope": 0}) - A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] + A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 1, j] with T.block(): - T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) + T.reads([A_local[1 - i % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) - B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) + B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * T.float32(2) with T.block(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) @@ -963,8 +961,8 @@ def transformed_nested_pipeline_double_buffer( for j in T.serial(0, 15): with T.block(): T.reads([A_local[1, tx, 15, j + 1]]) - T.writes([B[(j + 1) % 2, tx, 15, 0]]) - B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) + T.writes([B[1 - j % 2, tx, 15, 0]]) + B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -1135,7 +1133,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): with T.block(): T.where(i + 1 < 16) T.reads(A[tx, i + 1]) - T.writes(B[(i + 1) % 2, tx, 0]) + T.writes(B[1 - i % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) @@ -1350,8 +1348,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N B[i % 2, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(i == 1 and i - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[1 - i % 2, tx, 0]) + T.writes(C[1 - i % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1372,8 +1370,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.where(i + 2 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[1 - i % 2, tx, 0]) + T.writes(C[1 - i % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1394,8 +1392,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N for i in T.unroll(2): with T.block(): T.where(i + 16 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[1 - i % 2, tx, 0]) + T.writes(C[1 - i % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0 - i): From 31803c7751397803562a137069e49fcda3c8268f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Feb 2023 09:20:25 -0600 Subject: [PATCH 4/9] Update DetectIterMap to recognize floormod 2 patterns --- src/arith/iter_affine_map.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index af6e47b7a066..4013c827a14f 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -898,6 +898,11 @@ class IterMapRewriter : public ExprMutator { PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { + if (sign < 0 && is_const_int(rhs->extent, 2)) { + lhs->base -= rhs->scale; + sign = 1; + } + tir::ExprDeepEqual equal; for (size_t i = 0; i < lhs->args.size(); ++i) { IterSplitExpr lvalue = lhs->args[i]; From 8825bc44e50bb20c71913b03366dfb9d21e0ed32 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Mar 2023 14:27:36 -0600 Subject: [PATCH 5/9] Removed redundant rewrites --- src/arith/rewrite_simplify.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 2dbdc74d9430..f424e5f55679 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -278,8 +278,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); - TVM_TRY_REWRITE(floormod(x + 1, 2) + floormod(x, 2), OneWithTypeLike(x)); - TVM_TRY_REWRITE(floormod(x, 2) + floormod(x + 1, 2), OneWithTypeLike(x)); TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2)); // canonicalization rule @@ -488,7 +486,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2)); - TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floormod(x - 1, 2), floordiv(x - 1, 2)); TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), From da8bc8c3e9776321d2a63dba50e4f14ccbba24ff Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Mar 2023 14:27:58 -0600 Subject: [PATCH 6/9] Added unit test for floormod(x,2) behavior in affine iter --- tests/python/unittest/test_arith_iter_affine_map.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 0d24b59bb45e..0bb4c98b7b15 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -199,6 +199,18 @@ def test_compound(): assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) +def test_compound_floormod_two(): + x = tvm.tir.Var("x", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + # extent of 2 are normalized to positive scale + assert_iter_sum_pattern( + expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)}, + dom_map=var_dom([(x, 8)]), + ) + + def test_predicate(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") From e90bae9792b0cd4343025ed4ed0fd121917cdee9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Mar 2023 14:29:51 -0600 Subject: [PATCH 7/9] Add parentheses to generate `x+(c1+c2)` instead of `(x+c1)+c2)` --- src/arith/rewrite_simplify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f424e5f55679..8fdfb6986e49 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -460,7 +460,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_RECURSIVE_REWRITE( floordiv(x + c1, 2) - floordiv(x + c2, 2), - floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + floordiv(c1, 2) - floordiv(c2, 2)); + floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + (floordiv(c1, 2) - floordiv(c2, 2))); TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2), floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2)); TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2), From 538f6a2c2bb0b8046e011dda85ba72bf93d37503 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Mar 2023 09:44:04 -0600 Subject: [PATCH 8/9] Resolve test breakage from merge --- .../unittest/test_tir_transform_inject_software_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index f9aa530f6df8..7e59172bdd83 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1133,7 +1133,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): with T.block(): T.where(i + 1 < 16) T.reads(A[tx, i + 1]) - T.writes(B[1 - i % 2, tx, 0]) + T.writes(B[(i + 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) From b3ed8ee163280d493c5ec3ce6e08fc2454f5bf7a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 31 Mar 2023 09:39:22 -0500 Subject: [PATCH 9/9] Update expected output in meta-schedule unit test --- tests/python/unittest/test_meta_schedule_space_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index bc674064d1d6..ed416e1fbec6 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -736,7 +736,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96) + v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (i4_0 % 2 + 1)) // 96) v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32) v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(inputs[v0, v1 - 1, v2 - 1, v3])