Skip to content

Commit 0c2ab1b

Browse files
[Arith] Support eq in detect_clip_bound (#13746)
* Support eq in detect_clip_bound * follow review suggestion
1 parent 1bc8cf8 commit 0c2ab1b

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

src/arith/detect_linear_equation.cc

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond,
189189
PostOrderVisit(cond, fvisit);
190190
if (flag != 1) return false;
191191
// canonical form: exp >= 0
192+
bool is_eq = false;
192193
PrimExpr canonical;
193194
if (const LTNode* op = cond.as<LTNode>()) {
194195
if (!op->a.dtype().is_int()) return false;
@@ -202,6 +203,10 @@ bool DetectClipBound(const PrimExpr& cond,
202203
} else if (const GENode* op = cond.as<GENode>()) {
203204
if (!op->a.dtype().is_int()) return false;
204205
canonical = op->a - op->b;
206+
} else if (const EQNode* op = cond.as<EQNode>()) {
207+
if (!op->a.dtype().is_int()) return false;
208+
canonical = op->a - op->b;
209+
is_eq = true;
205210
} else {
206211
return false;
207212
}
@@ -210,25 +215,40 @@ bool DetectClipBound(const PrimExpr& cond,
210215
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
211216
ret.coeff = analyzer.Simplify(ret.coeff);
212217
IntervalEntry& p = (*bmap)[var.get()];
218+
219+
Optional<PrimExpr> min_value;
220+
Optional<PrimExpr> max_value;
213221
if (is_const_int(ret.coeff, 1)) {
214222
// var + shift >=0 -> var >= -shift
223+
min_value = -ret.base;
224+
if (is_eq) {
225+
max_value = min_value;
226+
}
227+
} else if (is_const_int(ret.coeff, -1)) {
228+
// -var + shift >=0 -> var <= shift
229+
max_value = ret.base;
230+
if (is_eq) {
231+
min_value = max_value;
232+
}
233+
}
234+
if (!min_value.defined() && !max_value.defined()) {
235+
return false;
236+
}
237+
if (min_value.defined()) {
215238
if (p.min_value.defined()) {
216-
p.min_value = max(p.min_value, -ret.base);
239+
p.min_value = max(p.min_value, min_value.value());
217240
} else {
218-
p.min_value = -ret.base;
241+
p.min_value = min_value.value();
219242
}
220-
return true;
221243
}
222-
if (is_const_int(ret.coeff, -1)) {
223-
// -var + shift >=0 -> var <= shift
244+
if (max_value.defined()) {
224245
if (p.max_value.defined()) {
225-
p.max_value = min(p.max_value, ret.base);
246+
p.max_value = min(p.max_value, max_value.value());
226247
} else {
227-
p.max_value = ret.base;
248+
p.max_value = max_value.value();
228249
}
229-
return true;
230250
}
231-
return false;
251+
return true;
232252
}
233253

234254
template <typename OP>

tests/python/unittest/test_arith_detect_clip_bound.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,18 @@ def test_basic():
3939
tvm.testing.assert_prim_expr_equal(m[2], 4)
4040

4141

42+
def test_trivial_eq():
43+
a = te.var("a")
44+
b = te.var("b")
45+
m = tvm.arith.detect_clip_bound(b == 3, [a, b])
46+
tvm.testing.assert_prim_expr_equal(m[2], 3)
47+
tvm.testing.assert_prim_expr_equal(m[3], 3)
48+
m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b])
49+
tvm.testing.assert_prim_expr_equal(m[0], 4)
50+
tvm.testing.assert_prim_expr_equal(m[1], 4)
51+
tvm.testing.assert_prim_expr_equal(m[2], 3)
52+
tvm.testing.assert_prim_expr_equal(m[3], 3)
53+
54+
4255
if __name__ == "__main__":
4356
test_basic()

0 commit comments

Comments
 (0)