Skip to content

Commit d5224cb

Browse files
committed
- Revert inlining helper pass.
1 parent cd1318b commit d5224cb

File tree

3 files changed

+48
-202
lines changed

3 files changed

+48
-202
lines changed

src/relay/transforms/compiler_function_utils.cc

Lines changed: 35 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,12 @@
2727
#include "../op/call/call.h"
2828
#include "tvm/relay/analysis.h"
2929
#include "tvm/relay/expr_functor.h"
30-
#include "tvm/relay/transform.h"
3130

3231
namespace tvm {
3332
namespace relay {
3433
namespace transforms {
3534
namespace {
3635

37-
/*!
38-
* \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should
39-
* be processed by a pass using \p compiler_filter. Otherwise returns null.
40-
*/
41-
const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) {
42-
if (const auto* function_node = expr.as<FunctionNode>()) {
43-
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
44-
if (opt_compiler.defined() &&
45-
(compiler_filter.empty() || opt_compiler.value() == compiler_filter)) {
46-
return function_node;
47-
}
48-
}
49-
return nullptr;
50-
}
51-
5236
/*!
5337
* \brief Rewrite calls to inlined "Compiler" functions to global functions. The given
5438
* module will be extended with the newly outlined functions.
@@ -60,31 +44,35 @@ class Outliner : public MixedModeMutator {
6044

6145
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
6246
Call new_call = Downcast<Call>(post);
63-
if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) {
64-
auto function = GetRef<Function>(function_node);
65-
DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
66-
<< "' attribute should not have free variables";
67-
// Ask the cache to supply a unique global var for this function.
68-
GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
69-
// Depending on the cache's implementation, two structurally equal (but not object
70-
// equal) functions may be assigned the same global symbol. If so we'll lift it just
71-
// once, but rewrite all the calls.
72-
if (!mod_->ContainGlobalVar(global_symbol->name_hint)) {
73-
function =
74-
WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
75-
mod_->Add(global_symbol, function);
47+
if (const auto* function_node = new_call->op.as<FunctionNode>()) {
48+
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
49+
if (opt_compiler.defined() &&
50+
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
51+
auto function = GetRef<Function>(function_node);
52+
DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
53+
<< "' attribute should not have free variables";
54+
// Ask the cache to supply a unique global var for this function.
55+
GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
56+
// Depending on the cache's implementation, two structurally equal (but not object equal)
57+
// functions may be assigned the same global symbol. If so we'll lift it just once, but
58+
// rewrite all the calls.
59+
if (!mod_->ContainGlobalVar(global_symbol->name_hint)) {
60+
function =
61+
WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
62+
mod_->Add(global_symbol, function);
63+
}
64+
// Update the call.
65+
return WithFields(new_call, global_symbol);
7666
}
77-
// Update the call.
78-
return WithFields(new_call, global_symbol);
7967
}
8068
return post;
8169
}
8270

8371
private:
8472
/*!
8573
* \brief A cached mapping from functions to global variables. Depending on the implementation
86-
* the cache may generate fresh symbols or require the function to already have a
87-
* "global_symbol" attribute, and may share symbols between structurally equal functions.
74+
* the cache may generate fresh symbols or require the function to already have a "global_symbol"
75+
* attribute, and may share symbols between structurally equal functions.
8876
*/
8977
GlobalSymbolCache* cache_;
9078
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
@@ -93,72 +81,6 @@ class Outliner : public MixedModeMutator {
9381
IRModule mod_;
9482
};
9583

96-
/*!
97-
* \brief Inline immediate calls to "Composite" functions.
98-
*/
99-
class InnerInliner : public MixedModeMutator {
100-
public:
101-
InnerInliner() = default;
102-
103-
private:
104-
using MixedModeMutator::Rewrite_;
105-
106-
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
107-
Call new_call = Downcast<Call>(post);
108-
if (const auto* function_node = new_call->op.as<FunctionNode>()) {
109-
ICHECK(function_node->GetAttr<String>(attr::kComposite).defined());
110-
ICHECK_EQ(function_node->params.size(), new_call->args.size());
111-
Map<Var, Expr> subst;
112-
for (size_t i = 0; i < new_call->args.size(); ++i) {
113-
subst.Set(function_node->params[i], new_call->args[i]);
114-
}
115-
return Bind(function_node->body, subst);
116-
}
117-
return post;
118-
}
119-
};
120-
121-
/*!
122-
* \brief Inline calls to global "Compiler" functions with global var in \p global_vars.
123-
* Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body
124-
* are inlined.
125-
*/
126-
class OuterInliner : public MixedModeMutator {
127-
public:
128-
OuterInliner(IRModule mod, Array<GlobalVar> global_vars_)
129-
: mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {}
130-
131-
private:
132-
using MixedModeMutator::Rewrite_;
133-
134-
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
135-
Call new_call = Downcast<Call>(post);
136-
if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
137-
auto global_var = GetRef<GlobalVar>(global_var_node);
138-
if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) {
139-
BaseFunc base_func = mod_->Lookup(global_var);
140-
const auto* function_node = base_func.as<FunctionNode>();
141-
ICHECK(function_node);
142-
ICHECK(function_node->GetAttr<String>(attr::kCompiler).defined());
143-
ICHECK_EQ(function_node->params.size(), new_call->args.size());
144-
Map<Var, Expr> subst;
145-
for (size_t i = 0; i < new_call->args.size(); ++i) {
146-
subst.Set(function_node->params[i], new_call->args[i]);
147-
}
148-
Expr new_body = InnerInliner().VisitExpr(function_node->body);
149-
return Bind(new_body, subst);
150-
}
151-
}
152-
return post;
153-
}
154-
155-
private:
156-
/*! \brief Original module we are processing. */
157-
IRModule mod_;
158-
/*! \brief Global vars of functions to inline. */
159-
Array<GlobalVar> global_vars_;
160-
};
161-
16284
} // namespace
16385

16486
GlobalSymbolCache::~GlobalSymbolCache() = default;
@@ -184,18 +106,17 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
184106
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
185107
[cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
186108
IRModule mod, transform::PassContext ctx) {
187-
VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
188-
IRModule output_mod = mod->ShallowCopy();
109+
IRModule output_mod = GetRef<IRModule>(mod.CopyOnWrite());
189110
for (const auto& kv : mod->functions) {
190-
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
111+
const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second);
112+
if (function_node) {
191113
Expr new_body =
192114
Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body);
193115
Function new_function =
194116
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
195117
output_mod->Add(kv.first, new_function);
196118
}
197119
}
198-
VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod);
199120
return output_mod;
200121
};
201122

@@ -211,57 +132,31 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
211132
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
212133
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
213134
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
214-
VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod);
215135
IRModule output_mod = mod->ShallowCopy();
216136
for (const auto& kv : mod->functions) {
217-
if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) {
218-
auto new_function =
219-
WithFields(GetRef<Function>(function_node), function_node->params,
220-
function_node->body, function_node->ret_type, function_node->type_params,
221-
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
222-
new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1));
223-
output_mod->Update(kv.first, new_function);
137+
if (const auto* function_node = kv.second.as<FunctionNode>()) {
138+
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
139+
if (opt_compiler.defined() &&
140+
(compiler_filter.empty() || opt_compiler.value() == compiler_filter)) {
141+
auto new_function = WithFields(
142+
GetRef<Function>(function_node), function_node->params, function_node->body,
143+
function_node->ret_type, function_node->type_params,
144+
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
145+
new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1));
146+
output_mod->Update(kv.first, new_function);
147+
}
224148
}
225149
}
226-
VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod);
227150
return output_mod;
228151
};
229152

230153
return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {});
231154
}
232155

233-
transform::Pass InlineCompilerFunctions(Array<GlobalVar> global_vars) {
234-
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
235-
[global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) {
236-
VLOG(1) << "InlineCompilerFunctions with global_vars: " << PrettyPrint(global_vars);
237-
if (global_vars.empty()) {
238-
return mod;
239-
}
240-
VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
241-
IRModule output_mod = mod->ShallowCopy();
242-
for (const auto& kv : mod->functions) {
243-
if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) {
244-
output_mod->Remove(kv.first);
245-
} else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
246-
Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body);
247-
Function new_function =
248-
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
249-
output_mod->Add(kv.first, new_function);
250-
}
251-
}
252-
VLOG(1) << "InlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod);
253-
return output_mod;
254-
};
255-
256-
return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsImpl", {});
257-
}
258-
259156
TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols")
260157
.set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols);
261158
TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
262159
.set_body_typed(MarkCompilerFunctionsAsExtern);
263-
TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctions")
264-
.set_body_typed(InlineCompilerFunctions);
265160

266161
} // namespace transforms
267162
} // namespace relay

src/relay/transforms/compiler_function_utils.h

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
* \brief Helper passes for working with functions with the "Compiler" attribute.
2323
*
2424
* Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external
25-
* codegen may find the following helpers useful:
25+
* codegen may find the following two helper passes useful:
2626
*
27-
* - The \p OutlineCompilerFunctionsWithExistingGlobalSymbols pass will lift inline functions with
28-
* a matching "Compiler" attribute to be global functions, using the "global_symbol" attribute
27+
* - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a
28+
* matching "Compiler" attribute to be global functions, using the "global_symbol" attribute
2929
* already assigned. Can be used before custom lowering.
3030
*
3131
* Note that ideally "Compiler" attributed functions would be made global functions as early as
@@ -36,22 +36,15 @@
3636
*
3737
* See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc.
3838
*
39-
* - (The \p OutlineCompilerFunctions pass is a more general version of the above which can use
40-
* a custom cache to both allocate "global_symbol" names and ensure two structurally equal
41-
* functions are assigned the same name, and thus lowered only once. This is used by Collage
42-
* when preparing the optimally partitioned IRModule).
39+
* - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom
40+
* cache to both allocate "global_symbol" names and ensure two strucurally equal functions are
41+
* assigned the same name, and thus lowered only once. This is used by Collage when preparing
42+
* the optimally partitioned IRModule).
4343
*
44-
* - The \p MarkCompilerFunctionsAsExtern pass will update the attributes of global functions
45-
* with a matching "Compiler" attribute to have just the "Extern" attribute. That will signal
46-
* the function has been dealt with. However calls to such functions will be left unchanged.
47-
* Can be used after lowering to cleanup the IRModule.
48-
*
49-
* - The \p InlineCompilerFunctions pass can selectively inline global functions with a matching
50-
* "Compiler" attribute who's name appears in the given set. Obviously it's more sensible to
51-
* not create that function in the first place, however some external codegen have rules to
52-
* accept or reject partitionings based on the overall partitioned function body. This pass
53-
* can be used do the legwork, and will take care to not only inline the outer "Compiler"
54-
* annotated funcition, but also any "Composite" annotated functions in its body.
44+
* - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler"
45+
* attribute with the same function with just an "Extern" attribute, signalling the function
46+
* has been dealt with. However calls to such functions will be left unchanged. Can be used
47+
* after lowering to cleanup the IRModule.
5548
*/
5649

5750
#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
@@ -133,16 +126,6 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
133126
*/
134127
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
135128

136-
/*!
137-
* \brief A pass to inline all global "Compiler" functions which are bound to a global var
138-
* in \p global_vars. Both the global function and any "Composite" functions it its body are
139-
* inlined.
140-
*
141-
* This pass may be useful for external codegen which needs to undo partitioning based on
142-
* properties of the entire partition.
143-
*/
144-
transform::Pass InlineCompilerFunctions(Array<GlobalVar> global_vars);
145-
146129
} // namespace transforms
147130
} // namespace relay
148131
} // namespace tvm

tests/python/relay/transform/test_compiler_function_utils.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def make_consts(dtype, shapes):
4242
}
4343

4444

45-
def original_mod():
45+
def inlined_mod():
4646
return tvm.parser.parse(
4747
"""
4848
#[version = "0.0.5"]
@@ -143,35 +143,10 @@ def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i
143143
)
144144

145145

146-
def expected_inlined_mod():
147-
return tvm.parser.parse(
148-
"""
149-
#[version = "0.0.5"]
150-
def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) {
151-
%0 = nn.dense(%x0, meta[relay.Constant][0], units=2304);
152-
%1 = add(%0, meta[relay.Constant][1]);
153-
%2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16],
154-
Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] {
155-
%6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16],
156-
PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] {
157-
nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True)
158-
};
159-
%6(%y_3_i0, %y_3_i1)
160-
};
161-
%3 = %2(%x3, meta[relay.Constant][2]);
162-
(%1, %3)
163-
}
164-
""",
165-
"from_string",
166-
None,
167-
metatable,
168-
)
169-
170-
171146
def test_outline_compiler_functions_with_existing_global_symbols():
172147
actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols(
173148
"cutlass"
174-
)(original_mod())
149+
)(inlined_mod())
175150
tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True)
176151

177152

@@ -182,12 +157,5 @@ def test_mark_compiler_functions_as_extern():
182157
tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True)
183158

184159

185-
def test_inline_compiler_functions():
186-
mod = expected_outlined_mod()
187-
gv = mod.get_global_var("tvmgen_default_cutlass_main_0")
188-
actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctions([gv])(mod)
189-
tvm.ir.assert_structural_equal(actual_inlined_mod, expected_inlined_mod(), map_free_vars=True)
190-
191-
192160
if __name__ == "__main__":
193161
tvm.testing.main()

0 commit comments

Comments
 (0)