Skip to content

Commit bcfa98d

Browse files
tqchenjunrushao
authored andcommitted
[ARITH] Enhance buffer shape bound deduction to include offset (apache#15228)
This PR enhances buffer shape hint so shape expressions like n - 1 will deduce n >= 1
1 parent 416ed05 commit bcfa98d

File tree

3 files changed

+55
-11
lines changed

3 files changed

+55
-11
lines changed

src/arith/analyzer.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,25 +65,42 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
6565
}
6666

6767
void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
68-
// split out the symbolic and non-symbolic part
68+
// decompose value as symbol * scale + offset
69+
int64_t offset = 0;
70+
PrimExpr symbol_scale = tir::make_const(value.dtype(), 0);
71+
72+
auto fcollect_sum = [&](PrimExpr val, int sign) {
73+
if (const auto* intimm = val.as<IntImmNode>()) {
74+
offset += intimm->value * sign;
75+
} else {
76+
if (sign > 0) {
77+
symbol_scale = symbol_scale + val;
78+
} else {
79+
symbol_scale = symbol_scale - val;
80+
}
81+
}
82+
};
83+
UnpackSum(value, fcollect_sum);
84+
85+
// split out the symbol and non-symbolic part
6986
int64_t cscale = 1;
70-
PrimExpr symbolic = tir::make_const(value.dtype(), 1);
71-
auto fcollect = [&](PrimExpr val) {
87+
PrimExpr symbol = tir::make_const(value.dtype(), 1);
88+
auto fcollect_prod = [&](PrimExpr val) {
7289
if (const auto* intimm = val.as<IntImmNode>()) {
7390
cscale *= intimm->value;
7491
} else {
75-
symbolic = symbolic * val;
92+
symbol = symbol * val;
7693
}
7794
};
78-
UnpackReduction<tir::MulNode>(value, fcollect);
95+
UnpackReduction<tir::MulNode>(symbol_scale, fcollect_prod);
7996
if (cscale <= 0) return;
8097
// override the constant int bound by marking it as non-negative
8198
// NOTE: there might be future opportunities of more bound hint
8299
// this is a simple step and covers all the current needs
83100
//
84101
// We may consider enhance the sub analyzer to directly take
85102
// MarkPositiveVar so their bounds do not overlap
86-
if (const auto* var_ptr = symbolic.as<VarNode>()) {
103+
if (const auto* var_ptr = symbol.as<VarNode>()) {
87104
Var var = GetRef<Var>(var_ptr);
88105
// skip non-index type, keep it to be compatible
89106
// with any_dim that do not represent any value
@@ -92,7 +109,8 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
92109
// mark the constant bound is sufficient
93110
// we cannot mark interval set as that will cause relaxation of the var
94111
// during bound proof which is not our intention
95-
this->const_int_bound.Update(var, ConstIntBound(0, ConstIntBound::kPosInf), allow_override);
112+
this->const_int_bound.Update(var, ConstIntBound(-offset, ConstIntBound::kPosInf),
113+
allow_override);
96114
}
97115
}
98116

src/arith/product_normal_form.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
4747
}
4848
}
4949

50+
/**
51+
* \brief Unpack chain of add sub by calling each leaf via fleaf
52+
* \param value The expression value.
53+
* \tparam FLeaf The callback function at leaf.
54+
*/
55+
template <typename FLeaf>
56+
inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) {
57+
if (const tir::AddNode* node = value.as<tir::AddNode>()) {
58+
UnpackSum(node->a, fleaf, sign);
59+
UnpackSum(node->b, fleaf, sign);
60+
} else if (const tir::SubNode* node = value.as<tir::SubNode>()) {
61+
UnpackSum(node->a, fleaf, sign);
62+
UnpackSum(node->b, fleaf, -sign);
63+
} else {
64+
fleaf(value, sign);
65+
}
66+
}
67+
5068
/*!
5169
* \brief Helper function to multiply extent and and re-normalize.
5270
*

tests/python/unittest/test_tir_transform_simplify.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,10 +1734,6 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32):
17341734

17351735

17361736
class TestBufferShapeConstraint(BaseBeforeAfter):
1737-
"""If enabled, rewrite boolean expressions into AND of OR"""
1738-
1739-
convert_boolean_to_and_of_ors = True
1740-
17411737
def before(a: T.handle):
17421738
n = T.int64()
17431739
A = T.match_buffer(a, (n * 32,), "float32")
@@ -1749,5 +1745,17 @@ def expected(a: T.handle):
17491745
A[T.int64(0)] = T.float32(0)
17501746

17511747

1748+
class TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
1749+
def before(a: T.handle):
1750+
n = T.int64()
1751+
A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
1752+
A[T.min(T.int64(1), n)] = T.float32(0)
1753+
1754+
def expected(a: T.handle):
1755+
n = T.int64()
1756+
A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
1757+
A[T.int64(1)] = T.float32(0)
1758+
1759+
17521760
if __name__ == "__main__":
17531761
tvm.testing.main()

0 commit comments

Comments
 (0)