Skip to content

Commit 47a1fbd

Browse files
committed
Benchmark with/without the BoundUsingReciprocal function
1 parent b2aa44f commit 47a1fbd

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

src/arith/const_int_bound.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <algorithm>
2929
#include <optional>
3030

31+
#include "../support/utils.h"
3132
#include "constraint_extract.h"
3233
#include "int_operator.h"
3334
#include "pattern_match.h"
@@ -240,8 +241,10 @@ class ConstIntBoundAnalyzer::Impl
240241
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
241242
ret.max_value = InfAwareAdd(a.max_value, b.max_value);
242243

243-
if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
244-
ret = Intersect(ret, bound.value());
244+
if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) {
245+
if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
246+
ret = Intersect(ret, bound.value());
247+
}
245248
}
246249

247250
return ret;
@@ -254,11 +257,13 @@ class ConstIntBoundAnalyzer::Impl
254257
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
255258
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);
256259

257-
if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
258-
ret = Intersect(ret, bound.value());
259-
}
260-
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
261-
ret = Intersect(ret, Negative(bound.value()));
260+
if (support::BoolEnvironmentVar("TVM_ENABLE_RECIPROCAL_PATTERN_MATCH")) {
261+
if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
262+
ret = Intersect(ret, bound.value());
263+
}
264+
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
265+
ret = Intersect(ret, Negative(bound.value()));
266+
}
262267
}
263268
return ret;
264269
}

tests/python/arith/test_arith_const_int_bound.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,33 @@ def __name__(self):
4343
return str(self.expr)
4444

4545

46+
with_reciprocal_pattern_match = tvm.testing.parameter(
47+
by_dict={
48+
"with_updated_const_int_analyzer": True,
49+
"without_updated_const_int_analyzer": False,
50+
}
51+
)
52+
53+
import pytest
54+
55+
56+
@pytest.fixture(autouse=True)
57+
def set_reciprocal_pattern_match(with_reciprocal_pattern_match):
58+
import os
59+
60+
var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH"
61+
old_value = os.environ.get(var_name)
62+
os.environ[var_name] = str(int(with_reciprocal_pattern_match))
63+
yield
64+
65+
if old_value is None:
66+
del os.environ[var_name]
67+
else:
68+
os.environ = old_value
69+
70+
4671
class BaseCompare:
47-
def test_const_bounds(self, test_case):
72+
def test_const_bounds(self, test_case, benchmark):
4873
analyzer = tvm.arith.Analyzer()
4974

5075
for var, bounds in test_case.known_bounds.items():
@@ -54,7 +79,7 @@ def test_const_bounds(self, test_case):
5479
if test_case.constraint is not None:
5580
stack.enter_context(analyzer.constraint_scope(test_case.constraint))
5681

57-
bounds = analyzer.const_int_bound(test_case.expr)
82+
bounds = benchmark(analyzer.const_int_bound, test_case.expr)
5883

5984
if test_case.expected_bounds[0] is None:
6085
assert bounds.max_value == test_case.expected_bounds[1]

tests/python/arith/test_arith_rewrite_simplify.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,43 @@ def __name__(self):
5050
return str(self.before)
5151

5252

53+
with_reciprocal_pattern_match = tvm.testing.parameter(
54+
by_dict={
55+
"with_updated_const_int_analyzer": True,
56+
"without_updated_const_int_analyzer": False,
57+
}
58+
)
59+
60+
61+
@pytest.fixture(autouse=True)
62+
def set_reciprocal_pattern_match(with_reciprocal_pattern_match):
63+
import os
64+
65+
var_name = "TVM_ENABLE_RECIPROCAL_PATTERN_MATCH"
66+
old_value = os.environ.get(var_name)
67+
os.environ[var_name] = str(int(with_reciprocal_pattern_match))
68+
yield
69+
70+
if old_value is None:
71+
del os.environ[var_name]
72+
else:
73+
os.environ = old_value
74+
75+
5376
class BaseCompare:
54-
def test_simplify(self, test_case):
77+
def test_simplify(self, test_case, benchmark):
5578
analyzer = tvm.arith.Analyzer()
5679

5780
if inspect.isclass(test_case.expected) and issubclass(test_case.expected, Exception):
5881
with pytest.raises(test_case.expected):
5982
with analyzer.constraint_scope(test_case.constraint):
60-
analyzer.rewrite_simplify(test_case.before)
83+
# analyzer.rewrite_simplify(test_case.before)
84+
benchmark(analyzer.rewrite_simplify, test_case.before)
6185
else:
6286

6387
with analyzer.constraint_scope(test_case.constraint):
64-
after = analyzer.rewrite_simplify(test_case.before)
88+
# after = analyzer.rewrite_simplify(test_case.before)
89+
after = benchmark(analyzer.rewrite_simplify, test_case.before)
6590

6691
assert tvm.ir.structural_equal(after, test_case.expected), (
6792
f"Rewrite didn't match expected.\n"

0 commit comments

Comments
 (0)