@@ -57,25 +57,33 @@ using namespace tir;
5757
5858// macro for doing simple rewrite
5959#define TVM_TRY_REWRITE (SrcExpr, ResExpr ) \
60+ RecordAttemptedRewrite (); \
6061 if ((SrcExpr).Match(ret)) { \
62+ RecordRewrite (); \
6163 return (ResExpr).Eval (); \
6264 }
6365
6466// macro for rewrite + recursively rewrite ResExpr
6567#define TVM_TRY_RECURSIVE_REWRITE (SrcExpr, ResExpr ) \
68+ RecordAttemptedRewrite (); \
6669 if ((SrcExpr).Match(ret)) { \
70+ RecordRewrite (); \
6771 return RecursiveRewrite ((ResExpr).Eval ()); \
6872 }
6973
7074// macro rewrite only if CondExor is true after match.
7175#define TVM_TRY_REWRITE_IF (SrcExpr, ResExpr, CondExpr ) \
76+ RecordAttemptedRewrite (); \
7277 if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
78+ RecordRewrite (); \
7379 return (ResExpr).Eval (); \
7480 }
7581
7682// macro rewrite + recursive_rewrite only if CondExor is true after match.
7783#define TVM_TRY_RECURSIVE_REWRITE_IF (SrcExpr, ResExpr, CondExpr ) \
84+ RecordAttemptedRewrite (); \
7885 if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
86+ RecordRewrite (); \
7987 return RecursiveRewrite ((ResExpr).Eval ()); \
8088 }
8189
@@ -185,6 +193,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
185193 return CompareResult::kUnknown ;
186194}
187195
196+ PrimExpr RewriteSimplifier::Impl::VisitExpr (const PrimExpr& e) {
197+ stats_.nodes_visited ++;
198+ return IRMutatorWithAnalyzer::VisitExpr (e);
199+ }
200+
188201void RewriteSimplifier::Impl::Update (const Var& var, const PrimExpr& info, bool can_override) {
189202 if (!can_override) {
190203 auto it = var_map_.find (var);
@@ -324,6 +337,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
324337 literal_constraints_.push_back (Not (negation));
325338 }
326339 }
340+ stats_.constraints_entered ++;
327341 size_t new_literal_size = literal_constraints_.size ();
328342 auto frecover = [old_literal_size, new_literal_size, this ]() {
329343 ICHECK_EQ (literal_constraints_.size (), new_literal_size);
@@ -2107,9 +2121,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
21072121 return impl_->GetEnabledExtensions ();
21082122}
21092123
2124+ ObjectRef RewriteSimplifier::GetStatsCounters () const { return impl_->GetStatsCounters (); }
2125+
2126+ void RewriteSimplifier::ResetStatsCounters () { impl_->ResetStatsCounters (); }
2127+
2128+ void RewriteSimplifier::SetMaximumRewriteSteps (int maximum) {
2129+ impl_->SetMaximumRewriteSteps (maximum);
2130+ }
2131+
21102132RewriteSimplifier::RewriteSimplifier (Analyzer* parent) : impl_(new Impl(parent)) {}
21112133
21122134RewriteSimplifier::~RewriteSimplifier () { delete impl_; }
21132135
2136+ TVM_REGISTER_NODE_TYPE (RewriteSimplifierStatsNode);
2137+
2138+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
2139+ .set_dispatch<RewriteSimplifierStatsNode>([](const ObjectRef& node, ReprPrinter* p) {
2140+ auto * ptr = node.as <RewriteSimplifierStatsNode>();
2141+ p->stream << " RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited
2142+ << " , constraints_entered = " << ptr->constraints_entered
2143+ << " , rewrites_attempted = " << ptr->rewrites_attempted
2144+ << " , rewrites_performed = " << ptr->rewrites_performed
2145+ << " , max_recursive_depth = " << ptr->max_recursive_depth
2146+ << " , num_recursive_rewrites = " << ptr->num_recursive_rewrites << " )" ;
2147+ });
2148+
21142149} // namespace arith
21152150} // namespace tvm
0 commit comments