Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ class RewriteSimplifier {
/*! \brief Return the currently enabled extensions */
TVM_DLL Extension GetEnabledExtensions() const;

/*! \brief Return the statistics counters */
TVM_DLL ObjectRef GetStatsCounters() const;

/*! \brief Reset the statistics counters */
TVM_DLL void ResetStatsCounters();

/*! \brief Set the maximum allowed number of rewrite steps
*
* By default, the simplifier may perform as many steps as are
* required. If a positive limit is set, then the simplifier will
* throw an exception when exceeding that number of rewrite steps.
* This allows tests to guard against performance regressions.
*
* Note: To maintain accurate usage counters, `Analyzer` instances
* should be re-used wherever possible. For example, TIR
* transformations should declare a single `Analyzer` that is used
* throughout the pass, and utility functions should receive an
* `Analyzer*` from their calling scope.
*/
TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);

private:
friend class Analyzer;
friend class ConstraintContext;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__(self):
self._modular_set = _mod("modular_set")
self._simplify = _mod("Simplify")
self._rewrite_simplify = _mod("rewrite_simplify")
self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats")
self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
Expand Down Expand Up @@ -166,6 +168,13 @@ def rewrite_simplify(self, expr):
"""
return self._rewrite_simplify(expr)

@property
def rewrite_simplify_stats(self):
return self._get_rewrite_simplify_stats()

def reset_rewrite_simplify_stats(self):
self._reset_rewrite_simplify_stats()

def canonical_simplify(self, expr):
"""Simplify expression via canonicalization.

Expand Down
7 changes: 7 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
} else if (name == "rewrite_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
} else if (name == "get_rewrite_simplify_stats") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
*ret = self->rewrite_simplify.GetStatsCounters();
});
} else if (name == "reset_rewrite_simplify_stats") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); });
} else if (name == "canonical_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });
Expand Down
35 changes: 35 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,33 @@ using namespace tir;

// macro for doing simple rewrite
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret)) { \
RecordRewrite(); \
return (ResExpr).Eval(); \
}

// macro for rewrite + recursively rewrite ResExpr
#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret)) { \
RecordRewrite(); \
return RecursiveRewrite((ResExpr).Eval()); \
}

// macro rewrite only if CondExor is true after match.
#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
RecordRewrite(); \
return (ResExpr).Eval(); \
}

// macro rewrite + recursive_rewrite only if CondExor is true after match.
#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
RecordAttemptedRewrite(); \
if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \
RecordRewrite(); \
return RecursiveRewrite((ResExpr).Eval()); \
}

Expand Down Expand Up @@ -211,6 +219,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
return CompareResult::kUnknown;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) {
stats_.nodes_visited++;
return IRMutatorWithAnalyzer::VisitExpr(e);
}

void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
Expand Down Expand Up @@ -350,6 +363,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
literal_constraints_.push_back(Not(negation));
}
}
stats_.constraints_entered++;
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
ICHECK_EQ(literal_constraints_.size(), new_literal_size);
Expand Down Expand Up @@ -2141,9 +2155,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
return impl_->GetEnabledExtensions();
}

ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); }

void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); }

void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) {
impl_->SetMaximumRewriteSteps(maximum);
}

RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}

RewriteSimplifier::~RewriteSimplifier() { delete impl_; }

TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RewriteSimplifierStatsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* ptr = node.as<RewriteSimplifierStatsNode>();
p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited
<< ", constraints_entered = " << ptr->constraints_entered
<< ", rewrites_attempted = " << ptr->rewrites_attempted
<< ", rewrites_performed = " << ptr->rewrites_performed
<< ", max_recursive_depth = " << ptr->max_recursive_depth
<< ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")";
});

} // namespace arith
} // namespace tvm
64 changes: 62 additions & 2 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <unordered_map>
#include <vector>

Expand All @@ -39,6 +40,41 @@ namespace arith {

using namespace tir;

/* Record of
*
* These are intended for debug and testing purposes, to ensure that
* PrimExpr simplifications and TIR passes do not require an excessive
*/
struct RewriteSimplifierStatsNode : Object {
int64_t nodes_visited{0};
int64_t constraints_entered{0};
int64_t rewrites_attempted{0};
int64_t rewrites_performed{0};
int64_t max_recursive_depth{0};
int64_t num_recursive_rewrites{0};

void VisitAttrs(AttrVisitor* v) {
v->Visit("nodes_visited", &nodes_visited);
v->Visit("constraints_entered", &constraints_entered);
v->Visit("rewrites_attempted", &rewrites_attempted);
v->Visit("rewrites_performed", &rewrites_performed);
v->Visit("max_recursive_depth", &max_recursive_depth);
v->Visit("num_recursive_rewrites", &num_recursive_rewrites);
}

static constexpr const char* _type_key = "arith.RewriteSimplifierStats";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object);
};

struct RewriteSimplifierStats : ObjectRef {
explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) {
data_ = make_object<RewriteSimplifierStatsNode>(data);
}

TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode);
};

/*!
* \brief Rewrite-based simplifier.
*
Expand All @@ -50,6 +86,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {

explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}

PrimExpr VisitExpr(const PrimExpr& e) override;

void Update(const Var& var, const PrimExpr& info, bool override_info);
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
Expand Down Expand Up @@ -87,9 +125,29 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
/*! \brief Return the currently enabled extensions */
Extension GetEnabledExtensions() const;

RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); }

void ResetStatsCounters() { stats_ = {}; }

void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; }

protected:
int64_t maximum_rewrite_steps_{0};
RewriteSimplifierStatsNode stats_;

void RecordAttemptedRewrite() { stats_.rewrites_attempted++; }
void RecordRewrite() {
stats_.rewrites_performed++;

ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_)
<< "RewriteSimplifier exceeded maximum number of rewrites allowed ("
<< maximum_rewrite_steps_ << ")";
}

bool is_currently_visiting_{false};

// counter to record recursive rewrite depth.
int recur_depth_{0};
int64_t recur_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;

Expand All @@ -103,7 +161,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
bool recursively_visiting_boolean_{false};

// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
static const constexpr int64_t kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
Expand Down Expand Up @@ -177,8 +235,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
PrimExpr RecursiveRewrite(const PrimExpr& x) {
stats_.num_recursive_rewrites++;
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth);
PrimExpr res = this->VisitExpr(x);
--recur_depth_;
return res;
Expand Down
7 changes: 5 additions & 2 deletions src/tir/analysis/control_flow_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
return buffer_touch;
}

ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
: max_revisits_(max_revisits) {
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int max_simplification_steps,
size_t max_revisits)
: max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) {
ControlFlowGraphBuilder::Build(this, stmt);
ForwardPropagateKnownValues();
BackwardPropagateUnusedValues();
Expand Down Expand Up @@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_fr
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
Expand Down Expand Up @@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
Expand Down
6 changes: 5 additions & 1 deletion src/tir/analysis/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ class ControlFlowGraph {
public:
/* \brief Extract the touch pattern from a TIR statement
*/
explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
explicit ControlFlowGraph(const Stmt& stmt, int max_simplification_steps = 0,
size_t max_revisits = 5);

/* \brief Check if a write is overwritten without impacting final results
*
Expand Down Expand Up @@ -655,6 +656,9 @@ class ControlFlowGraph {

/*! \brief The maximum number of revisits while flowing constraints */
size_t max_revisits_;

/*! \brief The maximum number of revisits while flowing constraints */
int max_simplification_steps_;
};

} // namespace tir
Expand Down
18 changes: 15 additions & 3 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ namespace tir {

struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
bool use_dataflow_analysis;
int64_t max_simplification_steps;

TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
TVM_ATTR_FIELD(use_dataflow_analysis)
.describe(
"If true, known buffer values are propagated and used "
"to statically prove statements as no-ops.")
.set_default(false);
TVM_ATTR_FIELD(max_simplification_steps)
.describe(
"If non-zero, RewriteSimplifier will throw an error "
"after the number of steps specified. "
"For use in debug and testing purposes.")
.set_default(0);
}
};

Expand Down Expand Up @@ -316,14 +323,19 @@ Pass RemoveNoOp() {

RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
.value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());

if (config->use_dataflow_analysis) {
touch_pattern.emplace(f->body);
touch_pattern.emplace(f->body, config->max_simplification_steps);
}

arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps);

auto* n = f.CopyOnWrite();
n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for adding curly braces here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They aren't necessary for correctness here, but were added present to limit the scope of write_ptr. It doesn't have much impact in a function of this size, especially with the return statement just afterward, but for longer functions it can help to limit the scope of variables that are only needed for a few following lines.

That said, the implementation no longer requires updating the signature of NoOpRemover::Apply, so this change has become unrelated to the overall PR.

auto* write_ptr = f.CopyOnWrite();
write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer,
std::move(touch_pattern), nullptr);
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
Expand Down
Loading