Skip to content

Commit 083eedf

Browse files
committed
[Arith] Implement statistics counters for RewriteSimplifier
Previously, so long as `RewriteSimplifier` produces the same output, unit tests of its behavior would pass. This could have severe performance regressions, such as the one resolved in #14528, which caused the runtime of two test to increase from ~1.5 seconds to ~10 minutes each. This commit implements statistics counts in RewriteSimplifier, which are exposed through both the C++ and Python APIs, and uses these to guard against the known performance regression from #14528.
1 parent 1113de2 commit 083eedf

File tree

9 files changed

+154
-6
lines changed

9 files changed

+154
-6
lines changed

include/tvm/arith/analyzer.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,14 @@ class RewriteSimplifier {
330330
/*! \brief Return the currently enabled extensions */
331331
TVM_DLL Extension GetEnabledExtensions() const;
332332

333+
/*! \brief Return the statistics counters */
334+
TVM_DLL ObjectRef GetStatsCounters() const;
335+
336+
/*! \brief Reset the statistics counters */
337+
TVM_DLL void ResetStatsCounters();
338+
339+
TVM_DLL void SetMaximumRewriteSteps(int maximum);
340+
333341
private:
334342
friend class Analyzer;
335343
friend class ConstraintContext;

python/tvm/arith/analyzer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(self):
8787
self._modular_set = _mod("modular_set")
8888
self._simplify = _mod("Simplify")
8989
self._rewrite_simplify = _mod("rewrite_simplify")
90+
self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats")
91+
self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats")
9092
self._canonical_simplify = _mod("canonical_simplify")
9193
self._int_set = _mod("int_set")
9294
self._enter_constraint_context = _mod("enter_constraint_context")
@@ -157,6 +159,13 @@ def rewrite_simplify(self, expr):
157159
"""
158160
return self._rewrite_simplify(expr)
159161

162+
@property
163+
def rewrite_simplify_stats(self):
164+
return self._get_rewrite_simplify_stats()
165+
166+
def reset_rewrite_simplify_stats(self):
167+
self._reset_rewrite_simplify_stats()
168+
160169
def canonical_simplify(self, expr):
161170
"""Simplify expression via canonicalization.
162171

src/arith/analyzer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
175175
} else if (name == "rewrite_simplify") {
176176
return PackedFunc(
177177
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
178+
} else if (name == "get_rewrite_simplify_stats") {
179+
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
180+
*ret = self->rewrite_simplify.GetStatsCounters();
181+
});
182+
} else if (name == "reset_rewrite_simplify_stats") {
183+
return PackedFunc(
184+
[self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); });
178185
} else if (name == "canonical_simplify") {
179186
return PackedFunc(
180187
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });

src/arith/rewrite_simplify.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,33 @@ using namespace tir;
5757

5858
// macro for doing simple rewrite
5959
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
60+
RecordAttemptedRewrite(); \
6061
if ((SrcExpr).Match(ret)) { \
62+
RecordRewrite(); \
6163
return (ResExpr).Eval(); \
6264
}
6365

6466
// macro for rewrite + recursively rewrite ResExpr
6567
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
68+
RecordAttemptedRewrite(); \
6669
if ((SrcExpr).Match(ret)) { \
70+
RecordRewrite(); \
6771
return RecursiveRewrite((ResExpr).Eval()); \
6872
}
6973

7074
// macro rewrite only if CondExor is true after match.
7175
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
76+
RecordAttemptedRewrite(); \
7277
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
78+
RecordRewrite(); \
7379
return (ResExpr).Eval(); \
7480
}
7581

7682
// macro rewrite + recursive_rewrite only if CondExor is true after match.
7783
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
84+
RecordAttemptedRewrite(); \
7885
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
86+
RecordRewrite(); \
7987
return RecursiveRewrite((ResExpr).Eval()); \
8088
}
8189

@@ -185,6 +193,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
185193
return CompareResult::kUnknown;
186194
}
187195

196+
PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) {
197+
stats_.nodes_visited++;
198+
return IRMutatorWithAnalyzer::VisitExpr(e);
199+
}
200+
188201
void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
189202
if (!can_override) {
190203
auto it = var_map_.find(var);
@@ -324,6 +337,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
324337
literal_constraints_.push_back(Not(negation));
325338
}
326339
}
340+
stats_.constraints_entered++;
327341
size_t new_literal_size = literal_constraints_.size();
328342
auto frecover = [old_literal_size, new_literal_size, this]() {
329343
ICHECK_EQ(literal_constraints_.size(), new_literal_size);
@@ -2107,9 +2121,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
21072121
return impl_->GetEnabledExtensions();
21082122
}
21092123

2124+
ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); }
2125+
2126+
void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); }
2127+
2128+
void RewriteSimplifier::SetMaximumRewriteSteps(int maximum) {
2129+
impl_->SetMaximumRewriteSteps(maximum);
2130+
}
2131+
21102132
RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
21112133

21122134
RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
21132135

2136+
TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode);
2137+
2138+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
2139+
.set_dispatch<RewriteSimplifierStatsNode>([](const ObjectRef& node, ReprPrinter* p) {
2140+
auto* ptr = node.as<RewriteSimplifierStatsNode>();
2141+
p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited
2142+
<< ", constraints_entered = " << ptr->constraints_entered
2143+
<< ", rewrites_attempted = " << ptr->rewrites_attempted
2144+
<< ", rewrites_performed = " << ptr->rewrites_performed
2145+
<< ", max_recursive_depth = " << ptr->max_recursive_depth
2146+
<< ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")";
2147+
});
2148+
21142149
} // namespace arith
21152150
} // namespace tvm

src/arith/rewrite_simplify.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,36 @@ namespace arith {
3939

4040
using namespace tir;
4141

42+
struct RewriteSimplifierStatsNode : Object {
43+
int nodes_visited{0};
44+
int constraints_entered{0};
45+
int rewrites_attempted{0};
46+
int rewrites_performed{0};
47+
int max_recursive_depth{0};
48+
int num_recursive_rewrites{0};
49+
50+
void VisitAttrs(AttrVisitor* v) {
51+
v->Visit("nodes_visited", &nodes_visited);
52+
v->Visit("constraints_entered", &constraints_entered);
53+
v->Visit("rewrites_attempted", &rewrites_attempted);
54+
v->Visit("rewrites_performed", &rewrites_performed);
55+
v->Visit("max_recursive_depth", &max_recursive_depth);
56+
v->Visit("num_recursive_rewrites", &num_recursive_rewrites);
57+
}
58+
59+
static constexpr const char* _type_key = "arith.RewriteSimplifierStats";
60+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object);
61+
};
62+
63+
struct RewriteSimplifierStats : ObjectRef {
64+
RewriteSimplifierStats(RewriteSimplifierStatsNode data) {
65+
data_ = make_object<RewriteSimplifierStatsNode>(data);
66+
}
67+
68+
TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode);
69+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode);
70+
};
71+
4272
/*!
4373
* \brief Rewrite-based simplifier.
4474
*
@@ -50,6 +80,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
5080

5181
explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
5282

83+
PrimExpr VisitExpr(const PrimExpr& e) override;
84+
5385
void Update(const Var& var, const PrimExpr& info, bool override_info);
5486
PrimExpr VisitExpr_(const AddNode* op) override;
5587
PrimExpr VisitExpr_(const SubNode* op) override;
@@ -87,7 +119,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
87119
/*! \brief Return the currently enabled extensions */
88120
Extension GetEnabledExtensions() const;
89121

122+
RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); }
123+
124+
void ResetStatsCounters() { stats_ = {}; }
125+
126+
void SetMaximumRewriteSteps(int maximum) { maximum_rewrite_steps_ = maximum; };
127+
90128
protected:
129+
int maximum_rewrite_steps_{0};
130+
RewriteSimplifierStatsNode stats_;
131+
132+
void RecordAttemptedRewrite() { stats_.rewrites_attempted++; }
133+
void RecordRewrite() {
134+
stats_.rewrites_performed++;
135+
136+
ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_)
137+
<< "RewriteSimplifier exceeded maximum number of rewrites allowed ("
138+
<< maximum_rewrite_steps_ << ")";
139+
}
140+
141+
bool is_currently_visiting_{false};
142+
91143
// counter to record recursive rewrite depth.
92144
int recur_depth_{0};
93145
// internal variable map
@@ -178,8 +230,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
178230
// we limit maximum depth of recursive rewrite allowed to
179231
// avoid infinite loop
180232
PrimExpr RecursiveRewrite(const PrimExpr& x) {
233+
stats_.num_recursive_rewrites++;
181234
if (recur_depth_ >= kMaxRecurDepth) return x;
182235
++recur_depth_;
236+
stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth);
183237
PrimExpr res = this->VisitExpr(x);
184238
--recur_depth_;
185239
return res;

src/tir/analysis/control_flow_graph.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
820820
return buffer_touch;
821821
}
822822

823-
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
824-
: max_revisits_(max_revisits) {
823+
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int max_simplification_steps,
824+
size_t max_revisits)
825+
: max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) {
825826
ControlFlowGraphBuilder::Build(this, stmt);
826827
ForwardPropagateKnownValues();
827828
BackwardPropagateUnusedValues();
@@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_fr
13771378
std::unordered_map<size_t, size_t> visit_count_lookup;
13781379

13791380
Analyzer analyzer;
1381+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
13801382
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
13811383
arith::RewriteSimplifier::kTransitivelyProveInequalities |
13821384
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
@@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_
15101512
std::unordered_map<size_t, size_t> visit_count_lookup;
15111513

15121514
Analyzer analyzer;
1515+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
15131516
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
15141517
arith::RewriteSimplifier::kTransitivelyProveInequalities |
15151518
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |

src/tir/analysis/control_flow_graph.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ class ControlFlowGraph {
399399
public:
400400
/* \brief Extract the touch pattern from a TIR statement
401401
*/
402-
explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
402+
explicit ControlFlowGraph(const Stmt& stmt, int max_simplification_steps = 0,
403+
size_t max_revisits = 5);
403404

404405
/* \brief Check if a write is overwritten without impacting final results
405406
*
@@ -655,6 +656,9 @@ class ControlFlowGraph {
655656

656657
/*! \brief The maximum number of revisits while flowing constraints */
657658
size_t max_revisits_;
659+
660+
/*! \brief The maximum number of revisits while flowing constraints */
661+
int max_simplification_steps_;
658662
};
659663

660664
} // namespace tir

src/tir/transforms/remove_no_op.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,20 @@ namespace tir {
4242

4343
struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
4444
bool use_dataflow_analysis;
45+
int max_simplification_steps;
4546

4647
TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
4748
TVM_ATTR_FIELD(use_dataflow_analysis)
4849
.describe(
4950
"If true, known buffer values are propagated and used "
5051
"to statically prove statements as no-ops.")
5152
.set_default(false);
53+
TVM_ATTR_FIELD(max_simplification_steps)
54+
.describe(
55+
"If non-zero, RewriteSimplifier will throw an error "
56+
"after the number of steps specified. "
57+
"For use in debug and testing purposes.")
58+
.set_default(0);
5259
}
5360
};
5461

@@ -316,14 +323,19 @@ Pass RemoveNoOp() {
316323

317324
RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
318325
.value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
326+
319327
if (config->use_dataflow_analysis) {
320-
touch_pattern.emplace(f->body);
328+
touch_pattern.emplace(f->body, config->max_simplification_steps);
321329
}
322330

323331
arith::Analyzer analyzer;
332+
analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps);
324333

325-
auto* n = f.CopyOnWrite();
326-
n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
334+
{
335+
auto* write_ptr = f.CopyOnWrite();
336+
write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer,
337+
std::move(touch_pattern), nullptr);
338+
}
327339
return f;
328340
};
329341
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});

tests/python/unittest/test_tir_transform_remove_no_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None:
8686

8787
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
8888
use_dataflow_analysis = False
89+
max_simplification_steps = 0
8990

9091
def transform(self):
9192
def inner(mod):
9293
config = {
9394
"tir.RemoveNoOp": {
9495
"use_dataflow_analysis": self.use_dataflow_analysis,
96+
"max_simplification_steps": self.max_simplification_steps,
9597
}
9698
}
9799
with tvm.transform.PassContext(config=config):
@@ -319,9 +321,16 @@ class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter)
319321
Similar to TestKeepPartiallyOverwrittenLoop, except the first loop
320322
has the same predicate as the second, and can therefore be
321323
removed.
324+
325+
In the past, this test has had performance regressions in which
326+
the runtime increased from a few seconds to nearly ten minutes.
327+
The "max_simplification_steps" parameter is set at twice the
328+
current number of steps required, in order to prevent similar
329+
performance regression.
322330
"""
323331

324332
use_dataflow_analysis = True
333+
max_simplification_steps = 200000
325334

326335
def before(A: T.Buffer(16, "int32")):
327336
for i in T.serial(16):
@@ -347,9 +356,16 @@ class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter):
347356
loop's predicate. So long as the regions written in the first
348357
loop are a subset of those written in the second loop, they can be
349358
removed.
359+
360+
In the past, this test has had performance regressions in which
361+
the runtime increased from a few seconds to nearly ten minutes.
362+
The "max_simplification_steps" parameter is set at twice the
363+
current number of steps required, in order to prevent similar
364+
performance regression.
350365
"""
351366

352367
use_dataflow_analysis = True
368+
max_simplification_steps = 200000
353369

354370
def before(A: T.Buffer(16, "int32")):
355371
for i in T.serial(16):

0 commit comments

Comments
 (0)