Skip to content

Commit 9f75fe7

Browse files
committed
[TIR][MetaSchedule] Estimate TIR FLOPs
1 parent 8ebdf6e commit 9f75fe7

File tree

5 files changed

+289
-204
lines changed

5 files changed

+289
-204
lines changed

include/tvm/tir/analysis.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,20 @@ inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
7272
}
7373
}
7474

75+
/*!
76+
* \brief Estimate the FLOPs of a TIR fragment.
77+
* \param stmt The TIR fragment to be estimated.
78+
* \return The estimated FLOPs.
79+
*/
80+
TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
81+
82+
/*!
83+
* \brief Estimate the FLOPs of TIRs in an IRModule.
84+
* \param mod The IRModule to be estimated.
85+
* \return The estimated FLOPs.
86+
*/
87+
TVM_DLL double EstimateTIRFlops(const IRModule& mod);
88+
7589
/*!
7690
* \brief Find undefined vars in the statement.
7791
* \param stmt The function to be checked.

python/tvm/tir/analysis/analysis.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
# under the License.
1717
"""Wrapping existing analysis utils."""
1818
# pylint: disable=invalid-name
19-
from typing import Dict, List
19+
from typing import Dict, List, Union
2020

2121
from tvm import Object
22-
from tvm.tir.stmt import Block, BufferRegion
23-
from tvm.tir.stmt import PrimExpr
22+
from tvm.ir import IRModule
2423
from tvm.tir.expr import Var
25-
from . import _ffi_api
26-
from ..function import PrimFunc
24+
from tvm.tir.stmt import Block, BufferRegion, PrimExpr
25+
2726
from .. import Buffer, Stmt
27+
from ..function import PrimFunc
28+
from . import _ffi_api
2829

2930

3031
def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
@@ -199,6 +200,22 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
199200
return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member
200201

201202

203+
def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) -> float:
204+
"""Estimate the FLOPs of a TIR fragment.
205+
206+
Parameters
207+
----------
208+
stmt_or_mod: Union[Stmt, IRModule]
209+
The TIR fragment or IRModule to be estimated.
210+
211+
Returns
212+
-------
213+
flops: float
214+
The estimated FLOPs.
215+
"""
216+
return _ffi_api.EstimateTIRFlops(stmt_or_mod) # type: ignore # pylint: disable=no-member
217+
218+
202219
# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would
203220
# introduce a cycling dependency. We make do with Object.
204221

src/meta_schedule/measure_callback/echo_statistics.cc

Lines changed: 1 addition & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -20,204 +20,6 @@
2020

2121
#include "../utils.h"
2222

23-
namespace tvm {
24-
namespace tir {
25-
26-
double CountFlop(const IRModule& mod) {
27-
struct TResult {
28-
using TTable = std::unordered_map<int32_t, double>;
29-
30-
TResult() = default;
31-
32-
explicit TResult(const tvm::DataType& dtype) { Add(dtype); }
33-
34-
void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }
35-
36-
TResult operator+=(const TResult& rhs) {
37-
for (const auto& kv : rhs.data_) {
38-
data_[kv.first] += kv.second;
39-
}
40-
return *this;
41-
}
42-
43-
TResult operator*=(int64_t rhs) {
44-
for (auto& kv : data_) {
45-
kv.second *= rhs;
46-
}
47-
return *this;
48-
}
49-
50-
TResult MaxWith(const TResult& rhs) {
51-
for (const auto& kv : rhs.data_) {
52-
double& v = data_[kv.first];
53-
if (v < kv.second) {
54-
v = kv.second;
55-
}
56-
}
57-
return *this;
58-
}
59-
60-
struct DType {
61-
uint8_t code : 8;
62-
uint8_t bits : 8;
63-
uint16_t lanes : 16;
64-
};
65-
static_assert(sizeof(DType) == 4, "Incorrect size of DType");
66-
67-
static String Int2Str(int32_t dtype) {
68-
union {
69-
DType dst;
70-
int32_t src;
71-
} converter;
72-
converter.src = dtype;
73-
static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"};
74-
std::ostringstream os;
75-
os << type_code_tab[converter.dst.code];
76-
os << static_cast<int>(converter.dst.bits);
77-
if (converter.dst.lanes != 1) {
78-
os << "x" << static_cast<int>(converter.dst.lanes);
79-
}
80-
return os.str();
81-
}
82-
83-
static int32_t DataType2Int(const tvm::DataType& dtype) {
84-
union {
85-
DType src;
86-
int32_t dst;
87-
} converter;
88-
converter.src.code = dtype.code();
89-
converter.src.bits = dtype.bits();
90-
converter.src.lanes = dtype.lanes();
91-
return converter.dst;
92-
}
93-
94-
TTable data_;
95-
};
96-
97-
class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>,
98-
public StmtFunctor<TResult(const Stmt& n)> {
99-
public:
100-
~FlopCounter() {}
101-
102-
TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
103-
TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }
104-
105-
TResult VisitStmt_(const IfThenElseNode* branch) override {
106-
TResult cond = VisitExpr(branch->condition);
107-
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
108-
return cond;
109-
}
110-
111-
TResult VisitStmt_(const BufferStoreNode* store) override {
112-
TResult result = VisitExpr(store->value);
113-
for (const PrimExpr& e : store->indices) {
114-
result += VisitExpr(e);
115-
}
116-
return result;
117-
}
118-
119-
TResult VisitStmt_(const SeqStmtNode* seq) override {
120-
TResult result;
121-
for (const Stmt& stmt : seq->seq) {
122-
result += VisitStmt(stmt);
123-
}
124-
return result;
125-
}
126-
127-
TResult VisitStmt_(const BlockRealizeNode* block) override {
128-
return VisitStmt(block->block->body);
129-
}
130-
131-
TResult VisitStmt_(const BlockNode* block) override {
132-
TResult result;
133-
if (block->init.defined()) {
134-
result += VisitStmt(block->init.value());
135-
}
136-
result += VisitStmt(block->body);
137-
return result;
138-
}
139-
140-
TResult VisitStmt_(const ForNode* loop) override {
141-
TResult result = VisitStmt(loop->body);
142-
const auto* int_imm = loop->extent.as<IntImmNode>();
143-
ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
144-
<< loop->extent->GetTypeKey();
145-
result *= int_imm->value;
146-
return result;
147-
}
148-
149-
#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \
150-
TResult VisitExpr_(const Node* op) final { \
151-
TResult result(op->dtype); \
152-
result += VisitExpr(op->a); \
153-
result += VisitExpr(op->b); \
154-
return result; \
155-
}
156-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode);
157-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode);
158-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode);
159-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode);
160-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode);
161-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode);
162-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode);
163-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode);
164-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode);
165-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode);
166-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode);
167-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode);
168-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode);
169-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode);
170-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode);
171-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode);
172-
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode);
173-
#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY
174-
TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
175-
TResult VisitExpr_(const VarNode* op) override { return TResult(); }
176-
TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
177-
TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
178-
TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
179-
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
180-
TResult VisitExpr_(const NotNode* op) override {
181-
TResult result(op->dtype);
182-
result += VisitExpr(op->a);
183-
return result;
184-
}
185-
TResult VisitExpr_(const SelectNode* op) override {
186-
TResult cond = VisitExpr(op->condition);
187-
cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
188-
return cond;
189-
}
190-
TResult VisitExpr_(const CallNode* op) override {
191-
TResult ret;
192-
for (const auto& x : op->args) {
193-
ret += VisitExpr(x);
194-
}
195-
return ret;
196-
}
197-
};
198-
FlopCounter counter;
199-
TResult result;
200-
for (const auto& kv : mod->functions) {
201-
const BaseFunc& base_func = kv.second;
202-
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
203-
result += counter.VisitStmt(prim_func->body);
204-
}
205-
}
206-
double cnt = 0.0;
207-
int i32 = TResult::DataType2Int(tvm::DataType::Int(32));
208-
int i64 = TResult::DataType2Int(tvm::DataType::Int(64));
209-
int u1 = TResult::DataType2Int(tvm::DataType::UInt(1));
210-
for (const auto& kv : result.data_) {
211-
if (kv.first != i32 && kv.first != i64 && kv.first != u1) {
212-
cnt += kv.second;
213-
}
214-
}
215-
return cnt;
216-
}
217-
218-
} // namespace tir
219-
} // namespace tvm
220-
22123
namespace tvm {
22224
namespace meta_schedule {
22325

@@ -312,7 +114,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
312114
for (const TuneContext& task : tasks) {
313115
task_info.push_back(TaskInfo(GetTaskName(task, task_id)));
314116
TaskInfo& info = task_info.back();
315-
info.flop = tir::CountFlop(task->mod.value());
117+
info.flop = tir::EstimateTIRFlops(task->mod.value());
316118
++task_id;
317119
}
318120
}

0 commit comments

Comments
 (0)