Skip to content

Commit 70c157d

Browse files
authored
[Analyzer] Enhance ConstIntBoundAnalyzer and IntervalSet with modular set analysis (#18330)
* Enhance ConstIntBoundAnalyzer and IntervalSet with modular set analysis - Added modular set analysis to ConstIntBoundAnalyzer for tighter bounds when min_value equals max_value. - Introduced ComputeGCD function to calculate the GCD of two integers. - Updated Combine functions in IntervalSet to accept operation nodes for better type handling. - Enhanced tests for modular set bounds in both const integer bounds and interval sets. * replace gcd compute with ZeroAwareGCD * doc op node * replace Compute GCD with ZeroAwareGCD * add example * test fix * test fix * lint fix
1 parent ddf7bce commit 70c157d

File tree

6 files changed

+182
-67
lines changed

6 files changed

+182
-67
lines changed

src/arith/const_int_bound.cc

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct ConstIntBoundAnalyzer::Entry {
102102
class ConstIntBoundAnalyzer::Impl
103103
: public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
104104
public:
105+
explicit Impl(Analyzer* parent) : parent_(parent) {}
105106
/*! \brief additional bound info about expr in bound */
106107
struct BoundInfo {
107108
/*! \brief The expr */
@@ -278,6 +279,33 @@ class ConstIntBoundAnalyzer::Impl
278279

279280
if (b.min_value > 0) {
280281
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
282+
283+
// Try to get tighter bounds using modular set information
284+
if (parent_ && b.min_value == b.max_value) {
285+
ModularSet mod_a = parent_->modular_set(op->a);
286+
int64_t modulus = b.min_value;
287+
int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus);
288+
289+
// If gcd_coeff_mod > 1, we can get tighter bounds
290+
// The result will be of the form gcd_coeff_mod * k + (base % modulus)
291+
// where k ranges to cover [0, modulus - gcd_coeff_mod]
292+
//
293+
// Example: expr = (bx * 2048 + tx * 16) % 7168
294+
// where bx in [0, 3584), tx in [0, 128)
295+
// ModularSet(expr) = 16*k (coeff=16, base=0)
296+
// GCD(16, 7168) = 16
297+
// Result can only be {0, 16, 32, ..., 7152}
298+
// Without this optimization: bound = [0, 7167]
299+
// With this optimization: bound = [0, 7152]
300+
if (gcd_coeff_mod > 1) {
301+
int64_t base_mod = mod_a->base % modulus;
302+
if (base_mod < 0) base_mod += modulus;
303+
int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
304+
if (tight_max >= modulus) tight_max -= modulus;
305+
return MakeBound(base_mod, tight_max);
306+
}
307+
}
308+
281309
if (a.min_value >= 0) {
282310
// 0 <= [a_min, a_max] < b_min
283311
if (a.max_value < b.min_value) return a;
@@ -324,6 +352,32 @@ class ConstIntBoundAnalyzer::Impl
324352

325353
if (b.min_value > 0) {
326354
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
355+
// Try to get tighter bounds using modular set information
356+
if (parent_ && b.min_value == b.max_value) {
357+
ModularSet mod_a = parent_->modular_set(op->a);
358+
int64_t modulus = b.min_value;
359+
int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus);
360+
361+
// If gcd_coeff_mod > 1, we can get tighter bounds
362+
// The result will be of the form gcd_coeff_mod * k + (base % modulus)
363+
// where k ranges to cover [0, modulus - gcd_coeff_mod]
364+
//
365+
// Example: expr = (bx * 2048 + tx * 16) % 7168
366+
// where bx in [0, 3584), tx in [0, 128)
367+
// ModularSet(expr) = 16*k (coeff=16, base=0)
368+
// GCD(16, 7168) = 16
369+
// Result can only be {0, 16, 32, ..., 7152}
370+
// Without this optimization: bound = [0, 7167]
371+
// With this optimization: bound = [0, 7152]
372+
if (gcd_coeff_mod > 1) {
373+
int64_t base_mod = mod_a->base % modulus;
374+
if (base_mod < 0) base_mod += modulus;
375+
int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
376+
if (tight_max >= modulus) tight_max -= modulus;
377+
return MakeBound(base_mod, tight_max);
378+
}
379+
}
380+
327381
if (a.min_value >= 0) {
328382
// 0 <= [a_min, a_max] < b_min
329383
if (a.max_value < b.min_value) return a;
@@ -458,6 +512,8 @@ class ConstIntBoundAnalyzer::Impl
458512

459513
private:
460514
friend class ConstIntBoundAnalyzer;
515+
// parent analyzer
516+
Analyzer* parent_;
461517
// internal variable map
462518
std::unordered_map<Var, Entry> var_map_;
463519
// additional bound info
@@ -525,6 +581,7 @@ class ConstIntBoundAnalyzer::Impl
525581
// If the range of b does not have 0, use BinaryOpBoundary.
526582
return BinaryOpBoundary(a, b, op);
527583
}
584+
528585
/*!
529586
* \brief Compute x + y, aware of inf.
530587
* \param x The left operand.
@@ -805,7 +862,7 @@ std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con
805862
return impl_->EnterConstraint(constraint);
806863
}
807864

808-
ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {}
865+
ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
809866

810867
ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
811868

src/arith/int_set.cc

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
#include <tvm/ffi/reflection/registry.h>
2828
#include <tvm/tir/expr.h>
2929
#include <tvm/tir/expr_functor.h>
30+
#include <tvm/tir/op.h>
3031

3132
#include <algorithm>
3233
#include <unordered_map>
3334
#include <utility>
3435

3536
#include "constraint_extract.h"
37+
#include "int_operator.h"
3638
#include "interval_set.h"
3739
#include "pattern_match.h"
3840

@@ -109,10 +111,15 @@ TVM_DECLARE_LOGICAL_OP(Not);
109111

110112
/*!
111113
* \brief Combine two interval set under arithmetic operations.
114+
* \param analyzer The analyzer for simplification and proving
115+
* \param a The first interval set
116+
* \param b The second interval set
117+
* \param op The operation node, used to extract dtype and other properties
112118
* \note this can possibly relax the set.
113119
*/
114-
template <typename Op>
115-
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) {
120+
template <typename Op, typename OpNode>
121+
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) {
122+
DataType dtype = op->dtype;
116123
if (a->IsSinglePoint() && b->IsSinglePoint()) {
117124
PrimExpr expr;
118125
if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
@@ -134,7 +141,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat
134141

135142
template <>
136143
inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b,
137-
DataType /* dtype */) {
144+
const tir::AddNode* /* op */) {
138145
if (a->IsSinglePoint() && b->IsSinglePoint()) {
139146
return IntervalSet::SinglePoint(a->min_value + b->min_value);
140147
}
@@ -149,7 +156,7 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalS
149156

150157
template <>
151158
inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b,
152-
DataType /* dtype */) {
159+
const tir::SubNode* /* op */) {
153160
if (a->IsSinglePoint() && b->IsSinglePoint()) {
154161
return IntervalSet::SinglePoint(a->min_value - b->min_value);
155162
}
@@ -164,7 +171,7 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalS
164171

165172
template <>
166173
inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
167-
DataType /* dtype */) {
174+
const tir::MulNode* /* op */) {
168175
if (a->IsSinglePoint() && b->IsSinglePoint()) {
169176
return IntervalSet::SinglePoint(a->min_value * b->min_value);
170177
}
@@ -198,7 +205,7 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, Interval
198205

199206
template <>
200207
inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
201-
DataType /* dtype */) {
208+
const tir::DivNode* /* op */) {
202209
if (a->IsSinglePoint() && b->IsSinglePoint()) {
203210
return IntervalSet::SinglePoint(a->min_value / b->min_value);
204211
}
@@ -232,7 +239,7 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, Interval
232239

233240
template <>
234241
inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
235-
DataType /* dtype */) {
242+
const tir::ModNode* op) {
236243
if (a->IsSinglePoint() && b->IsSinglePoint()) {
237244
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
238245
}
@@ -261,7 +268,7 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, Interval
261268

262269
template <>
263270
inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
264-
DataType /* dtype */) {
271+
const tir::FloorDivNode* /* op */) {
265272
if (a->IsSinglePoint() && b->IsSinglePoint()) {
266273
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
267274
}
@@ -295,7 +302,7 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, Int
295302

296303
template <>
297304
inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
298-
DataType /* dtype */) {
305+
const tir::FloorModNode* op) {
299306
if (a->IsSinglePoint() && b->IsSinglePoint()) {
300307
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
301308
}
@@ -321,6 +328,29 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
321328
return IntervalSet(tmin, tmax);
322329
}
323330
}
331+
// Enhanced: Use ModularSet analysis for better bounds
332+
if (auto* div_imm = divisor.as<tir::IntImmNode>()) {
333+
int64_t div_val = div_imm->value;
334+
335+
// Analyze the modular properties of the dividend
336+
ModularSet dividend_mod = analyzer->modular_set(op->a);
337+
338+
if (dividend_mod.defined() && dividend_mod->coeff > 0) {
339+
// Calculate GCD of dividend coefficient and divisor
340+
int64_t gcd = ZeroAwareGCD(dividend_mod->coeff, div_val);
341+
342+
if (gcd > 1 && div_val % gcd == 0) {
343+
// The dividend is a multiple of gcd, and divisor is also a multiple of gcd
344+
// So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd
345+
int64_t max_quotient = (div_val / gcd) - 1;
346+
int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd);
347+
348+
if (max_mod_result >= 0 && max_mod_result < div_val) {
349+
return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result));
350+
}
351+
}
352+
}
353+
}
324354
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
325355
} else {
326356
PrimExpr bound = abs(divisor) - 1;
@@ -333,7 +363,7 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
333363

334364
template <>
335365
inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
336-
DataType /* dtype */) {
366+
const tir::MaxNode* /* op */) {
337367
if (a->IsSinglePoint() && b->IsSinglePoint()) {
338368
return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
339369
}
@@ -344,7 +374,7 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, Interval
344374

345375
template <>
346376
inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
347-
DataType /* dtype */) {
377+
const tir::MinNode* /* op */) {
348378
if (a->IsSinglePoint() && b->IsSinglePoint()) {
349379
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
350380
}
@@ -475,19 +505,25 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
475505
if (op->lanes->IsInstance<IntImmNode>()) {
476506
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
477507
if (vstride > 0) {
478-
return Combine<Add>(analyzer_, base,
479-
IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))),
480-
op->dtype);
508+
PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
509+
auto add_op = tir::Add(op->base, stride_expr);
510+
auto add_node = add_op.as<tir::AddNode>();
511+
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node);
481512
} else {
482-
return Combine<Add>(analyzer_, base,
483-
IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)),
484-
op->dtype);
513+
PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
514+
auto add_op = tir::Add(op->base, stride_expr);
515+
auto add_node = add_op.as<tir::AddNode>();
516+
return Combine<Add>(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node);
485517
}
486518
} else { /* Scalable vector */
487519
if (vstride > 0) {
488-
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype);
520+
auto add_op = tir::Add(op->base, make_zero(t));
521+
auto add_node = add_op.as<tir::AddNode>();
522+
return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node);
489523
} else {
490-
return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), op->dtype);
524+
auto add_op = tir::Add(op->base, make_zero(t));
525+
auto add_node = add_op.as<tir::AddNode>();
526+
return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node);
491527
}
492528
}
493529
}
@@ -563,7 +599,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
563599
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
564600
return IntervalSet::SinglePoint(ffi::GetRef<PrimExpr>(op));
565601
}
566-
return Combine<TOp>(analyzer_, a, b, op->dtype);
602+
return Combine<TOp>(analyzer_, a, b, op);
567603
}
568604

569605
// recursive depth

tests/python/arith/test_arith_const_int_bound.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,5 +298,17 @@ class TestRampBound(BaseCompare):
298298
)
299299

300300

301+
class TestModularSetBound(BaseCompare):
302+
analyzer = tvm.arith.Analyzer()
303+
tx = tvm.te.var("tx", dtype="int32")
304+
bx = tvm.te.var("bx", dtype="int32")
305+
306+
expr = (bx * 2048 + tx * 16) % 7168
307+
308+
test_case = tvm.testing.parameter(
309+
TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}),
310+
)
311+
312+
301313
if __name__ == "__main__":
302314
tvm.testing.main()

tests/python/arith/test_arith_intset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,15 @@ def test_union_lower_bound():
387387
assert result.max_value.same_as(pos_inf)
388388

389389

390+
def test_modular_set():
391+
ck = IntSetChecker()
392+
x = tvm.te.var("x", dtype="int32")
393+
y = tvm.te.var("y", dtype="int32")
394+
expr = (x * 2048 + y * 16) % 7168
395+
ck.verify(
396+
expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152)
397+
)
398+
399+
390400
if __name__ == "__main__":
391401
tvm.testing.main()

0 commit comments

Comments
 (0)