Skip to content

Commit d4d439d

Browse files
adapt merge mulmod opt for OffsetOf computation
1 parent 9256609 commit d4d439d

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

src/tir/ir/buffer.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
7575
}
7676

7777
// Searches for the following types of expr:
78-
// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
79-
// mod_l_expr = c
78+
// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
79+
// mod_l_expr = c2
8080
// mod_r_expr = k1 * k2 * ... * ki
81-
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
81+
// where c1 ~= c2 mod k1 * k2 * ... * ki
82+
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1)
8283
// Currently the we will not search the add/mult combinations exhaustively
8384
// as it will take too much computation.
84-
inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr& mult_expr,
85+
inline std::pair<bool, PrimExpr> MergeMulModInner(arith::Analyzer* analyzer,
86+
const PrimExpr& mult_expr,
8587
const PrimExpr& mod_l_expr,
8688
const PrimExpr& mod_r_expr) {
8789
using namespace tir;
@@ -119,9 +121,10 @@ inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr& mult_expr,
119121
} else if (inner_div_ptr) {
120122
PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
121123
if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) &&
122-
expr_equal(inner_div_ptr->a, mod_l_expr)) {
124+
analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) {
123125
// Found!
124-
PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
126+
PrimExpr ret =
127+
no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a;
125128
return std::make_pair(true, ret);
126129
} else {
127130
return std::make_pair(false, PrimExpr());
@@ -204,7 +207,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
204207
bool inner_find_opt = false;
205208
while (mult_it != mult_exprs.end()) {
206209
std::pair<bool, PrimExpr> ret =
207-
MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second);
210+
MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second);
208211
if (ret.first) {
209212
inner_find_opt = true;
210213
auto temp_mod_it = search_mod_it;

tests/python/unittest/test_tir_buffer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def assert_simplified_equal(index_simplified, index_direct):
137137

138138
idxd = tvm.tir.indexdiv
139139
idxm = tvm.tir.indexmod
140+
140141
# Test Case1
141142
index_simplified = A_stride.offset_of(
142143
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1)
@@ -174,16 +175,25 @@ def assert_simplified_equal(index_simplified, index_direct):
174175
j = te.size_var("j")
175176
k = te.size_var("k")
176177

177-
index_simplified = B.offset_of(
178+
index_simplified1 = B.offset_of(
178179
(
179180
idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14),
180181
idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14),
181182
idxm(idxd((i * 50176 + j * 28672 + k), 1024), 14),
182183
idxm((i * 50176 + j * 28672 + k), 1024),
183184
)
184185
)
186+
index_simplified2 = B.offset_of(
187+
(
188+
idxd(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14),
189+
idxm(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14),
190+
idxm(i * 7 + idxd(k, 1024), 14),
191+
idxm(k, 1024),
192+
)
193+
)
185194
index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k)))
186-
assert_simplified_equal(index_simplified, index_direct)
195+
assert_simplified_equal(index_simplified1, index_direct)
196+
assert_simplified_equal(index_simplified2, index_direct)
187197

188198

189199
@tvm.testing.requires_llvm

0 commit comments

Comments
 (0)