Skip to content

Commit 5ece9dd

Browse files
committed
[BYOC] Two helper passes for external codegen using RelayToTIR custom pass machinery
(See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). For reasons explained in the above thread I'm moving CUTLASS to be IRModule-at-a-time external codegen using a custom RelayToTIR pass instead of the traditional function-at-a-time external codegen using a relay.ext.cutlass registered function. This means some of the rewriing done on-the-fly by LowerTEPass now needs to be done by the custom pass directly. This PR supplies two passes which ease that burden: - Before starting the CUTLASS-specific processing, make sure all "Compiler" attributed functions have unique global definitions (ie are outlined). Though functions start in this form after BYOC partitioning, under Graph and AOT compilation flows those functions are then inlined to pass through the 'codegen' keyhole which assumes the whole model is just one self-contained main function. This pass will undo that. (I gave up trying to just remove the inlining in the first place.) - After the CUTLASS-specific processing the now compiled "Compiler" attributed functions need to marked as 'extern'. The te_compiler.cc uses the "ExternalSymbol" attribute for that, but since a) the symbol name is never needed, on the presense of the attribute is significant downstream and b) "ExternalSymbol" is easy to confuse with "global_symbol", I just replaced "ExternalSymbol" with "Extern" with an Integer(1) (cf "Primitive"). The outlining pass is a little more general than necessary because it (will also) be used by Collage to rewrite the IRModule into optimally partitioned form while making maximal reuse of partition functions. Hence the abstract GlobalSymbolCache.
1 parent bc492ac commit 5ece9dd

File tree

15 files changed

+585
-39
lines changed

15 files changed

+585
-39
lines changed

include/tvm/ir/expr.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode {
260260
*/
261261
class GlobalVar : public RelayExpr {
262262
public:
263-
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});
263+
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
264264

265265
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
266+
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
266267
};
267268

268269
// PrimExprs that are useful as runtime containers.

include/tvm/relay/attrs/call.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace relay {
3535
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
3636
*/
3737
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
38-
/*! \brief The metadata attached to the call node. */
38+
/*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */
3939
Map<String, ObjectRef> metadata;
4040

4141
TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {

include/tvm/relay/function.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,19 +170,34 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
170170
* \brief namespace of the attributes that can be attached to a relay::Function.
171171
*/
172172
namespace attr {
173-
/*! \brief Mark the function as a primitive function. */
173+
174+
/*!
175+
* \brief Mark the function as a primitive function. Should be bound to \p Integer(1).
176+
*
177+
* Type: Integer
178+
*/
174179
constexpr const char* kPrimitive = "Primitive";
180+
181+
/*!
182+
* \brief Mark the function as being 'extern', ie implemented in a runtime::Module. Should be bound
183+
* to \p Integer(1). Typically accompanied by "Primitive".
184+
*
185+
* Type: Integer
186+
*/
187+
constexpr const char* kExtern = "Extern";
188+
175189
/*!
176-
* \brief Indicate the compiler that should be used for building this function.
190+
* \brief Indicate the external codegen 'compiler' that should be used for building this function.
177191
* When this is unset or set to "default", the default compilation pipeline will be used.
192+
*
193+
* Type: String
178194
*/
179195
constexpr const char* kCompiler = "Compiler";
196+
180197
/*! \brief Indicate if the function is a closure. */
181198
constexpr const char* kClosure = "Closure";
182199
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
183200
constexpr const char* kParams = "__params__";
184-
/*! \brief Store the unique external symbol for external compilers. */
185-
constexpr const char* kExternalSymbol = "ExternalSymbol";
186201
/*! \brief Mark if the function should be avoided being optimized. */
187202
constexpr const char* kSkipOptimization = "SkipOptimization";
188203
/*! \brief Treat the function as a composite operator. */
@@ -193,6 +208,7 @@ constexpr const char* kInline = "Inline";
193208
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
194209
/*! \brief Mark the function as only composed of reshape operations. */
195210
constexpr const char* kReshapeOnly = "relay.reshape_only";
211+
196212
} // namespace attr
197213

198214
} // namespace relay

python/tvm/relay/transform/transform.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -802,24 +802,6 @@ def Inline():
802802
return _ffi_api.Inline()
803803

804804

805-
def InlineComposites(target):
806-
"""Perform inlining on the given Relay IR module. The functions originate
807-
from the MergeComposite pass based on an input pattern table will fold back
808-
to main. Currently, this is used for the TRT BYOC which expects a single
809-
primitive function to operate on.
810-
811-
Parameters
812-
----------
813-
target: str
814-
The byoc target for which ops need to fold back to primitive function.
815-
Returns
816-
-------
817-
ret: tvm.transform.Pass
818-
The registered pass that performs inlining for a Relay IR module.
819-
"""
820-
return _ffi_api.InlineComposites(target)
821-
822-
823805
def gradient(expr, mod=None, mode="higher_order"):
824806
"""
825807
Transform the input function,
@@ -1386,3 +1368,51 @@ def SplitArgs(max_function_args):
13861368
The registered pass for constant folding.
13871369
"""
13881370
return _ffi_api.SplitArgs(max_function_args)
1371+
1372+
1373+
def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
1374+
"""A pass to outline all literal functions in direct call positions which have a "Compiler"
1375+
attribute. The functions are bound to unique global vars according to their existing
1376+
"global_symbol" attribute. At most one function with the same global symbol is outlined.
1377+
1378+
If compiler_filter is non-empty only functions with that as their attribute value are
1379+
outlined.
1380+
1381+
This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism
1382+
to prepare the IRModule before custom lowering.
1383+
1384+
Parameters
1385+
----------
1386+
compiler_filter : String
1387+
If non-empty, the 'compiler' attribute to filter on.
1388+
1389+
Returns
1390+
-------
1391+
ret : tvm.transform.Pass
1392+
The pass.
1393+
"""
1394+
return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter)
1395+
1396+
1397+
def MarkCompilerFunctionsAsExtern(compiler_filter=""):
1398+
"""A pass to mark all global functions which have a "Compiler" attribute matching
1399+
compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
1400+
rewrite all calls to such functions to use the 'call_lowered' calling convention.
1401+
1402+
If compiler_filter is non-empty only functions with that as their attribute value are
1403+
outlined.
1404+
1405+
This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to
1406+
cleanup the IRModule after custom lowering.
1407+
1408+
Parameters
1409+
----------
1410+
compiler_filter : String
1411+
If non-empty, the 'compiler' attribute to filter on.
1412+
1413+
Returns
1414+
-------
1415+
ret : tvm.transform.Pass
1416+
The pass.
1417+
"""
1418+
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)

src/ir/expr.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
141141
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
142142
});
143143

144-
GlobalVar::GlobalVar(String name_hint, Type type) {
144+
GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
145145
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
146146
n->name_hint = std::move(name_hint);
147147
n->checked_type_ = std::move(type);
148+
n->span = std::move(span);
148149
data_ = std::move(n);
149150
}
150151

src/parser/tokenizer.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ struct Tokenizer {
295295
int line = this->line;
296296
int column = this->col;
297297

298-
ICHECK_EQ(Peek(), '[');
299-
Next();
300298
std::stringstream type_key;
301299
while (More() && Peek() != ']') {
302300
type_key << Next();
@@ -498,7 +496,7 @@ struct Tokenizer {
498496
auto token = NewToken(TokenType::kQuestion);
499497
Next();
500498
return token;
501-
} else if (MatchString("meta")) {
499+
} else if (MatchString("meta[")) {
502500
return TokenizeMetaRef();
503501
} else if (next == '#') {
504502
return TokenizeAttr();

src/relay/backend/te_compiler.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class TECompilerImpl : public TECompilerNode {
168168
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
169169
// Abandon the existing function annotations.
170170

171-
// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
171+
// Unfortunately, Optional<DictAttrs>() is indistinguishable from
172172
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
173173
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
174174
// attributes.
@@ -177,7 +177,7 @@ class TECompilerImpl : public TECompilerNode {
177177
function_node->body, function_node->ret_type, function_node->type_params,
178178
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
179179
// Mark function as 'extern' using the "ExternalSymbol" attribute.
180-
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
180+
function = WithAttr(std::move(function), attr::kExtern, Integer(1));
181181
module->Add(kv2.first, function);
182182
}
183183
}
@@ -689,7 +689,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
689689

690690
Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
691691
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
692-
function_node->GetAttr<String>(attr::kExternalSymbol)) {
692+
function_node->HasNonzeroAttr(attr::kExtern)) {
693693
// Nothing to lower inside primitive/external functions.
694694
return GetRef<Function>(function_node);
695695
} else {

src/relay/backend/vm/compiler.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ void VMCompiler::LowerImpl(IRModule mod) {
920920
for (const auto& pair : context_.module->functions) {
921921
auto gvar = pair.first;
922922
if (auto* n = pair.second.as<FunctionNode>()) {
923-
if (n->GetAttr<String>(attr::kExternalSymbol).defined()) {
923+
if (n->HasNonzeroAttr(attr::kExtern)) {
924924
// Already compiled during lowering.
925925
continue;
926926
}
@@ -1129,7 +1129,7 @@ size_t VMCompiler::PopulateGlobalMap() {
11291129
// Excludes PrimFuncs and externs, which are managed by the primitive_map_.
11301130
for (const auto& kv : context_.module->functions) {
11311131
if (const auto* function_node = kv.second.as<FunctionNode>()) {
1132-
if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
1132+
if (!function_node->HasNonzeroAttr(attr::kExtern)) {
11331133
context_.global_map.emplace(kv.first, context_.global_map.size());
11341134
}
11351135
}

src/relay/ir/function.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const {
112112
const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) {
113113
if (const auto* function_node = base_func.as<FunctionNode>()) {
114114
if (!function_node->GetAttr<String>(attr::kCompiler).defined() &&
115-
!function_node->GetAttr<String>(attr::kExternalSymbol).defined() &&
115+
!function_node->HasNonzeroAttr(attr::kExtern) &&
116116
!function_node->HasNonzeroAttr(attr::kSkipOptimization)) {
117117
return function_node;
118118
}

src/relay/op/nn/nn.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f
10121012
- **out**: `(b, m, n)`.
10131013
10141014
)code" TVM_ADD_FILELINE)
1015+
.set_attrs_type<BatchMatmulAttrs>()
10151016
.set_num_inputs(2)
10161017
.add_argument("tensor_a", "3D Tensor", "The first input.")
10171018
.add_argument("tensor_b", "3D Tensor", "The second input.")

0 commit comments

Comments
 (0)