Skip to content

Commit 864fd5c

Browse files
authored
[Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr (#16599)
* [Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr Update the `relax.transform.EliminateCommonSubexpr` pass to handle `R.match_cast` bindings, where the argument of the `R.match_cast` has also been de-duplicated. * Fix unit tests failures * Add unit test for avoiding leak of dataflow var * Track all legal de-duplications, in case the first is a DataflowVar * De-duplicate within an if/else, using bindings before the if/else
1 parent e715814 commit 864fd5c

File tree

2 files changed

+411
-190
lines changed

2 files changed

+411
-190
lines changed

src/relax/transform/eliminate_common_subexpr.cc

Lines changed: 125 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -20,223 +20,180 @@
2020

2121
/*!
2222
* \file tvm/relax/transform/eliminate_common_subexpr.cc
23-
* \brief Eliminrate common subexpression pass.
23+
* \brief Eliminate common subexpression pass.
2424
*
2525
* Currently it removes common subexpressions within a Function.
2626
*/
27+
#include <tvm/relax/analysis.h>
2728
#include <tvm/relax/expr_functor.h>
2829
#include <tvm/relax/transform.h>
2930
#include <tvm/relax/utils.h>
3031

31-
#include "utils.h"
32+
#include "../../support/utils.h"
3233

3334
namespace tvm {
3435
namespace relax {
35-
36-
// Checks if a given expression contains an impure subexpression
37-
// Caches the results of checks to avoid revisiting subexpressions
38-
class ImpurityDetector : public ExprVisitor {
39-
public:
40-
bool Detect(const Expr& expr) {
41-
impure_found_ = false;
42-
VisitExpr(expr);
43-
return impure_found_;
36+
namespace {
37+
/* \brief Lookup key for subexpression replacements
38+
*
39+
* The lookup key must contain the expression being bound, along with
40+
* the struct info used for a match cast, if applicable. Using
41+
* `MatchCast` with StructuralEqual and StructuralHash would be almost
42+
* correct, but acts as a point of definition for symbolic variables
43+
* within the output struct info. As a result, it would erroneously
44+
* de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
45+
* `R.match_cast(A, R.Tensor([p,q]))`, even though they define
46+
* different symbolic variables.
47+
*/
48+
struct ReplacementKey {
49+
tvm::relax::Expr bound_value;
50+
tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt;
51+
52+
explicit ReplacementKey(const tvm::relax::Binding& binding)
53+
: bound_value(GetBoundValue(binding)) {
54+
if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) {
55+
match_cast = ptr->struct_info;
56+
}
4457
}
4558

46-
void VisitExpr(const Expr& expr) {
47-
// already checked: do not revisit
48-
if (purity_map_.count(expr)) {
49-
impure_found_ = impure_found_ || !purity_map_.at(expr);
50-
return;
51-
}
59+
friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
60+
tvm::StructuralEqual eq;
61+
return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
62+
}
63+
};
5264

53-
// in principle, we could stop checking once we find an impurity,
54-
// but not doing so lets us fully populate the cache
65+
} // namespace
66+
} // namespace relax
67+
} // namespace tvm
5568

56-
// store the previous state so we could assess the purity of this subexpression alone
57-
bool prev_state = impure_found_;
58-
impure_found_ = false;
59-
ExprVisitor::VisitExpr(expr);
60-
// if impure_found_ remains false, then the expression is pure
61-
purity_map_[expr] = !impure_found_;
62-
impure_found_ = prev_state || impure_found_;
69+
/* \brief Definition of std::hash<ReplacementKey>
70+
*
71+
* Specialization of std::hash must occur outside of tvm::relax
72+
* namespace, and before its usage in the constructor of
73+
* `CommonSubexprEliminator`.
74+
*/
75+
template <>
76+
struct std::hash<tvm::relax::ReplacementKey> {
77+
std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
78+
tvm::StructuralHash hasher;
79+
return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast));
6380
}
81+
};
6482

65-
void VisitExpr_(const CallNode* call) {
66-
// the only possible impurities can come from call nodes
67-
bool is_impure = IsImpureCall(GetRef<Call>(call));
68-
impure_found_ = impure_found_ || is_impure;
69-
ExprVisitor::VisitExpr_(call);
70-
}
83+
namespace tvm {
84+
namespace relax {
7185

72-
private:
73-
bool impure_found_ = false;
74-
std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
75-
};
86+
namespace {
7687

77-
class SubexprCounter : public ExprVisitor {
88+
class CommonSubexprEliminator : public ExprMutator {
7889
public:
79-
static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) {
80-
SubexprCounter visitor;
81-
visitor(expr);
82-
return visitor.count_map_;
83-
}
90+
explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {}
8491

85-
// overriding VisitExpr ensures we do this for every subexpression
86-
void VisitExpr(const Expr& e) override {
87-
// Cases we ignore because we will not substitute them:
88-
// 1. Vars of all kinds
89-
// 2. Op nodes (nothing we can do)
90-
// 3. PrimValue nodes (not much benefit from binding to a var)
91-
// 4. StringImm nodes (not much benefit from binding to a var)
92-
// 5. Scalar constants (not much benefit from binding to a var)
93-
// 6. Shape expressions (exist to hold several PrimValue objects)
94-
// 7. DataType nodes (no need to modify dtype nodes)
95-
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
96-
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
97-
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
98-
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
99-
e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
100-
// also if e has an impure subexpression, we will not deduplicate it
101-
if (!impurity_detector_.Detect(e)) {
102-
int count = 0;
103-
if (count_map_.count(e)) {
104-
count = count_map_.at(e);
105-
}
106-
count_map_[e] = count + 1;
107-
}
108-
}
92+
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
93+
auto cache_vars = var_remap_;
94+
auto output = ExprMutator::VisitBindingBlock_(block);
10995

110-
// Only visit the interior of objects that we might still keep
111-
// around. Otherwise, double-counting these would lead to extra
112-
// variable bindings.
113-
//
114-
// Before:
115-
// y = f(a+b)
116-
// z = f(a+b)
117-
//
118-
// Expected:
119-
// y = f(a+b) // De-duped from (y==z)
120-
// z = y
121-
//
122-
// Erroneous output:
123-
// c = a+b // Incorrect, a+b only has a single usage.
124-
// y = f(c) // De-duped from
125-
// z = y
126-
//
127-
if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) {
128-
ExprVisitor::VisitExpr(e);
96+
for (auto& [key, replacements] : expr_replacements_) {
97+
replacements.erase(
98+
std::remove_if(replacements.begin(), replacements.end(),
99+
[](const Var& var) -> bool { return var->IsInstance<DataflowVarNode>(); }),
100+
replacements.end());
129101
}
102+
103+
var_remap_ = cache_vars;
104+
return output;
130105
}
131106

132-
// do not visit inner functions: we will do CSE within those
133-
void VisitExpr_(const FunctionNode* func) override {}
107+
void VisitBinding(const Binding& binding) override {
108+
Expr bound_value = VisitExpr(GetBoundValue(binding));
109+
110+
Binding output_binding = [&]() -> Binding {
111+
if (binding.as<VarBindingNode>()) {
112+
return VarBinding(binding->var, bound_value);
113+
} else if (auto match_cast = binding.as<MatchCastNode>()) {
114+
return MatchCast(binding->var, bound_value, match_cast->struct_info);
115+
} else {
116+
LOG(FATAL) << "Binding must be either VarBinding or MatchCast, "
117+
<< "but was " << binding->GetTypeKey();
118+
}
119+
}();
134120

135-
// we are not going to do replacements inside struct info to avoid binding lots of reused shapes
136-
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
121+
ReplacementKey lookup_key(output_binding);
137122

138-
private:
139-
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
140-
ImpurityDetector impurity_detector_;
141-
};
123+
if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) {
124+
VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value;
142125

143-
class CommonSubexprEliminator : public ExprMutator {
144-
public:
145-
explicit CommonSubexprEliminator(
146-
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
147-
bool call_only = false)
148-
: count_map_(std::move(count_map)), call_only_(call_only) {}
149-
150-
// overriding here ensures we visit every subexpression
151-
Expr VisitExpr(const Expr& e) override {
152-
if (call_only_ && !e->IsInstance<CallNode>()) {
153-
return ExprMutator::VisitExpr(e);
154-
}
155-
if (count_map_.count(e) && count_map_.at(e) > 1) {
156-
// if we already have a mapping for it, get it
157-
if (replacements_.count(e)) {
158-
return replacements_.at(e);
159-
}
160-
// Otherwise, insert a new binding for the current expression.
161-
// Visit before emitting to do inner replacements
162-
Expr new_e = ExprMutator::VisitExpr(e);
163-
Var v = builder_->Emit(new_e);
164-
replacements_[e] = v;
165-
return v;
166-
}
167-
return ExprMutator::VisitExpr(e);
168-
}
126+
} else if (ContainsImpureCall(bound_value)) {
127+
VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value;
169128

170-
// we are not going to do replacements inside struct info to avoid binding lots of reused shapes
171-
StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
172-
return struct_info;
173-
}
129+
} else if (auto it = expr_replacements_.find(lookup_key);
130+
it != expr_replacements_.end() && it->second.size()) {
131+
VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0]
132+
<< ". The duplicate binding of this value to " << binding->var
133+
<< " will be replaced with a trivial binding, "
134+
<< "and occurrences of " << binding->var << " will be replaced with "
135+
<< it->second[0];
136+
output_binding = VarBinding(binding->var, it->second[0]);
137+
var_remap_.insert({binding->var->vid, it->second[0]});
138+
it->second.push_back(binding->var);
174139

175-
Expr VisitExpr_(const FunctionNode* op) override {
176-
Function func = GetRef<Function>(op);
140+
} else {
141+
VLOG(1) << "Value " << bound_value << " is bound to " << binding->var
142+
<< " and may be de-duplicated if it occurs again.";
177143

178-
auto cache = SubexprCounter::Count(op->body);
179-
std::swap(cache, count_map_);
180-
Expr output = ExprMutator::VisitExpr_(op);
181-
std::swap(cache, count_map_);
144+
expr_replacements_[lookup_key].push_back(binding->var);
145+
}
182146

183-
return output;
147+
builder_->EmitNormalized(output_binding);
184148
}
185149

186-
void VisitBinding_(const VarBindingNode* binding) override {
187-
// no need to visit var def because the struct info isn't going to change
188-
Expr new_value = RegisterBoundValue(binding->var, binding->value);
189-
190-
if (new_value.same_as(binding->value)) {
191-
builder_->EmitNormalized(GetRef<VarBinding>(binding));
150+
Expr VisitExpr_(const FunctionNode* op) override {
151+
// If we have accumulated any state, visit the function in a fresh
152+
// copy of the mutator, to avoid replacing a child-scope
153+
// expression with a parent-scope binding, or vice versa.
154+
if (expr_replacements_.size() || var_remap_.size()) {
155+
return VisitWithCleanScope(GetRef<Expr>(op));
192156
} else {
193-
// no need to renormalize new_value because all replacements are with vars
194-
builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span));
157+
return ExprMutator::VisitExpr_(op);
195158
}
196159
}
197160

198-
void VisitBinding_(const MatchCastNode* binding) override {
199-
// no need to visit var def because the struct info isn't going to change
200-
Expr new_value = RegisterBoundValue(binding->var, binding->value);
201-
202-
// re-emit old binding if nothing changes
203-
if (new_value.same_as(binding->value)) {
204-
builder_->EmitNormalized(GetRef<MatchCast>(binding));
161+
Expr VisitExpr_(const IfNode* op) override {
162+
Expr cond = VisitExpr(op->cond);
163+
Expr true_branch = VisitWithInnerScope(op->true_branch);
164+
Expr false_branch = VisitWithInnerScope(op->false_branch);
165+
if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) &&
166+
op->false_branch.same_as(false_branch) &&
167+
VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) {
168+
return GetRef<Expr>(op);
205169
} else {
206-
// no need to renormalize new_value because all replacements are with vars
207-
builder_->EmitNormalized(
208-
MatchCast(binding->var, new_value, binding->struct_info, binding->span));
170+
return If(cond, true_branch, false_branch, op->span);
209171
}
210172
}
211173

212174
private:
213-
Expr RegisterBoundValue(Var var, Expr bound_value) {
214-
// special case: if we are processing a binding
215-
// and this is the first time we've encountered it,
216-
// we will use the binding's var for the mapping
217-
bool newly_replaced = false;
218-
if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 &&
219-
!replacements_.count(bound_value)) {
220-
replacements_[bound_value] = var;
221-
newly_replaced = true;
222-
}
175+
Expr VisitWithInnerScope(Expr expr) {
176+
auto cached_vars = var_remap_;
177+
auto cached_exprs = expr_replacements_;
178+
auto output = VisitExpr(expr);
179+
var_remap_ = cached_vars;
180+
expr_replacements_ = cached_exprs;
181+
return output;
182+
}
223183

224-
if (newly_replaced) {
225-
// If we've just added the mapping, using the overridden visitor will
226-
// just return the var, which we don't want, so we will use
227-
// the superclass VisitExpr to do inner substitutions
228-
return ExprMutator::VisitExpr(bound_value);
229-
}
230-
return VisitExpr(bound_value);
184+
Expr VisitWithCleanScope(Expr expr) {
185+
CommonSubexprEliminator clean_mutator(call_only_);
186+
return clean_mutator.VisitExpr(expr);
231187
}
232188

233-
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
234-
std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
235189
bool call_only_{false};
190+
std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
236191
};
237192

193+
} // namespace
194+
238195
Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) {
239-
CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only);
196+
CommonSubexprEliminator mutator(call_only);
240197
return mutator(expr);
241198
}
242199

0 commit comments

Comments
 (0)