diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index b8e5db483f4f..7e1d8fb3fb89 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -102,6 +102,7 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: + explicit Impl(Analyzer* parent) : parent_(parent) {} /*! \brief additional bound info about expr in bound */ struct BoundInfo { /*! \brief The expr */ @@ -278,6 +279,33 @@ class ConstIntBoundAnalyzer::Impl if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // ModularSet(expr) = 16*k (coeff=16, base=0) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -324,6 +352,32 @@ class ConstIntBoundAnalyzer::Impl if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // ModularSet(expr) = 16*k (coeff=16, base=0) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -458,6 +512,8 @@ class ConstIntBoundAnalyzer::Impl private: friend class ConstIntBoundAnalyzer; + // parent analyzer + Analyzer* parent_; // internal variable map std::unordered_map var_map_; // additional bound info @@ -525,6 +581,7 @@ class ConstIntBoundAnalyzer::Impl // If the range of b does not have 0, use BinaryOpBoundary. return BinaryOpBoundary(a, b, op); } + /*! * \brief Compute x + y, aware of inf. * \param x The left operand. @@ -805,7 +862,7 @@ std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index aa15284b3e03..1433ceb70fc0 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -27,12 +27,14 @@ #include #include #include +#include #include #include #include #include "constraint_extract.h" +#include "int_operator.h" #include "interval_set.h" #include "pattern_match.h" @@ -109,10 +111,15 @@ TVM_DECLARE_LOGICAL_OP(Not); /*! * \brief Combine two interval set under arithmetic operations. + * \param analyzer The analyzer for simplification and proving + * \param a The first interval set + * \param b The second interval set + * \param op The operation node, used to extract dtype and other properties * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { + DataType dtype = op->dtype; if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -134,7 +141,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::AddNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -149,7 +156,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::SubNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -164,7 +171,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MulNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -198,7 +205,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::DivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -232,7 +239,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::ModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -261,7 +268,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorDivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -295,7 +302,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -321,6 +328,29 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int return IntervalSet(tmin, tmax); } } + // Enhanced: Use ModularSet analysis for better bounds + if (auto* div_imm = divisor.as()) { + int64_t div_val = div_imm->value; + + // Analyze the modular properties of the dividend + ModularSet dividend_mod = analyzer->modular_set(op->a); + + if (dividend_mod.defined() && dividend_mod->coeff > 0) { + // Calculate GCD of dividend coefficient and divisor + int64_t gcd = ZeroAwareGCD(dividend_mod->coeff, div_val); + + if (gcd > 1 && div_val % gcd == 0) { + // The dividend is a multiple of gcd, and divisor is also a multiple of gcd + // So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd + int64_t max_quotient = (div_val / gcd) - 1; + int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); + + if (max_mod_result >= 0 && max_mod_result < div_val) { + return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result)); + } + } + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -333,7 +363,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MaxNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -344,7 +374,7 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MinNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -475,19 +505,25 @@ class IntervalSetEvaluator : public ExprFunctor { if (op->lanes->IsInstance()) { int lanes = static_cast(Downcast(op->lanes)->value); if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node); } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node); } } else { /* Scalable vector */ if (vstride > 0) { - return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node); } else { - return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node); } } } @@ -563,7 +599,7 @@ class IntervalSetEvaluator : public ExprFunctor { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(ffi::GetRef(op)); } - return Combine(analyzer_, a, b, op->dtype); + return Combine(analyzer_, a, b, op); } // recursive depth diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 14bfec2328f2..8728df7e3f3a 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -298,5 +298,17 @@ class TestRampBound(BaseCompare): ) +class TestModularSetBound(BaseCompare): + analyzer = tvm.arith.Analyzer() + tx = tvm.te.var("tx", dtype="int32") + bx = tvm.te.var("bx", dtype="int32") + + expr = (bx * 2048 + tx * 16) % 7168 + + test_case = tvm.testing.parameter( + TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index 18865a73df45..04014ca30095 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -387,5 +387,15 @@ def test_union_lower_bound(): assert result.max_value.same_as(pos_inf) +def test_modular_set(): + ck = IntSetChecker() + x = tvm.te.var("x", dtype="int32") + y = tvm.te.var("y", dtype="int32") + expr = (x * 2048 + y * 16) % 7168 + ck.verify( + expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152) + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py index 057cd0e9f7ae..b901c3ce1372 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py @@ -846,21 +846,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 25.000000042995662, - 20.000001375860553, - 23.00000017198264, - 14.000088052430122, + 25.00000004, + 19.99718086, + 23.00000017, + 13.99726771, 1.0, 0.0, 0.0, - 18.00000550343433, - 20.00562591970089, - 2.321928094887362, - 23.00000017198264, - 18.00000550343433, - 21.000000687930438, - 12.0003521774803, - 12.0003521774803, + 18.0000055, + 20.00000138, + 2.32192809, + 23.00000017, + 17.997185, + 21.00000069, + 11.99753235, + 12.00035218, ], rtol=1e-5, atol=1e-5, @@ -872,21 +872,21 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 25.000000042995662, - 12.0003521774803, - 23.00000017198264, - 9.002815015607053, + 25.00000004, + 11.00070427, + 23.00000017, + 5.04439412, 1.0, 0.0, 0.0, - 6.022367813028454, - 11.98049663618346, - 8.005624549193879, - 17.000011006847668, - 4.087462841250339, - 15.000044026886828, - 1.584962500721156, - 4.087462841250339, + 6.02236781, + 11.98049664, + 8.00562455, + 17.00001101, + 3.169925, + 15.00004403, + 0.169925, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1052,21 +1052,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 22.00000034396526, - 20.000001375860553, - 20.000001375860553, - 14.000088052430122, + 22.00000034, + 19.85798251, + 20.00000138, + 13.85807816, 1.0, 0.0, 0.0, - 15.000044026886828, - 20.17555076886471, - 2.321928094887362, - 20.000001375860553, - 18.00000550343433, - 18.00000550343433, - 12.0003521774803, - 4.087462841250339, + 15.00004403, + 20.04456622, + 2.32192809, + 20.00000138, + 17.85798707, + 18.0000055, + 11.8583696, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1078,20 +1078,20 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 22.00000034396526, - 9.002815015607053, - 20.000001375860553, - 3.169925001442312, + 22.00000034, + 7.01122726, + 20.00000138, + 4.08746284, 1.0, 0.0, 0.0, 3.169925001442312, - 9.61654884377899, + 4.08746284, 8.005624549193879, 14.000088052430122, - 1.584962500721156, - 12.0003521774803, - 0.044394119358453436, + 0.5849625, + 12.00035218, + 0.08746284, 4.087462841250339, ], rtol=1e-5, diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index c8a095280230..426272584bb5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -852,7 +852,7 @@ def tir_workload( v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) - for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30): + for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 % 3 * 10 + 40) // 30 + 1): with T.block("adaptive_pool_sum"): v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) @@ -870,7 +870,7 @@ def tir_workload( T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) - adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30)) + adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1)) # fmt: on def te_workload():