Skip to content

Commit db4cd4a

Browse files
follow review suggestion
1 parent bf740d5 commit db4cd4a

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

src/arith/detect_linear_equation.cc

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ bool DetectClipBound(const PrimExpr& cond,
204204
if (!op->a.dtype().is_int()) return false;
205205
canonical = op->a - op->b;
206206
} else if (const EQNode* op = cond.as<EQNode>()) {
207+
if (!op->a.dtype().is_int()) return false;
207208
canonical = op->a - op->b;
208209
is_eq = true;
209210
} else {
@@ -214,31 +215,40 @@ bool DetectClipBound(const PrimExpr& cond,
214215
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
215216
ret.coeff = analyzer.Simplify(ret.coeff);
216217
IntervalEntry& p = (*bmap)[var.get()];
218+
219+
Optional<PrimExpr> min_value;
220+
Optional<PrimExpr> max_value;
217221
if (is_const_int(ret.coeff, 1)) {
218222
// var + shift >=0 -> var >= -shift
219-
if (p.min_value.defined()) {
220-
p.min_value = max(p.min_value, -ret.base);
221-
} else {
222-
p.min_value = -ret.base;
223+
min_value = -ret.base;
224+
if (is_eq) {
225+
max_value = min_value;
223226
}
227+
} else if (is_const_int(ret.coeff, -1)) {
228+
// -var + shift >=0 -> var <= shift
229+
max_value = ret.base;
224230
if (is_eq) {
225-
p.max_value = p.min_value;
231+
min_value = max_value;
226232
}
227-
return true;
228233
}
229-
if (is_const_int(ret.coeff, -1)) {
230-
// -var + shift >=0 -> var <= shift
231-
if (p.max_value.defined()) {
232-
p.max_value = min(p.max_value, ret.base);
234+
if (!min_value.defined() && !max_value.defined()) {
235+
return false;
236+
}
237+
if (min_value.defined()) {
238+
if (p.min_value.defined()) {
239+
p.min_value = max(p.min_value, min_value.value());
233240
} else {
234-
p.max_value = ret.base;
241+
p.min_value = min_value.value();
235242
}
236-
if (is_eq) {
237-
p.min_value = p.max_value;
243+
}
244+
if (max_value.defined()) {
245+
if (p.max_value.defined()) {
246+
p.max_value = min(p.max_value, max_value.value());
247+
} else {
248+
p.max_value = max_value.value();
238249
}
239-
return true;
240250
}
241-
return false;
251+
return true;
242252
}
243253

244254
template <typename OP>

0 commit comments

Comments
 (0)