2727#include < tvm/ffi/reflection/registry.h>
2828#include < tvm/tir/expr.h>
2929#include < tvm/tir/expr_functor.h>
30+ #include < tvm/tir/op.h>
3031
3132#include < algorithm>
3233#include < unordered_map>
3334#include < utility>
3435
3536#include " constraint_extract.h"
37+ #include " int_operator.h"
3638#include " interval_set.h"
3739#include " pattern_match.h"
3840
@@ -109,10 +111,15 @@ TVM_DECLARE_LOGICAL_OP(Not);
109111
110112/* !
111113 * \brief Combine two interval set under arithmetic operations.
114+ * \param analyzer The analyzer for simplification and proving
115+ * \param a The first interval set
116+ * \param b The second interval set
117+ * \param op The operation node, used to extract dtype and other properties
112118 * \note this can possibly relax the set.
113119 */
114- template <typename Op>
115- inline IntervalSet Combine (Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) {
120+ template <typename Op, typename OpNode>
121+ inline IntervalSet Combine (Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) {
122+ DataType dtype = op->dtype ;
116123 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
117124 PrimExpr expr;
118125 if (auto res = TryConstFold<Op>(a->min_value , b->min_value )) {
@@ -134,7 +141,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat
134141
135142template <>
136143inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b,
137- DataType /* dtype */ ) {
144+ const tir::AddNode* /* op */ ) {
138145 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
139146 return IntervalSet::SinglePoint (a->min_value + b->min_value );
140147 }
@@ -149,7 +156,7 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalS
149156
150157template <>
151158inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b,
152- DataType /* dtype */ ) {
159+ const tir::SubNode* /* op */ ) {
153160 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
154161 return IntervalSet::SinglePoint (a->min_value - b->min_value );
155162 }
@@ -164,7 +171,7 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalS
164171
165172template <>
166173inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
167- DataType /* dtype */ ) {
174+ const tir::MulNode* /* op */ ) {
168175 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
169176 return IntervalSet::SinglePoint (a->min_value * b->min_value );
170177 }
@@ -198,7 +205,7 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, Interval
198205
199206template <>
200207inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
201- DataType /* dtype */ ) {
208+ const tir::DivNode* /* op */ ) {
202209 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
203210 return IntervalSet::SinglePoint (a->min_value / b->min_value );
204211 }
@@ -232,7 +239,7 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, Interval
232239
233240template <>
234241inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
235- DataType /* dtype */ ) {
242+ const tir::ModNode* op ) {
236243 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
237244 return IntervalSet::SinglePoint (truncmod (a->min_value , b->min_value ));
238245 }
@@ -261,7 +268,7 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, Interval
261268
262269template <>
263270inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
264- DataType /* dtype */ ) {
271+ const tir::FloorDivNode* /* op */ ) {
265272 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
266273 return IntervalSet::SinglePoint (floordiv (a->min_value , b->min_value ));
267274 }
@@ -295,7 +302,7 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, Int
295302
296303template <>
297304inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
298- DataType /* dtype */ ) {
305+ const tir::FloorModNode* op ) {
299306 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
300307 return IntervalSet::SinglePoint (floormod (a->min_value , b->min_value ));
301308 }
@@ -321,6 +328,29 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
321328 return IntervalSet (tmin, tmax);
322329 }
323330 }
331+ // Enhanced: Use ModularSet analysis for better bounds
332+ if (auto * div_imm = divisor.as <tir::IntImmNode>()) {
333+ int64_t div_val = div_imm->value ;
334+
335+ // Analyze the modular properties of the dividend
336+ ModularSet dividend_mod = analyzer->modular_set (op->a );
337+
338+ if (dividend_mod.defined () && dividend_mod->coeff > 0 ) {
339+ // Calculate GCD of dividend coefficient and divisor
340+ int64_t gcd = ZeroAwareGCD (dividend_mod->coeff , div_val);
341+
342+ if (gcd > 1 && div_val % gcd == 0 ) {
343+ // The dividend is a multiple of gcd, and divisor is also a multiple of gcd
344+ // So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd
345+ int64_t max_quotient = (div_val / gcd) - 1 ;
346+ int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd);
347+
348+ if (max_mod_result >= 0 && max_mod_result < div_val) {
349+ return IntervalSet (make_zero (op->dtype ), make_const (op->dtype , max_mod_result));
350+ }
351+ }
352+ }
353+ }
324354 return IntervalSet (make_zero (divisor.dtype ()), divisor - 1 );
325355 } else {
326356 PrimExpr bound = abs (divisor) - 1 ;
@@ -333,7 +363,7 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
333363
334364template <>
335365inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
336- DataType /* dtype */ ) {
366+ const tir::MaxNode* /* op */ ) {
337367 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
338368 return IntervalSet::SinglePoint (max (a->min_value , b->min_value ));
339369 }
@@ -344,7 +374,7 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, Interval
344374
345375template <>
346376inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
347- DataType /* dtype */ ) {
377+ const tir::MinNode* /* op */ ) {
348378 if (a->IsSinglePoint () && b->IsSinglePoint ()) {
349379 return IntervalSet::SinglePoint (min (a->min_value , b->min_value ));
350380 }
@@ -475,19 +505,25 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
475505 if (op->lanes ->IsInstance <IntImmNode>()) {
476506 int lanes = static_cast <int >(Downcast<IntImm>(op->lanes )->value );
477507 if (vstride > 0 ) {
478- return Combine<Add>(analyzer_, base,
479- IntervalSet (make_zero (t), make_const (t, vstride * (lanes - 1 ))),
480- op->dtype );
508+ PrimExpr stride_expr = make_const (t, vstride * (lanes - 1 ));
509+ auto add_op = tir::Add (op->base , stride_expr);
510+ auto add_node = add_op.as <tir::AddNode>();
511+ return Combine<Add>(analyzer_, base, IntervalSet (make_zero (t), stride_expr), add_node);
481512 } else {
482- return Combine<Add>(analyzer_, base,
483- IntervalSet (make_const (t, vstride * (lanes - 1 )), make_zero (t)),
484- op->dtype );
513+ PrimExpr stride_expr = make_const (t, vstride * (lanes - 1 ));
514+ auto add_op = tir::Add (op->base , stride_expr);
515+ auto add_node = add_op.as <tir::AddNode>();
516+ return Combine<Add>(analyzer_, base, IntervalSet (stride_expr, make_zero (t)), add_node);
485517 }
486518 } else { /* Scalable vector */
487519 if (vstride > 0 ) {
488- return Combine<Add>(analyzer_, base, IntervalSet (make_zero (t), pos_inf ()), op->dtype );
520+ auto add_op = tir::Add (op->base , make_zero (t));
521+ auto add_node = add_op.as <tir::AddNode>();
522+ return Combine<Add>(analyzer_, base, IntervalSet (make_zero (t), pos_inf ()), add_node);
489523 } else {
490- return Combine<Add>(analyzer_, base, IntervalSet (neg_inf (), make_zero (t)), op->dtype );
524+ auto add_op = tir::Add (op->base , make_zero (t));
525+ auto add_node = add_op.as <tir::AddNode>();
526+ return Combine<Add>(analyzer_, base, IntervalSet (neg_inf (), make_zero (t)), add_node);
491527 }
492528 }
493529 }
@@ -563,7 +599,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
563599 if (MatchPoint (a, op->a ) && MatchPoint (b, op->b )) {
564600 return IntervalSet::SinglePoint (ffi::GetRef<PrimExpr>(op));
565601 }
566- return Combine<TOp>(analyzer_, a, b, op-> dtype );
602+ return Combine<TOp>(analyzer_, a, b, op);
567603 }
568604
569605 // recursive depth
0 commit comments