2727#include " ../op/call/call.h"
2828#include " tvm/relay/analysis.h"
2929#include " tvm/relay/expr_functor.h"
30- #include " tvm/relay/transform.h"
3130
3231namespace tvm {
3332namespace relay {
3433namespace transforms {
3534namespace {
3635
37- /* !
38- * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should
39- * be processed by a pass using \p compiler_filter. Otherwise returns null.
40- */
41- const FunctionNode* AsFunctionNode (const Expr& expr, const std::string& compiler_filter) {
42- if (const auto * function_node = expr.as <FunctionNode>()) {
43- Optional<String> opt_compiler = function_node->GetAttr <String>(attr::kCompiler );
44- if (opt_compiler.defined () &&
45- (compiler_filter.empty () || opt_compiler.value () == compiler_filter)) {
46- return function_node;
47- }
48- }
49- return nullptr ;
50- }
51-
5236/* !
5337 * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given
5438 * module will be extended with the newly outlined functions.
@@ -60,31 +44,35 @@ class Outliner : public MixedModeMutator {
6044
6145 Expr Rewrite_ (const CallNode* pre , const Expr& post ) final {
6246 Call new_call = Downcast<Call>(post );
63- if (const auto * function_node = AsFunctionNode (new_call->op , compiler_filter_)) {
64- auto function = GetRef<Function>(function_node);
65- DCHECK (FreeVars (function).empty ()) << " Function marked with '" << attr::kCompiler
66- << " ' attribute should not have free variables" ;
67- // Ask the cache to supply a unique global var for this function.
68- GlobalVar global_symbol = cache_->GetGlobalSymbol (function);
69- // Depending on the cache's implementation, two structurally equal (but not object
70- // equal) functions may be assigned the same global symbol. If so we'll lift it just
71- // once, but rewrite all the calls.
72- if (!mod_->ContainGlobalVar (global_symbol->name_hint )) {
73- function =
74- WithAttr (std::move (function), tvm::attr::kGlobalSymbol , global_symbol->name_hint );
75- mod_->Add (global_symbol, function);
47+ if (const auto * function_node = new_call->op .as <FunctionNode>()) {
48+ Optional<String> opt_compiler = function_node->GetAttr <String>(attr::kCompiler );
49+ if (opt_compiler.defined () &&
50+ (compiler_filter_.empty () || opt_compiler.value () == compiler_filter_)) {
51+ auto function = GetRef<Function>(function_node);
52+ DCHECK (FreeVars (function).empty ()) << " Function marked with '" << attr::kCompiler
53+ << " ' attribute should not have free variables" ;
54+ // Ask the cache to supply a unique global var for this function.
55+ GlobalVar global_symbol = cache_->GetGlobalSymbol (function);
56+ // Depending on the cache's implementation, two structurally equal (but not object equal)
57+ // functions may be assigned the same global symbol. If so we'll lift it just once, but
58+ // rewrite all the calls.
59+ if (!mod_->ContainGlobalVar (global_symbol->name_hint )) {
60+ function =
61+ WithAttr (std::move (function), tvm::attr::kGlobalSymbol , global_symbol->name_hint );
62+ mod_->Add (global_symbol, function);
63+ }
64+ // Update the call.
65+ return WithFields (new_call, global_symbol);
7666 }
77- // Update the call.
78- return WithFields (new_call, global_symbol);
7967 }
8068 return post ;
8169 }
8270
8371 private:
8472 /* !
8573 * \brief A cached mapping from functions to global variables. Depending on the implementation
86- * the cache may generate fresh symbols or require the function to already have a
87- * "global_symbol" attribute, and may share symbols between structurally equal functions.
74+ * the cache may generate fresh symbols or require the function to already have a "global_symbol"
75+ * attribute, and may share symbols between structurally equal functions.
8876 */
8977 GlobalSymbolCache* cache_;
9078 /* ! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
@@ -93,72 +81,6 @@ class Outliner : public MixedModeMutator {
9381 IRModule mod_;
9482};
9583
96- /* !
97- * \brief Inline immediate calls to "Composite" functions.
98- */
99- class InnerInliner : public MixedModeMutator {
100- public:
101- InnerInliner () = default ;
102-
103- private:
104- using MixedModeMutator::Rewrite_;
105-
106- Expr Rewrite_ (const CallNode* pre , const Expr& post ) final {
107- Call new_call = Downcast<Call>(post );
108- if (const auto * function_node = new_call->op .as <FunctionNode>()) {
109- ICHECK (function_node->GetAttr <String>(attr::kComposite ).defined ());
110- ICHECK_EQ (function_node->params .size (), new_call->args .size ());
111- Map<Var, Expr> subst;
112- for (size_t i = 0 ; i < new_call->args .size (); ++i) {
113- subst.Set (function_node->params [i], new_call->args [i]);
114- }
115- return Bind (function_node->body , subst);
116- }
117- return post ;
118- }
119- };
120-
121- /* !
122- * \brief Inline calls to global "Compiler" functions with global var in \p global_vars.
123- * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body
124- * are inlined.
125- */
126- class OuterInliner : public MixedModeMutator {
127- public:
128- OuterInliner (IRModule mod, Array<GlobalVar> global_vars_)
129- : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {}
130-
131- private:
132- using MixedModeMutator::Rewrite_;
133-
134- Expr Rewrite_ (const CallNode* pre , const Expr& post ) final {
135- Call new_call = Downcast<Call>(post );
136- if (const auto * global_var_node = new_call->op .as <GlobalVarNode>()) {
137- auto global_var = GetRef<GlobalVar>(global_var_node);
138- if (std::find (global_vars_.begin (), global_vars_.end (), global_var) != global_vars_.end ()) {
139- BaseFunc base_func = mod_->Lookup (global_var);
140- const auto * function_node = base_func.as <FunctionNode>();
141- ICHECK (function_node);
142- ICHECK (function_node->GetAttr <String>(attr::kCompiler ).defined ());
143- ICHECK_EQ (function_node->params .size (), new_call->args .size ());
144- Map<Var, Expr> subst;
145- for (size_t i = 0 ; i < new_call->args .size (); ++i) {
146- subst.Set (function_node->params [i], new_call->args [i]);
147- }
148- Expr new_body = InnerInliner ().VisitExpr (function_node->body );
149- return Bind (new_body, subst);
150- }
151- }
152- return post ;
153- }
154-
155- private:
156- /* ! \brief Original module we are processing. */
157- IRModule mod_;
158- /* ! \brief Global vars of functions to inline. */
159- Array<GlobalVar> global_vars_;
160- };
161-
16284} // namespace
16385
16486GlobalSymbolCache::~GlobalSymbolCache () = default ;
@@ -184,18 +106,17 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
184106 runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
185107 [cache = std::move (cache), compiler_filter = std::move (compiler_filter)](
186108 IRModule mod, transform::PassContext ctx) {
187- VLOG (1 ) << " OutlineCompilerFunctions input:" << std::endl << PrettyPrint (mod);
188- IRModule output_mod = mod->ShallowCopy ();
109+ IRModule output_mod = GetRef<IRModule>(mod.CopyOnWrite ());
189110 for (const auto & kv : mod->functions ) {
190- if (const auto * function_node = AsOptimizableFunctionNode (kv.second )) {
111+ const FunctionNode* function_node = AsOptimizableFunctionNode (kv.second );
112+ if (function_node) {
191113 Expr new_body =
192114 Outliner (cache.get (), compiler_filter, output_mod).VisitExpr (function_node->body );
193115 Function new_function =
194116 WithFields (GetRef<Function>(function_node), /* opt_params=*/ {}, new_body);
195117 output_mod->Add (kv.first , new_function);
196118 }
197119 }
198- VLOG (1 ) << " OutlineCompilerFunctions result:" << std::endl << PrettyPrint (output_mod);
199120 return output_mod;
200121 };
201122
@@ -211,57 +132,31 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
211132transform::Pass MarkCompilerFunctionsAsExtern (std::string compiler_filter) {
212133 runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
213134 [compiler_filter = std::move (compiler_filter)](IRModule mod, transform::PassContext ctx) {
214- VLOG (1 ) << " MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint (mod);
215135 IRModule output_mod = mod->ShallowCopy ();
216136 for (const auto & kv : mod->functions ) {
217- if (const auto * function_node = AsFunctionNode (kv.second , compiler_filter)) {
218- auto new_function =
219- WithFields (GetRef<Function>(function_node), function_node->params ,
220- function_node->body , function_node->ret_type , function_node->type_params ,
221- /* erase attributes */ DictAttrs (Map<String, ObjectRef>()));
222- new_function = WithAttr (std::move (new_function), attr::kExtern , Integer (1 ));
223- output_mod->Update (kv.first , new_function);
137+ if (const auto * function_node = kv.second .as <FunctionNode>()) {
138+ Optional<String> opt_compiler = function_node->GetAttr <String>(attr::kCompiler );
139+ if (opt_compiler.defined () &&
140+ (compiler_filter.empty () || opt_compiler.value () == compiler_filter)) {
141+ auto new_function = WithFields (
142+ GetRef<Function>(function_node), function_node->params , function_node->body ,
143+ function_node->ret_type , function_node->type_params ,
144+ /* erase attributes */ DictAttrs (Map<String, ObjectRef>()));
145+ new_function = WithAttr (std::move (new_function), attr::kExtern , Integer (1 ));
146+ output_mod->Update (kv.first , new_function);
147+ }
224148 }
225149 }
226- VLOG (1 ) << " MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint (output_mod);
227150 return output_mod;
228151 };
229152
230153 return tvm::transform::CreateModulePass (pass_func, 0 , " MarkCompilerFunctionsAsExtern" , {});
231154}
232155
233- transform::Pass InlineCompilerFunctions (Array<GlobalVar> global_vars) {
234- runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
235- [global_vars = std::move (global_vars)](IRModule mod, transform::PassContext ctx) {
236- VLOG (1 ) << " InlineCompilerFunctions with global_vars: " << PrettyPrint (global_vars);
237- if (global_vars.empty ()) {
238- return mod;
239- }
240- VLOG (1 ) << " InlineCompilerFunctions input:" << std::endl << PrettyPrint (mod);
241- IRModule output_mod = mod->ShallowCopy ();
242- for (const auto & kv : mod->functions ) {
243- if (std::find (global_vars.begin (), global_vars.end (), kv.first ) != global_vars.end ()) {
244- output_mod->Remove (kv.first );
245- } else if (const auto * function_node = AsOptimizableFunctionNode (kv.second )) {
246- Expr new_body = OuterInliner (mod, global_vars).VisitExpr (function_node->body );
247- Function new_function =
248- WithFields (GetRef<Function>(function_node), /* opt_params=*/ {}, new_body);
249- output_mod->Add (kv.first , new_function);
250- }
251- }
252- VLOG (1 ) << " InlineCompilerFunctions result:" << std::endl << PrettyPrint (output_mod);
253- return output_mod;
254- };
255-
256- return tvm::transform::CreateModulePass (pass_func, 0 , " InlineCompilerFunctionsImpl" , {});
257- }
258-
259156TVM_REGISTER_GLOBAL (" relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols" )
260157 .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols);
261158TVM_REGISTER_GLOBAL (" relay._transform.MarkCompilerFunctionsAsExtern" )
262159 .set_body_typed(MarkCompilerFunctionsAsExtern);
263- TVM_REGISTER_GLOBAL (" relay._transform.InlineCompilerFunctions" )
264- .set_body_typed(InlineCompilerFunctions);
265160
266161} // namespace transforms
267162} // namespace relay
0 commit comments