Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,42 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
}

void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
// split out the symbolic and non-symbolic part
// decompose value as symbol * scale + offset
int64_t offset = 0;
PrimExpr symbol_scale = tir::make_const(value.dtype(), 0);

auto fcollect_sum = [&](PrimExpr val, int sign) {
if (const auto* intimm = val.as<IntImmNode>()) {
offset += intimm->value * sign;
} else {
if (sign > 0) {
symbol_scale = symbol_scale + val;
} else {
symbol_scale = symbol_scale - val;
}
}
};
UnpackSum(value, fcollect_sum);

// split out the symbol and non-symbolic part
int64_t cscale = 1;
PrimExpr symbolic = tir::make_const(value.dtype(), 1);
auto fcollect = [&](PrimExpr val) {
PrimExpr symbol = tir::make_const(value.dtype(), 1);
auto fcollect_prod = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
} else {
symbolic = symbolic * val;
symbol = symbol * val;
}
};
UnpackReduction<tir::MulNode>(value, fcollect);
UnpackReduction<tir::MulNode>(symbol_scale, fcollect_prod);
if (cscale <= 0) return;
// override the constant int bound by marking it as non-negative
// NOTE: there might be future opportunities of more bound hint
// this is a simple step and covers all the current needs
//
// We may consider enhance the sub analyzer to directly take
// MarkPositiveVar so their bounds do not overlap
if (const auto* var_ptr = symbolic.as<VarNode>()) {
if (const auto* var_ptr = symbol.as<VarNode>()) {
Var var = GetRef<Var>(var_ptr);
// skip non-index type, keep it to be compatible
// with any_dim that do not represent any value
Expand All @@ -92,7 +109,8 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
// mark the constant bound is sufficient
// we cannot mark interval set as that will cause relaxation of the var
// during bound proof which is not our intention
this->const_int_bound.Update(var, ConstIntBound(0, ConstIntBound::kPosInf), allow_override);
this->const_int_bound.Update(var, ConstIntBound(-offset, ConstIntBound::kPosInf),
allow_override);
}
}

Expand Down
18 changes: 18 additions & 0 deletions src/arith/product_normal_form.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
}
}

/**
* \brief Unpack chain of add sub by calling each leaf via fleaf
* \param value The expression value.
* \tparam FLeaf The callback function at leaf.
*/
template <typename FLeaf>
inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) {
if (const tir::AddNode* node = value.as<tir::AddNode>()) {
UnpackSum(node->a, fleaf, sign);
UnpackSum(node->b, fleaf, sign);
} else if (const tir::SubNode* node = value.as<tir::SubNode>()) {
UnpackSum(node->a, fleaf, sign);
UnpackSum(node->b, fleaf, -sign);
} else {
fleaf(value, sign);
}
}

/*!
* \brief Helper function to multiply extent and and re-normalize.
*
Expand Down
16 changes: 12 additions & 4 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,10 +1734,6 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32):


class TestBufferShapeConstraint(BaseBeforeAfter):
"""If enabled, rewrite boolean expressions into AND of OR"""

convert_boolean_to_and_of_ors = True

def before(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32,), "float32")
Expand All @@ -1749,5 +1745,17 @@ def expected(a: T.handle):
A[T.int64(0)] = T.float32(0)


class TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
def before(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
A[T.min(T.int64(1), n)] = T.float32(0)

def expected(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
A[T.int64(1)] = T.float32(0)


if __name__ == "__main__":
tvm.testing.main()