|
20 | 20 |
|
21 | 21 | #include "../utils.h" |
22 | 22 |
|
23 | | -namespace tvm { |
24 | | -namespace tir { |
25 | | - |
26 | | -double CountFlop(const IRModule& mod) { |
27 | | - struct TResult { |
28 | | - using TTable = std::unordered_map<int32_t, double>; |
29 | | - |
30 | | - TResult() = default; |
31 | | - |
32 | | - explicit TResult(const tvm::DataType& dtype) { Add(dtype); } |
33 | | - |
34 | | - void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } |
35 | | - |
36 | | - TResult operator+=(const TResult& rhs) { |
37 | | - for (const auto& kv : rhs.data_) { |
38 | | - data_[kv.first] += kv.second; |
39 | | - } |
40 | | - return *this; |
41 | | - } |
42 | | - |
43 | | - TResult operator*=(int64_t rhs) { |
44 | | - for (auto& kv : data_) { |
45 | | - kv.second *= rhs; |
46 | | - } |
47 | | - return *this; |
48 | | - } |
49 | | - |
50 | | - TResult MaxWith(const TResult& rhs) { |
51 | | - for (const auto& kv : rhs.data_) { |
52 | | - double& v = data_[kv.first]; |
53 | | - if (v < kv.second) { |
54 | | - v = kv.second; |
55 | | - } |
56 | | - } |
57 | | - return *this; |
58 | | - } |
59 | | - |
60 | | - struct DType { |
61 | | - uint8_t code : 8; |
62 | | - uint8_t bits : 8; |
63 | | - uint16_t lanes : 16; |
64 | | - }; |
65 | | - static_assert(sizeof(DType) == 4, "Incorrect size of DType"); |
66 | | - |
67 | | - static String Int2Str(int32_t dtype) { |
68 | | - union { |
69 | | - DType dst; |
70 | | - int32_t src; |
71 | | - } converter; |
72 | | - converter.src = dtype; |
73 | | - static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"}; |
74 | | - std::ostringstream os; |
75 | | - os << type_code_tab[converter.dst.code]; |
76 | | - os << static_cast<int>(converter.dst.bits); |
77 | | - if (converter.dst.lanes != 1) { |
78 | | - os << "x" << static_cast<int>(converter.dst.lanes); |
79 | | - } |
80 | | - return os.str(); |
81 | | - } |
82 | | - |
83 | | - static int32_t DataType2Int(const tvm::DataType& dtype) { |
84 | | - union { |
85 | | - DType src; |
86 | | - int32_t dst; |
87 | | - } converter; |
88 | | - converter.src.code = dtype.code(); |
89 | | - converter.src.bits = dtype.bits(); |
90 | | - converter.src.lanes = dtype.lanes(); |
91 | | - return converter.dst; |
92 | | - } |
93 | | - |
94 | | - TTable data_; |
95 | | - }; |
96 | | - |
97 | | - class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>, |
98 | | - public StmtFunctor<TResult(const Stmt& n)> { |
99 | | - public: |
100 | | - ~FlopCounter() {} |
101 | | - |
102 | | - TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } |
103 | | - TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } |
104 | | - |
105 | | - TResult VisitStmt_(const IfThenElseNode* branch) override { |
106 | | - TResult cond = VisitExpr(branch->condition); |
107 | | - cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); |
108 | | - return cond; |
109 | | - } |
110 | | - |
111 | | - TResult VisitStmt_(const BufferStoreNode* store) override { |
112 | | - TResult result = VisitExpr(store->value); |
113 | | - for (const PrimExpr& e : store->indices) { |
114 | | - result += VisitExpr(e); |
115 | | - } |
116 | | - return result; |
117 | | - } |
118 | | - |
119 | | - TResult VisitStmt_(const SeqStmtNode* seq) override { |
120 | | - TResult result; |
121 | | - for (const Stmt& stmt : seq->seq) { |
122 | | - result += VisitStmt(stmt); |
123 | | - } |
124 | | - return result; |
125 | | - } |
126 | | - |
127 | | - TResult VisitStmt_(const BlockRealizeNode* block) override { |
128 | | - return VisitStmt(block->block->body); |
129 | | - } |
130 | | - |
131 | | - TResult VisitStmt_(const BlockNode* block) override { |
132 | | - TResult result; |
133 | | - if (block->init.defined()) { |
134 | | - result += VisitStmt(block->init.value()); |
135 | | - } |
136 | | - result += VisitStmt(block->body); |
137 | | - return result; |
138 | | - } |
139 | | - |
140 | | - TResult VisitStmt_(const ForNode* loop) override { |
141 | | - TResult result = VisitStmt(loop->body); |
142 | | - const auto* int_imm = loop->extent.as<IntImmNode>(); |
143 | | - ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " |
144 | | - << loop->extent->GetTypeKey(); |
145 | | - result *= int_imm->value; |
146 | | - return result; |
147 | | - } |
148 | | - |
149 | | -#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \ |
150 | | - TResult VisitExpr_(const Node* op) final { \ |
151 | | - TResult result(op->dtype); \ |
152 | | - result += VisitExpr(op->a); \ |
153 | | - result += VisitExpr(op->b); \ |
154 | | - return result; \ |
155 | | - } |
156 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode); |
157 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode); |
158 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode); |
159 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode); |
160 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode); |
161 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode); |
162 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode); |
163 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode); |
164 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode); |
165 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode); |
166 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode); |
167 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode); |
168 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode); |
169 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode); |
170 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode); |
171 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode); |
172 | | - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode); |
173 | | -#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY |
174 | | - TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } |
175 | | - TResult VisitExpr_(const VarNode* op) override { return TResult(); } |
176 | | - TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } |
177 | | - TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } |
178 | | - TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } |
179 | | - TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } |
180 | | - TResult VisitExpr_(const NotNode* op) override { |
181 | | - TResult result(op->dtype); |
182 | | - result += VisitExpr(op->a); |
183 | | - return result; |
184 | | - } |
185 | | - TResult VisitExpr_(const SelectNode* op) override { |
186 | | - TResult cond = VisitExpr(op->condition); |
187 | | - cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); |
188 | | - return cond; |
189 | | - } |
190 | | - TResult VisitExpr_(const CallNode* op) override { |
191 | | - TResult ret; |
192 | | - for (const auto& x : op->args) { |
193 | | - ret += VisitExpr(x); |
194 | | - } |
195 | | - return ret; |
196 | | - } |
197 | | - }; |
198 | | - FlopCounter counter; |
199 | | - TResult result; |
200 | | - for (const auto& kv : mod->functions) { |
201 | | - const BaseFunc& base_func = kv.second; |
202 | | - if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
203 | | - result += counter.VisitStmt(prim_func->body); |
204 | | - } |
205 | | - } |
206 | | - double cnt = 0.0; |
207 | | - int i32 = TResult::DataType2Int(tvm::DataType::Int(32)); |
208 | | - int i64 = TResult::DataType2Int(tvm::DataType::Int(64)); |
209 | | - int u1 = TResult::DataType2Int(tvm::DataType::UInt(1)); |
210 | | - for (const auto& kv : result.data_) { |
211 | | - if (kv.first != i32 && kv.first != i64 && kv.first != u1) { |
212 | | - cnt += kv.second; |
213 | | - } |
214 | | - } |
215 | | - return cnt; |
216 | | -} |
217 | | - |
218 | | -} // namespace tir |
219 | | -} // namespace tvm |
220 | | - |
221 | 23 | namespace tvm { |
222 | 24 | namespace meta_schedule { |
223 | 25 |
|
@@ -312,7 +114,7 @@ class EchoStatisticsNode : public MeasureCallbackNode { |
312 | 114 | for (const TuneContext& task : tasks) { |
313 | 115 | task_info.push_back(TaskInfo(GetTaskName(task, task_id))); |
314 | 116 | TaskInfo& info = task_info.back(); |
315 | | - info.flop = tir::CountFlop(task->mod.value()); |
| 117 | + info.flop = tir::EstimateTIRFlops(task->mod.value()); |
316 | 118 | ++task_id; |
317 | 119 | } |
318 | 120 | } |
|
0 commit comments