diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index ed2c40da72a1..377f8bb7c9b1 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -722,6 +722,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; + if (auto opt = TryCombineSplitFromSameSource(expr)) { + expr = opt.value(); + if (expr->args.size() < 1) return expr; + } Optional opt = TryFuseIters(expr, check_level_); if (opt.defined()) { return opt.value(); @@ -995,9 +999,6 @@ class IterMapRewriter : public ExprMutator { * \return The sum with the fused IterMark and extra offset if succeed. */ Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { - if (auto opt = TryCombineSplitFromSameSource(expr)) { - expr = opt.value(); - } // select the iterators in order std::vector visited(expr->args.size(), false); int base_index = FindBaseIter(expr, visited, NullOpt); @@ -1553,6 +1554,10 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o return IterSumExpr(); } else if (sum->args.size() == 1) { return sum; + } else if (auto opt = TryCombineSplitFromSameSource(sum)) { + if (opt.value()->args.size() == 1) { + return opt.value(); + } } auto opt_fused = TryFuseIters(sum, check_level_); if (!opt_fused) { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index cbca1bb325d8..640d7592ad88 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -83,7 +83,7 @@ def assert_iter_sum_pattern( tvm.ir.assert_structural_equal(sum_expr, expect_iter) -def assert_iter_map_simplfy( +def assert_iter_map_simplify( expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True ): keys = list(expect_dict.keys()) @@ -1120,28 +1120,28 @@ def test_iter_map_simplify_symbolic_case(): def simple_fuse0(x): return (x // n) * n + x % n - assert_iter_map_simplfy({simple_fuse0(x): x}, var_dom([(x, n * 32)])) + assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)])) - assert_iter_map_simplfy({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) + assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) def fsymbolic_fuse0(x): return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n - assert_iter_map_simplfy({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) - assert_iter_map_simplfy({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) + assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse1(x): return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n - assert_iter_map_simplfy({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) - assert_iter_map_simplfy({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) + assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n - assert_iter_map_simplfy({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) def test_iter_map_simplify_symbolic_predicate(): @@ -1155,7 +1155,7 @@ def simple_fuse0(x): return (x // n) * n + x % n z = x * 32 + y - assert_iter_map_simplfy( + assert_iter_map_simplify( {simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16) ) @@ -1163,13 +1163,26 @@ def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n z = x * 64 + y - assert_iter_map_simplfy( + assert_iter_map_simplify( {fsymbolic_fuse2(z): z}, var_dom([(x, (n * n + 1) // 2), (y, 64)]), predicate=(z < n * n * 32), ) +def test_iter_map_simplify_symbolic_reshape(): + n = tvm.tir.Var("n", "int64") + fused = tvm.tir.Var("fused", "int64") + + ax0 = (fused // 4096) // n + ax1 = (fused // 4096) % n + ax2 = fused % 4096 + + rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096 + + assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)])) + + def test_iter_map_simplify_unit_loop_order(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") @@ -1178,18 +1191,18 @@ def test_iter_map_simplify_unit_loop_order(): # trivial iterators can be found at any when comparing via scale # ensure order unchange - assert_iter_map_simplfy( + assert_iter_map_simplify( {x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False ) # Even with simplifcation, it should follow the original order - assert_iter_map_simplfy( + assert_iter_map_simplify( {x + y + (z // 4) * 4 + z % 4: z + x + y}, var_dom([(x, 1), (y, 1), (z, 32)]), simplify_trivial_iterators=False, ) - assert_iter_map_simplfy( + assert_iter_map_simplify( {y + 64 - x % 2 * 64: y + 64 - x % 2 * 64}, var_dom([(x, 6), (y, 64)]), simplify_trivial_iterators=False, @@ -1197,7 +1210,7 @@ def test_iter_map_simplify_unit_loop_order(): # When we have iterators that have same scale but one of them come # with unit extent, we should prioritize unit extent - assert_iter_map_simplfy( + assert_iter_map_simplify( {x // 128 + y + z: y + x // 128 + z}, var_dom([(x, 128), (y, 128), (z, 1)]), simplify_trivial_iterators=False,