Skip to content

Commit dc53a6c

Browse files
authored
[Arith] Simplify the result of non-divisible floordiv (#15881)
1 parent fa4aeee commit dc53a6c

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/arith/iter_affine_map.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,15 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
19041904
/* lower_factor = */ padded->lower_factor * rhs,
19051905
/* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)),
19061906
/* scale = */ padded->scale);
1907+
} else if (is_one(padded->lower_factor) &&
1908+
analyzer_->CanProveEqual(padded->extent, padded->source->extent)) {
1909+
// floordiv(floormod(floordiv(iter, lower_factor), ext), c)
1910+
// = floordiv(iter, c)
1911+
// when lower_factor = 1 and ext = iter.extent
1912+
new_split = IterSplitExpr(padded->source,
1913+
/* lower_factor = */ rhs,
1914+
/* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)),
1915+
/* scale = */ padded->scale);
19071916
} else {
19081917
new_split = IterSplitExpr(IterMark(padded, padded->extent),
19091918
/* lower_factor = */ rhs,

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,11 +1227,29 @@ def test_iter_map_simplify_unit_loop_order():
12271227

12281228

12291229
def assert_normalize_to_iter_sum(index, input_iters, args, base):
1230+
"""Assert the result of arith.normalize_to_iter_sum is correct
1231+
1232+
Parameters
1233+
----------
1234+
index : tvm.tir.PrimExpr
1235+
The index to be normalized
1236+
input_iters : Mapping[Var, Range]
1237+
The input iterators
1238+
args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
1239+
The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be
1240+
either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the
1241+
iterator normalized to PrimExpr and the second element is the scale.
1242+
base : tvm.tir.PrimExpr
1243+
The expected base
1244+
"""
12301245
res = tvm.arith.normalize_to_iter_sum(index, input_iters)
12311246

12321247
assert isinstance(res, tvm.arith.IterSumExpr)
12331248
assert len(res.args) == len(args)
12341249
for split, item in zip(res.args, args):
1250+
if isinstance(item, tvm.arith.IterSplitExpr):
1251+
tvm.ir.assert_structural_equal(split, item)
1252+
continue
12351253
tvm.testing.assert_prim_expr_equal(split.scale, item[1])
12361254
tvm.testing.assert_prim_expr_equal(
12371255
tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
@@ -1245,6 +1263,7 @@ def test_normalize_to_iter_sum():
12451263
z = tvm.tir.Var("z", "int64")
12461264
a = tvm.tir.Var("a", "int64")
12471265
n = tvm.tir.Var("n", "int64")
1266+
flm = tvm.tir.floormod
12481267

12491268
assert_normalize_to_iter_sum(
12501269
z + ((y + x * 4 + 2) * n) + 3,
@@ -1285,6 +1304,21 @@ def test_normalize_to_iter_sum():
12851304
0,
12861305
)
12871306

1307+
# non-divisible
1308+
assert_normalize_to_iter_sum(
1309+
x // 5,
1310+
var_dom([(x, 4096)]),
1311+
[
1312+
tvm.arith.IterSplitExpr(
1313+
tvm.arith.IterMark(x, 4096),
1314+
lower_factor=tvm.tir.const(5, "int64"),
1315+
extent=tvm.tir.const(820, "int64"),
1316+
scale=tvm.tir.const(1, "int64"),
1317+
)
1318+
],
1319+
0,
1320+
)
1321+
12881322
# iter simplify
12891323
assert_normalize_to_iter_sum(
12901324
z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),

0 commit comments

Comments
 (0)