|
20 | 20 |
|
21 | 21 | /*! |
22 | 22 | * \file tvm/relax/transform/eliminate_common_subexpr.cc |
23 | | - * \brief Eliminrate common subexpression pass. |
| 23 | + * \brief Eliminate common subexpression pass. |
24 | 24 | * |
25 | 25 | * Currently it removes common subexpressions within a Function. |
26 | 26 | */ |
| 27 | +#include <tvm/relax/analysis.h> |
27 | 28 | #include <tvm/relax/expr_functor.h> |
28 | 29 | #include <tvm/relax/transform.h> |
29 | 30 | #include <tvm/relax/utils.h> |
30 | 31 |
|
31 | | -#include "utils.h" |
| 32 | +#include "../../support/utils.h" |
32 | 33 |
|
33 | 34 | namespace tvm { |
34 | 35 | namespace relax { |
35 | | - |
36 | | -// Checks if a given expression contains an impure subexpression |
37 | | -// Caches the results of checks to avoid revisiting subexpressions |
38 | | -class ImpurityDetector : public ExprVisitor { |
39 | | - public: |
40 | | - bool Detect(const Expr& expr) { |
41 | | - impure_found_ = false; |
42 | | - VisitExpr(expr); |
43 | | - return impure_found_; |
| 36 | +namespace { |
| 37 | +/* \brief Lookup key for subexpression replacements |
| 38 | + * |
| 39 | + * The lookup key must contain the expression being bound, along with |
| 40 | + * the struct info used for a match cast, if applicable. Using |
| 41 | + * `MatchCast` with StructuralEqual and StructuralHash would be almost |
| 42 | + * correct, but acts as a point of definition for symbolic variables |
| 43 | + * within the output struct info. As a result, it would erroneously |
| 44 | + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and |
| 45 | + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define |
| 46 | + * different symbolic variables. |
| 47 | + */ |
| 48 | +struct ReplacementKey { |
| 49 | + tvm::relax::Expr bound_value; |
| 50 | + tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt; |
| 51 | + |
| 52 | + explicit ReplacementKey(const tvm::relax::Binding& binding) |
| 53 | + : bound_value(GetBoundValue(binding)) { |
| 54 | + if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) { |
| 55 | + match_cast = ptr->struct_info; |
| 56 | + } |
44 | 57 | } |
45 | 58 |
|
46 | | - void VisitExpr(const Expr& expr) { |
47 | | - // already checked: do not revisit |
48 | | - if (purity_map_.count(expr)) { |
49 | | - impure_found_ = impure_found_ || !purity_map_.at(expr); |
50 | | - return; |
51 | | - } |
| 59 | + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { |
| 60 | + tvm::StructuralEqual eq; |
| 61 | + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); |
| 62 | + } |
| 63 | +}; |
52 | 64 |
|
53 | | - // in principle, we could stop checking once we find an impurity, |
54 | | - // but not doing so lets us fully populate the cache |
| 65 | +} // namespace |
| 66 | +} // namespace relax |
| 67 | +} // namespace tvm |
55 | 68 |
|
56 | | - // store the previous state so we could assess the purity of this subexpression alone |
57 | | - bool prev_state = impure_found_; |
58 | | - impure_found_ = false; |
59 | | - ExprVisitor::VisitExpr(expr); |
60 | | - // if impure_found_ remains false, then the expression is pure |
61 | | - purity_map_[expr] = !impure_found_; |
62 | | - impure_found_ = prev_state || impure_found_; |
| 69 | +/* \brief Definition of std::hash<ReplacementKey> |
| 70 | + * |
| 71 | + * Specialization of std::hash must occur outside of tvm::relax |
| 72 | + * namespace, and before its usage in the constructor of |
| 73 | + * `CommonSubexprEliminator`. |
| 74 | + */ |
| 75 | +template <> |
| 76 | +struct std::hash<tvm::relax::ReplacementKey> { |
| 77 | + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { |
| 78 | + tvm::StructuralHash hasher; |
| 79 | + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); |
63 | 80 | } |
| 81 | +}; |
64 | 82 |
|
65 | | - void VisitExpr_(const CallNode* call) { |
66 | | - // the only possible impurities can come from call nodes |
67 | | - bool is_impure = IsImpureCall(GetRef<Call>(call)); |
68 | | - impure_found_ = impure_found_ || is_impure; |
69 | | - ExprVisitor::VisitExpr_(call); |
70 | | - } |
| 83 | +namespace tvm { |
| 84 | +namespace relax { |
71 | 85 |
|
72 | | - private: |
73 | | - bool impure_found_ = false; |
74 | | - std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_; |
75 | | -}; |
| 86 | +namespace { |
76 | 87 |
|
77 | | -class SubexprCounter : public ExprVisitor { |
| 88 | +class CommonSubexprEliminator : public ExprMutator { |
78 | 89 | public: |
79 | | - static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) { |
80 | | - SubexprCounter visitor; |
81 | | - visitor(expr); |
82 | | - return visitor.count_map_; |
83 | | - } |
| 90 | + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} |
84 | 91 |
|
85 | | - // overriding VisitExpr ensures we do this for every subexpression |
86 | | - void VisitExpr(const Expr& e) override { |
87 | | - // Cases we ignore because we will not substitute them: |
88 | | - // 1. Vars of all kinds |
89 | | - // 2. Op nodes (nothing we can do) |
90 | | - // 3. PrimValue nodes (not much benefit from binding to a var) |
91 | | - // 4. StringImm nodes (not much benefit from binding to a var) |
92 | | - // 5. Scalar constants (not much benefit from binding to a var) |
93 | | - // 6. Shape expressions (exist to hold several PrimValue objects) |
94 | | - // 7. DataType nodes (no need to modify dtype nodes) |
95 | | - if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || |
96 | | - e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || |
97 | | - e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() || |
98 | | - e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() || |
99 | | - e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) { |
100 | | - // also if e has an impure subexpression, we will not deduplicate it |
101 | | - if (!impurity_detector_.Detect(e)) { |
102 | | - int count = 0; |
103 | | - if (count_map_.count(e)) { |
104 | | - count = count_map_.at(e); |
105 | | - } |
106 | | - count_map_[e] = count + 1; |
107 | | - } |
108 | | - } |
| 92 | + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { |
| 93 | + auto cache_vars = var_remap_; |
| 94 | + auto output = ExprMutator::VisitBindingBlock_(block); |
109 | 95 |
|
110 | | - // Only visit the interior of objects that we might still keep |
111 | | - // around. Otherwise, double-counting these would lead to extra |
112 | | - // variable bindings. |
113 | | - // |
114 | | - // Before: |
115 | | - // y = f(a+b) |
116 | | - // z = f(a+b) |
117 | | - // |
118 | | - // Expected: |
119 | | - // y = f(a+b) // De-duped from (y==z) |
120 | | - // z = y |
121 | | - // |
122 | | - // Erroneous output: |
123 | | - // c = a+b // Incorrect, a+b only has a single usage. |
124 | | - // y = f(c) // De-duped from |
125 | | - // z = y |
126 | | - // |
127 | | - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { |
128 | | - ExprVisitor::VisitExpr(e); |
| 96 | + for (auto& [key, replacements] : expr_replacements_) { |
| 97 | + replacements.erase( |
| 98 | + std::remove_if(replacements.begin(), replacements.end(), |
| 99 | + [](const Var& var) -> bool { return var->IsInstance<DataflowVarNode>(); }), |
| 100 | + replacements.end()); |
129 | 101 | } |
| 102 | + |
| 103 | + var_remap_ = cache_vars; |
| 104 | + return output; |
130 | 105 | } |
131 | 106 |
|
132 | | - // do not visit inner functions: we will do CSE within those |
133 | | - void VisitExpr_(const FunctionNode* func) override {} |
| 107 | + void VisitBinding(const Binding& binding) override { |
| 108 | + Expr bound_value = VisitExpr(GetBoundValue(binding)); |
| 109 | + |
| 110 | + Binding output_binding = [&]() -> Binding { |
| 111 | + if (binding.as<VarBindingNode>()) { |
| 112 | + return VarBinding(binding->var, bound_value); |
| 113 | + } else if (auto match_cast = binding.as<MatchCastNode>()) { |
| 114 | + return MatchCast(binding->var, bound_value, match_cast->struct_info); |
| 115 | + } else { |
| 116 | + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " |
| 117 | + << "but was " << binding->GetTypeKey(); |
| 118 | + } |
| 119 | + }(); |
134 | 120 |
|
135 | | - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes |
136 | | - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} |
| 121 | + ReplacementKey lookup_key(output_binding); |
137 | 122 |
|
138 | | - private: |
139 | | - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; |
140 | | - ImpurityDetector impurity_detector_; |
141 | | -}; |
| 123 | + if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) { |
| 124 | + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; |
142 | 125 |
|
143 | | -class CommonSubexprEliminator : public ExprMutator { |
144 | | - public: |
145 | | - explicit CommonSubexprEliminator( |
146 | | - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map, |
147 | | - bool call_only = false) |
148 | | - : count_map_(std::move(count_map)), call_only_(call_only) {} |
149 | | - |
150 | | - // overriding here ensures we visit every subexpression |
151 | | - Expr VisitExpr(const Expr& e) override { |
152 | | - if (call_only_ && !e->IsInstance<CallNode>()) { |
153 | | - return ExprMutator::VisitExpr(e); |
154 | | - } |
155 | | - if (count_map_.count(e) && count_map_.at(e) > 1) { |
156 | | - // if we already have a mapping for it, get it |
157 | | - if (replacements_.count(e)) { |
158 | | - return replacements_.at(e); |
159 | | - } |
160 | | - // Otherwise, insert a new binding for the current expression. |
161 | | - // Visit before emitting to do inner replacements |
162 | | - Expr new_e = ExprMutator::VisitExpr(e); |
163 | | - Var v = builder_->Emit(new_e); |
164 | | - replacements_[e] = v; |
165 | | - return v; |
166 | | - } |
167 | | - return ExprMutator::VisitExpr(e); |
168 | | - } |
| 126 | + } else if (ContainsImpureCall(bound_value)) { |
| 127 | + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; |
169 | 128 |
|
170 | | - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes |
171 | | - StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { |
172 | | - return struct_info; |
173 | | - } |
| 129 | + } else if (auto it = expr_replacements_.find(lookup_key); |
| 130 | + it != expr_replacements_.end() && it->second.size()) { |
| 131 | + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0] |
| 132 | + << ". The duplicate binding of this value to " << binding->var |
| 133 | + << " will be replaced with a trivial binding, " |
| 134 | + << "and occurrences of " << binding->var << " will be replaced with " |
| 135 | + << it->second[0]; |
| 136 | + output_binding = VarBinding(binding->var, it->second[0]); |
| 137 | + var_remap_.insert({binding->var->vid, it->second[0]}); |
| 138 | + it->second.push_back(binding->var); |
174 | 139 |
|
175 | | - Expr VisitExpr_(const FunctionNode* op) override { |
176 | | - Function func = GetRef<Function>(op); |
| 140 | + } else { |
| 141 | + VLOG(1) << "Value " << bound_value << " is bound to " << binding->var |
| 142 | + << " and may be de-duplicated if it occurs again."; |
177 | 143 |
|
178 | | - auto cache = SubexprCounter::Count(op->body); |
179 | | - std::swap(cache, count_map_); |
180 | | - Expr output = ExprMutator::VisitExpr_(op); |
181 | | - std::swap(cache, count_map_); |
| 144 | + expr_replacements_[lookup_key].push_back(binding->var); |
| 145 | + } |
182 | 146 |
|
183 | | - return output; |
| 147 | + builder_->EmitNormalized(output_binding); |
184 | 148 | } |
185 | 149 |
|
186 | | - void VisitBinding_(const VarBindingNode* binding) override { |
187 | | - // no need to visit var def because the struct info isn't going to change |
188 | | - Expr new_value = RegisterBoundValue(binding->var, binding->value); |
189 | | - |
190 | | - if (new_value.same_as(binding->value)) { |
191 | | - builder_->EmitNormalized(GetRef<VarBinding>(binding)); |
| 150 | + Expr VisitExpr_(const FunctionNode* op) override { |
| 151 | + // If we have accumulated any state, visit the function in a fresh |
| 152 | + // copy of the mutator, to avoid replacing a child-scope |
| 153 | + // expression with a parent-scope binding, or vice versa. |
| 154 | + if (expr_replacements_.size() || var_remap_.size()) { |
| 155 | + return VisitWithCleanScope(GetRef<Expr>(op)); |
192 | 156 | } else { |
193 | | - // no need to renormalize new_value because all replacements are with vars |
194 | | - builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); |
| 157 | + return ExprMutator::VisitExpr_(op); |
195 | 158 | } |
196 | 159 | } |
197 | 160 |
|
198 | | - void VisitBinding_(const MatchCastNode* binding) override { |
199 | | - // no need to visit var def because the struct info isn't going to change |
200 | | - Expr new_value = RegisterBoundValue(binding->var, binding->value); |
201 | | - |
202 | | - // re-emit old binding if nothing changes |
203 | | - if (new_value.same_as(binding->value)) { |
204 | | - builder_->EmitNormalized(GetRef<MatchCast>(binding)); |
| 161 | + Expr VisitExpr_(const IfNode* op) override { |
| 162 | + Expr cond = VisitExpr(op->cond); |
| 163 | + Expr true_branch = VisitWithInnerScope(op->true_branch); |
| 164 | + Expr false_branch = VisitWithInnerScope(op->false_branch); |
| 165 | + if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) && |
| 166 | + op->false_branch.same_as(false_branch) && |
| 167 | + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { |
| 168 | + return GetRef<Expr>(op); |
205 | 169 | } else { |
206 | | - // no need to renormalize new_value because all replacements are with vars |
207 | | - builder_->EmitNormalized( |
208 | | - MatchCast(binding->var, new_value, binding->struct_info, binding->span)); |
| 170 | + return If(cond, true_branch, false_branch, op->span); |
209 | 171 | } |
210 | 172 | } |
211 | 173 |
|
212 | 174 | private: |
213 | | - Expr RegisterBoundValue(Var var, Expr bound_value) { |
214 | | - // special case: if we are processing a binding |
215 | | - // and this is the first time we've encountered it, |
216 | | - // we will use the binding's var for the mapping |
217 | | - bool newly_replaced = false; |
218 | | - if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 && |
219 | | - !replacements_.count(bound_value)) { |
220 | | - replacements_[bound_value] = var; |
221 | | - newly_replaced = true; |
222 | | - } |
| 175 | + Expr VisitWithInnerScope(Expr expr) { |
| 176 | + auto cached_vars = var_remap_; |
| 177 | + auto cached_exprs = expr_replacements_; |
| 178 | + auto output = VisitExpr(expr); |
| 179 | + var_remap_ = cached_vars; |
| 180 | + expr_replacements_ = cached_exprs; |
| 181 | + return output; |
| 182 | + } |
223 | 183 |
|
224 | | - if (newly_replaced) { |
225 | | - // If we've just added the mapping, using the overridden visitor will |
226 | | - // just return the var, which we don't want, so we will use |
227 | | - // the superclass VisitExpr to do inner substitutions |
228 | | - return ExprMutator::VisitExpr(bound_value); |
229 | | - } |
230 | | - return VisitExpr(bound_value); |
| 184 | + Expr VisitWithCleanScope(Expr expr) { |
| 185 | + CommonSubexprEliminator clean_mutator(call_only_); |
| 186 | + return clean_mutator.VisitExpr(expr); |
231 | 187 | } |
232 | 188 |
|
233 | | - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; |
234 | | - std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_; |
235 | 189 | bool call_only_{false}; |
| 190 | + std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_; |
236 | 191 | }; |
237 | 192 |
|
| 193 | +} // namespace |
| 194 | + |
238 | 195 | Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) { |
239 | | - CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only); |
| 196 | + CommonSubexprEliminator mutator(call_only); |
240 | 197 | return mutator(expr); |
241 | 198 | } |
242 | 199 |
|
|
0 commit comments