Skip to content

Commit 1294926

Browse files
authored
[Arith] Implement statistics counters for RewriteSimplifier (#14532)
* [Arith] Implement statistics counters for RewriteSimplifier Previously, so long as `RewriteSimplifier` produces the same output, unit tests of its behavior would pass. This could have severe performance regressions, such as the one resolved in #14528, which caused the runtime of two test to increase from ~1.5 seconds to ~10 minutes each. This commit implements statistics counts in RewriteSimplifier, which are exposed through both the C++ and Python APIs, and uses these to guard against the known performance regression from #14528. * lint fixes * Updates based on review comments * Consistent int64_t with kMaxRecurDepth * Removed unused is_currently_visiting_ * Add missing \brief for RewriteSimplifierStatsNode * Use int64_t in ControlFlowGraph for max simplification steps
1 parent aa7d2bf commit 1294926

File tree

9 files changed

+173
-8
lines changed

9 files changed

+173
-8
lines changed

include/tvm/arith/analyzer.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,27 @@ class RewriteSimplifier {
346346
/*! \brief Return the currently enabled extensions */
347347
TVM_DLL Extension GetEnabledExtensions() const;
348348

349+
/*! \brief Return the statistics counters */
350+
TVM_DLL ObjectRef GetStatsCounters() const;
351+
352+
/*! \brief Reset the statistics counters */
353+
TVM_DLL void ResetStatsCounters();
354+
355+
/*! \brief Set the maximum allowed number of rewrite steps
356+
*
357+
* By default, the simplifier may perform as many steps as are
358+
* required. If a positive limit is set, then the simplifier will
359+
* throw an exception when exceeding that number of rewrite steps.
360+
* This allows tests to guard against performance regressions.
361+
*
362+
* Note: To maintain accurate usage counters, `Analyzer` instances
363+
* should be re-used wherever possible. For example, TIR
364+
* transformations should declare a single `Analyzer` that is used
365+
* throughout the pass, and utility functions should receive an
366+
* `Analyzer*` from their calling scope.
367+
*/
368+
TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
369+
349370
private:
350371
friend class Analyzer;
351372
friend class ConstraintContext;

python/tvm/arith/analyzer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(self):
9696
self._modular_set = _mod("modular_set")
9797
self._simplify = _mod("Simplify")
9898
self._rewrite_simplify = _mod("rewrite_simplify")
99+
self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats")
100+
self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats")
99101
self._canonical_simplify = _mod("canonical_simplify")
100102
self._int_set = _mod("int_set")
101103
self._enter_constraint_context = _mod("enter_constraint_context")
@@ -167,6 +169,13 @@ def rewrite_simplify(self, expr):
167169
"""
168170
return self._rewrite_simplify(expr)
169171

172+
@property
173+
def rewrite_simplify_stats(self):
174+
return self._get_rewrite_simplify_stats()
175+
176+
def reset_rewrite_simplify_stats(self):
177+
self._reset_rewrite_simplify_stats()
178+
170179
def canonical_simplify(self, expr):
171180
"""Simplify expression via canonicalization.
172181

src/arith/analyzer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
228228
} else if (name == "rewrite_simplify") {
229229
return PackedFunc(
230230
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
231+
} else if (name == "get_rewrite_simplify_stats") {
232+
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
233+
*ret = self->rewrite_simplify.GetStatsCounters();
234+
});
235+
} else if (name == "reset_rewrite_simplify_stats") {
236+
return PackedFunc(
237+
[self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); });
231238
} else if (name == "canonical_simplify") {
232239
return PackedFunc(
233240
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });

src/arith/rewrite_simplify.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,33 @@ using namespace tir;
5858

5959
// macro for doing simple rewrite
6060
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
61+
RecordAttemptedRewrite(); \
6162
if ((SrcExpr).Match(ret)) { \
63+
RecordRewrite(); \
6264
return (ResExpr).Eval(); \
6365
}
6466

6567
// macro for rewrite + recursively rewrite ResExpr
6668
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
69+
RecordAttemptedRewrite(); \
6770
if ((SrcExpr).Match(ret)) { \
71+
RecordRewrite(); \
6872
return RecursiveRewrite((ResExpr).Eval()); \
6973
}
7074

7175
// macro rewrite only if CondExor is true after match.
7276
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
77+
RecordAttemptedRewrite(); \
7378
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
79+
RecordRewrite(); \
7480
return (ResExpr).Eval(); \
7581
}
7682

7783
// macro rewrite + recursive_rewrite only if CondExor is true after match.
7884
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
85+
RecordAttemptedRewrite(); \
7986
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
87+
RecordRewrite(); \
8088
return RecursiveRewrite((ResExpr).Eval()); \
8189
}
8290

@@ -211,6 +219,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
211219
return CompareResult::kUnknown;
212220
}
213221

222+
PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) {
223+
stats_.nodes_visited++;
224+
return IRMutatorWithAnalyzer::VisitExpr(e);
225+
}
226+
214227
void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
215228
if (!can_override) {
216229
auto it = var_map_.find(var);
@@ -359,6 +372,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
359372
literal_constraints_.push_back(Not(negation));
360373
}
361374
}
375+
stats_.constraints_entered++;
362376
size_t new_literal_size = literal_constraints_.size();
363377
auto frecover = [old_literal_size, new_literal_size, this]() {
364378
ICHECK_EQ(literal_constraints_.size(), new_literal_size);
@@ -2150,9 +2164,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
21502164
return impl_->GetEnabledExtensions();
21512165
}
21522166

2167+
ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); }
2168+
2169+
void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); }
2170+
2171+
void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) {
2172+
impl_->SetMaximumRewriteSteps(maximum);
2173+
}
2174+
21532175
RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
21542176

21552177
RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
21562178

2179+
TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode);
2180+
2181+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
2182+
.set_dispatch<RewriteSimplifierStatsNode>([](const ObjectRef& node, ReprPrinter* p) {
2183+
auto* ptr = node.as<RewriteSimplifierStatsNode>();
2184+
p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited
2185+
<< ", constraints_entered = " << ptr->constraints_entered
2186+
<< ", rewrites_attempted = " << ptr->rewrites_attempted
2187+
<< ", rewrites_performed = " << ptr->rewrites_performed
2188+
<< ", max_recursive_depth = " << ptr->max_recursive_depth
2189+
<< ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")";
2190+
});
2191+
21572192
} // namespace arith
21582193
} // namespace tvm

src/arith/rewrite_simplify.h

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/arith/analyzer.h>
2828
#include <tvm/tir/op.h>
2929

30+
#include <algorithm>
3031
#include <unordered_map>
3132
#include <vector>
3233

@@ -39,6 +40,41 @@ namespace arith {
3940

4041
using namespace tir;
4142

43+
/* \brief Usage counters for RewriteSimplifier
44+
*
45+
* These are intended for debug and testing purposes, to ensure that
46+
* PrimExpr simplifications and TIR passes do not require an excessive
47+
*/
48+
struct RewriteSimplifierStatsNode : Object {
49+
int64_t nodes_visited{0};
50+
int64_t constraints_entered{0};
51+
int64_t rewrites_attempted{0};
52+
int64_t rewrites_performed{0};
53+
int64_t max_recursive_depth{0};
54+
int64_t num_recursive_rewrites{0};
55+
56+
void VisitAttrs(AttrVisitor* v) {
57+
v->Visit("nodes_visited", &nodes_visited);
58+
v->Visit("constraints_entered", &constraints_entered);
59+
v->Visit("rewrites_attempted", &rewrites_attempted);
60+
v->Visit("rewrites_performed", &rewrites_performed);
61+
v->Visit("max_recursive_depth", &max_recursive_depth);
62+
v->Visit("num_recursive_rewrites", &num_recursive_rewrites);
63+
}
64+
65+
static constexpr const char* _type_key = "arith.RewriteSimplifierStats";
66+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object);
67+
};
68+
69+
struct RewriteSimplifierStats : ObjectRef {
70+
explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) {
71+
data_ = make_object<RewriteSimplifierStatsNode>(data);
72+
}
73+
74+
TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode);
75+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode);
76+
};
77+
4278
/*!
4379
* \brief Rewrite-based simplifier.
4480
*
@@ -50,6 +86,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
5086

5187
explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
5288

89+
PrimExpr VisitExpr(const PrimExpr& e) override;
90+
5391
void Update(const Var& var, const PrimExpr& info, bool override_info);
5492
PrimExpr VisitExpr_(const AddNode* op) override;
5593
PrimExpr VisitExpr_(const SubNode* op) override;
@@ -87,9 +125,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
87125
/*! \brief Return the currently enabled extensions */
88126
Extension GetEnabledExtensions() const;
89127

128+
RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); }
129+
130+
void ResetStatsCounters() { stats_ = {}; }
131+
132+
void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; }
133+
90134
protected:
135+
int64_t maximum_rewrite_steps_{0};
136+
RewriteSimplifierStatsNode stats_;
137+
138+
void RecordAttemptedRewrite() { stats_.rewrites_attempted++; }
139+
void RecordRewrite() {
140+
stats_.rewrites_performed++;
141+
142+
ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_)
143+
<< "RewriteSimplifier exceeded maximum number of rewrites allowed ("
144+
<< maximum_rewrite_steps_ << ")";
145+
}
146+
91147
// counter to record recursive rewrite depth.
92-
int recur_depth_{0};
148+
int64_t recur_depth_{0};
93149
// internal variable map
94150
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
95151

@@ -103,7 +159,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
103159
bool recursively_visiting_boolean_{false};
104160

105161
// maximum number of recursion allowed during a single pass.
106-
static const constexpr int kMaxRecurDepth = 5;
162+
static const constexpr int64_t kMaxRecurDepth = 5;
107163
/*!
108164
* \brief try to compare x against val.
109165
* \param x The expression to be evaluated.
@@ -177,8 +233,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
177233
// we limit maximum depth of recursive rewrite allowed to
178234
// avoid infinite loop
179235
PrimExpr RecursiveRewrite(const PrimExpr& x) {
236+
stats_.num_recursive_rewrites++;
180237
if (recur_depth_ >= kMaxRecurDepth) return x;
181238
++recur_depth_;
239+
stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth);
182240
PrimExpr res = this->VisitExpr(x);
183241
--recur_depth_;
184242
return res;

src/tir/analysis/control_flow_graph.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
820820
return buffer_touch;
821821
}
822822

823-
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
824-
: max_revisits_(max_revisits) {
823+
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps,
824+
size_t max_revisits)
825+
: max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) {
825826
ControlFlowGraphBuilder::Build(this, stmt);
826827
ForwardPropagateKnownValues();
827828
BackwardPropagateUnusedValues();
@@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_fr
13771378
std::unordered_map<size_t, size_t> visit_count_lookup;
13781379

13791380
Analyzer analyzer;
1381+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
13801382
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
13811383
arith::RewriteSimplifier::kTransitivelyProveInequalities |
13821384
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
@@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_
15101512
std::unordered_map<size_t, size_t> visit_count_lookup;
15111513

15121514
Analyzer analyzer;
1515+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
15131516
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
15141517
arith::RewriteSimplifier::kTransitivelyProveInequalities |
15151518
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |

src/tir/analysis/control_flow_graph.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ class ControlFlowGraph {
399399
public:
400400
/* \brief Extract the touch pattern from a TIR statement
401401
*/
402-
explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
402+
explicit ControlFlowGraph(const Stmt& stmt, int64_t max_simplification_steps = 0,
403+
size_t max_revisits = 5);
403404

404405
/* \brief Check if a write is overwritten without impacting final results
405406
*
@@ -655,6 +656,9 @@ class ControlFlowGraph {
655656

656657
/*! \brief The maximum number of revisits while flowing constraints */
657658
size_t max_revisits_;
659+
660+
/*! \brief The maximum number of revisits while flowing constraints */
661+
int64_t max_simplification_steps_;
658662
};
659663

660664
} // namespace tir

src/tir/transforms/remove_no_op.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,20 @@ namespace tir {
4242

4343
struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
4444
bool use_dataflow_analysis;
45+
int64_t max_simplification_steps;
4546

4647
TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
4748
TVM_ATTR_FIELD(use_dataflow_analysis)
4849
.describe(
4950
"If true, known buffer values are propagated and used "
5051
"to statically prove statements as no-ops.")
5152
.set_default(false);
53+
TVM_ATTR_FIELD(max_simplification_steps)
54+
.describe(
55+
"If non-zero, RewriteSimplifier will throw an error "
56+
"after the number of steps specified. "
57+
"For use in debug and testing purposes.")
58+
.set_default(0);
5259
}
5360
};
5461

@@ -291,14 +298,19 @@ Pass RemoveNoOp() {
291298

292299
RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
293300
.value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
301+
294302
if (config->use_dataflow_analysis) {
295-
touch_pattern.emplace(f->body);
303+
touch_pattern.emplace(f->body, config->max_simplification_steps);
296304
}
297305

298306
arith::Analyzer analyzer;
307+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps);
299308

300-
auto* n = f.CopyOnWrite();
301-
n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
309+
{
310+
auto* write_ptr = f.CopyOnWrite();
311+
write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer,
312+
std::move(touch_pattern), nullptr);
313+
}
302314
return f;
303315
};
304316
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});

tests/python/unittest/test_tir_transform_remove_no_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None:
8686

8787
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
8888
use_dataflow_analysis = False
89+
max_simplification_steps = 0
8990

9091
def transform(self):
9192
def inner(mod):
9293
config = {
9394
"tir.RemoveNoOp": {
9495
"use_dataflow_analysis": self.use_dataflow_analysis,
96+
"max_simplification_steps": self.max_simplification_steps,
9597
}
9698
}
9799
with tvm.transform.PassContext(config=config):
@@ -319,9 +321,16 @@ class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter)
319321
Similar to TestKeepPartiallyOverwrittenLoop, except the first loop
320322
has the same predicate as the second, and can therefore be
321323
removed.
324+
325+
In the past, this test has had performance regressions in which
326+
the runtime increased from a few seconds to nearly ten minutes.
327+
The "max_simplification_steps" parameter is set at twice the
328+
current number of steps required, in order to prevent similar
329+
performance regression.
322330
"""
323331

324332
use_dataflow_analysis = True
333+
max_simplification_steps = 200000
325334

326335
def before(A: T.Buffer(16, "int32")):
327336
for i in T.serial(16):
@@ -347,9 +356,16 @@ class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter):
347356
loop's predicate. So long as the regions written in the first
348357
loop are a subset of those written in the second loop, they can be
349358
removed.
359+
360+
In the past, this test has had performance regressions in which
361+
the runtime increased from a few seconds to nearly ten minutes.
362+
The "max_simplification_steps" parameter is set at twice the
363+
current number of steps required, in order to prevent similar
364+
performance regression.
350365
"""
351366

352367
use_dataflow_analysis = True
368+
max_simplification_steps = 200000
353369

354370
def before(A: T.Buffer(16, "int32")):
355371
for i in T.serial(16):

0 commit comments

Comments
 (0)