@@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond,
189189 PostOrderVisit (cond, fvisit);
190190 if (flag != 1 ) return false ;
191191 // canonical form: exp >= 0
192+ bool is_eq = false ;
192193 PrimExpr canonical;
193194 if (const LTNode* op = cond.as <LTNode>()) {
194195 if (!op->a .dtype ().is_int ()) return false ;
@@ -202,6 +203,10 @@ bool DetectClipBound(const PrimExpr& cond,
202203 } else if (const GENode* op = cond.as <GENode>()) {
203204 if (!op->a .dtype ().is_int ()) return false ;
204205 canonical = op->a - op->b ;
206+ } else if (const EQNode* op = cond.as <EQNode>()) {
207+ if (!op->a .dtype ().is_int ()) return false ;
208+ canonical = op->a - op->b ;
209+ is_eq = true ;
205210 } else {
206211 return false ;
207212 }
@@ -210,25 +215,40 @@ bool DetectClipBound(const PrimExpr& cond,
210215 if (!LinearEqDetector (var).Detect (canonical, &ret)) return false ;
211216 ret.coeff = analyzer.Simplify (ret.coeff );
212217 IntervalEntry& p = (*bmap)[var.get ()];
218+
219+ Optional<PrimExpr> min_value;
220+ Optional<PrimExpr> max_value;
213221 if (is_const_int (ret.coeff , 1 )) {
214222 // var + shift >=0 -> var >= -shift
223+ min_value = -ret.base ;
224+ if (is_eq) {
225+ max_value = min_value;
226+ }
227+ } else if (is_const_int (ret.coeff , -1 )) {
228+ // -var + shift >=0 -> var <= shift
229+ max_value = ret.base ;
230+ if (is_eq) {
231+ min_value = max_value;
232+ }
233+ }
234+ if (!min_value.defined () && !max_value.defined ()) {
235+ return false ;
236+ }
237+ if (min_value.defined ()) {
215238 if (p.min_value .defined ()) {
216- p.min_value = max (p.min_value , -ret. base );
239+ p.min_value = max (p.min_value , min_value. value () );
217240 } else {
218- p.min_value = -ret. base ;
241+ p.min_value = min_value. value () ;
219242 }
220- return true ;
221243 }
222- if (is_const_int (ret.coeff , -1 )) {
223- // -var + shift >=0 -> var <= shift
244+ if (max_value.defined ()) {
224245 if (p.max_value .defined ()) {
225- p.max_value = min (p.max_value , ret. base );
246+ p.max_value = min (p.max_value , max_value. value () );
226247 } else {
227- p.max_value = ret. base ;
248+ p.max_value = max_value. value () ;
228249 }
229- return true ;
230250 }
231- return false ;
251+ return true ;
232252}
233253
234254template <typename OP>
0 commit comments