diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9e5b1414edf4..3e5b8834ebca 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -65,17 +65,34 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { } void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { - // split out the symbolic and non-symbolic part + // decompose value as symbol * scale + offset + int64_t offset = 0; + PrimExpr symbol_scale = tir::make_const(value.dtype(), 0); + + auto fcollect_sum = [&](PrimExpr val, int sign) { + if (const auto* intimm = val.as()) { + offset += intimm->value * sign; + } else { + if (sign > 0) { + symbol_scale = symbol_scale + val; + } else { + symbol_scale = symbol_scale - val; + } + } + }; + UnpackSum(value, fcollect_sum); + + // split out the symbol and non-symbolic part int64_t cscale = 1; - PrimExpr symbolic = tir::make_const(value.dtype(), 1); - auto fcollect = [&](PrimExpr val) { + PrimExpr symbol = tir::make_const(value.dtype(), 1); + auto fcollect_prod = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; } else { - symbolic = symbolic * val; + symbol = symbol * val; } }; - UnpackReduction(value, fcollect); + UnpackReduction(symbol_scale, fcollect_prod); if (cscale <= 0) return; // override the constant int bound by marking it as non-negative // NOTE: there might be future opportunities of more bound hint @@ -83,7 +100,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // // We may consider enhance the sub analyzer to directly take // MarkPositiveVar so their bounds do not overlap - if (const auto* var_ptr = symbolic.as()) { + if (const auto* var_ptr = symbol.as()) { Var var = GetRef(var_ptr); // skip non-index type, keep it to be compatible // with any_dim that do not represent any value @@ -92,7 +109,8 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // mark the constant bound is sufficient // we cannot mark interval set as that will cause relaxation of the var // during bound proof which is not our intention - this->const_int_bound.Update(var, ConstIntBound(0, ConstIntBound::kPosInf), allow_override); + this->const_int_bound.Update(var, ConstIntBound(-offset, ConstIntBound::kPosInf), + allow_override); } } diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h index 768a3a2b8bab..d27ca76650e0 100644 --- a/src/arith/product_normal_form.h +++ b/src/arith/product_normal_form.h @@ -47,6 +47,24 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { } } +/** + * \brief Unpack chain of add sub by calling each leaf via fleaf + * \param value The expression value. + * \tparam FLeaf The callback function at leaf. + */ +template +inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { + if (const tir::AddNode* node = value.as()) { + UnpackSum(node->a, fleaf, sign); + UnpackSum(node->b, fleaf, sign); + } else if (const tir::SubNode* node = value.as()) { + UnpackSum(node->a, fleaf, sign); + UnpackSum(node->b, fleaf, -sign); + } else { + fleaf(value, sign); + } +} + /*! * \brief Helper function to multiply extent and and re-normalize. * diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 79fd5e143418..c779d92f9c47 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -1734,10 +1734,6 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32): class TestBufferShapeConstraint(BaseBeforeAfter): - """If enabled, rewrite boolean expressions into AND of OR""" - - convert_boolean_to_and_of_ors = True - def before(a: T.handle): n = T.int64() A = T.match_buffer(a, (n * 32,), "float32") @@ -1749,5 +1745,17 @@ def expected(a: T.handle): A[T.int64(0)] = T.float32(0) +class TestBufferShapeConstraintWithOffset(BaseBeforeAfter): + def before(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32") + A[T.min(T.int64(1), n)] = T.float32(0) + + def expected(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32") + A[T.int64(1)] = T.float32(0) + + if __name__ == "__main__": tvm.testing.main()