Skip to content

Commit 85556df

Browse files
[TIR][FIX] update FlopEstimator to include missing nodes (#17598)
* updated estimate flops to include the AllocateNode * added AttrStmtNode * added a visit to the AttrStmtNode body * added a visit to value of AttrStmtNode
1 parent 050b23f commit 85556df

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/tir/analysis/estimate_flops.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
138138
}
139139

140140
TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
141+
TResult VisitStmt_(const AttrStmtNode* op) override {
142+
TResult result = VisitStmt(op->body);
143+
result += VisitExpr(op->value);
144+
return result;
145+
}
141146
TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); }
142147
TResult VisitStmt_(const BlockRealizeNode* block) override {
143148
return VisitStmt(block->block->body);
@@ -186,6 +191,7 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
186191
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
187192
TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
188193
TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); }
194+
TResult VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); }
189195
TResult VisitStmt_(const DeclBufferNode* op) override { return VisitStmt(op->body); }
190196

191197
TResult VisitStmt_(const SeqStmtNode* seq) override {

0 commit comments

Comments
 (0)