Skip to content

Commit c5075dc

Browse files
authored
[TIR] not estimating the flops when there is a default estimated flops as attr (#14379)
* not estimating the flops when there is a default estimated flops as attr * add unittests * lint fix * make unittest simpler
1 parent 4a2a3b5 commit c5075dc

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

src/tir/analysis/estimate_flops.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,15 @@ double EstimateTIRFlops(const Stmt& stmt) {
208208
double EstimateTIRFlops(const IRModule& mod) {
209209
FlopEstimator counter;
210210
TResult result;
211-
VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) {
212-
result += counter.VisitStmt(f->body); //
211+
double cached_result = 0;
212+
VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) {
213+
if (auto cached = f->attrs.GetAttr<Integer>("estimated_flops")) {
214+
cached_result += cached.value()->value;
215+
} else {
216+
result += counter.VisitStmt(f->body); //
217+
}
213218
});
214-
return PostprocessResults(result);
219+
return PostprocessResults(result) + cached_result;
215220
}
216221

217222
TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double {

tests/python/unittest/test_tir_analysis_estimate_tir_flops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,35 @@ def test_flops_with_if():
7777
assert flops == 16
7878

7979

80+
@T.prim_func
81+
def flops_with_forloop_as_expression(A: T.Buffer(1)):
82+
for i in T.serial(0, 16):
83+
for k in T.serial(0, i):
84+
A[0] = A[0] + 1
85+
86+
87+
@T.prim_func
88+
def flops_override(A: T.Buffer(16, "float32")):
89+
T.func_attr({"estimated_flops": 32})
90+
for i in range(16):
91+
A[0] = A[0] + 1
92+
93+
94+
def test_estimate_flops_forloop_as_experssion():
95+
flops = estimate_tir_flops(
96+
IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)})
97+
)
98+
assert flops == 32
99+
100+
# test whether the user estimated flop would over ride
101+
flops = estimate_tir_flops(IRModule({"main": flops_override}))
102+
assert flops == 32
103+
104+
105+
def test_exception():
106+
with pytest.raises(tvm.TVMError):
107+
flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression}))
108+
109+
80110
if __name__ == "__main__":
81111
tvm.testing.main()

0 commit comments

Comments
 (0)