Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
};

// PrimExprs that are useful as runtime containers.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace relay {
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
/*! \brief The metadata attached to the call node. */
/*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
Expand Down
32 changes: 27 additions & 5 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,40 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
/*! \brief Mark the function as a primitive function. */

/*!
* \brief Mark the function as representing a sub-graph which is to be lowered or compiled as
* a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc.
* If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below.
* The function body should be considered opaque by Relay, and many passes simply ignore these
* functions.
*
* Type: Integer
*/
constexpr const char* kPrimitive = "Primitive";

/*!
* \brief Mark the function as externally implemented, ie bound in a runtime::Module within the
* IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally
* the only attribute when present.
*
* Type: Integer
*/
constexpr const char* kExtern = "Extern";

/*!
* \brief Indicate the compiler that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
* \brief Indicates the name of the external codegen 'compiler' that should be used to lower
* or compile the function other than TVM's default lowering pipeline. The name may correspond
* to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'.
*
* Type: String
*/
constexpr const char* kCompiler = "Compiler";

/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
Expand All @@ -193,6 +214,7 @@ constexpr const char* kInline = "Inline";
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";

} // namespace attr

} // namespace relay
Expand Down
70 changes: 52 additions & 18 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,24 +802,6 @@ def Inline():
return _ffi_api.Inline()


def InlineComposites(target):
"""Perform inlining on the given Relay IR module. The functions originate
from the MergeComposite pass based on an input pattern table will fold back
to main. Currently, this is used for the TRT BYOC which expects a single
primitive function to operate on.

Parameters
----------
target: str
The byoc target for which ops need to fold back to primitive function.
Returns
-------
ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _ffi_api.InlineComposites(target)


def gradient(expr, mod=None, mode="higher_order"):
"""
Transform the input function,
Expand Down Expand Up @@ -1386,3 +1368,55 @@ def SplitArgs(max_function_args):
The registered pass for constant folding.
"""
return _ffi_api.SplitArgs(max_function_args)


def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
"""Outlines all literal functions in direct call positions which have a "Compiler"
attribute.

The outlined functions are bound to unique global vars according to their existing
"global_symbol" attribute. At most one function with the same global symbol is outlined.

If compiler_filter is non-empty only functions with that as their attribute value are
outlined.

This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism
to prepare the IRModule before custom lowering.

Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.

Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter)


def MarkCompilerFunctionsAsExtern(compiler_filter=""):
"""Marks all global functions which have a "Compiler" attribute matching
compiler_filter as 'extern'.

The function's attributes are replaced with a single "Extern" attribute, and
all calls to the function are switched to use the 'call_lowered' calling convention.

If compiler_filter is non-empty only functions with that as their attribute value are
outlined.

This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to
cleanup the IRModule after custom lowering.

Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.

Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)
3 changes: 2 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint, Type type) {
GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
n->checked_type_ = std::move(type);
n->span = std::move(span);
data_ = std::move(n);
}

Expand Down
4 changes: 1 addition & 3 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ struct Tokenizer {
int line = this->line;
int column = this->col;

ICHECK_EQ(Peek(), '[');
Next();
std::stringstream type_key;
while (More() && Peek() != ']') {
type_key << Next();
Expand Down Expand Up @@ -498,7 +496,7 @@ struct Tokenizer {
auto token = NewToken(TokenType::kQuestion);
Next();
return token;
} else if (MatchString("meta")) {
} else if (MatchString("meta[")) {
return TokenizeMetaRef();
} else if (next == '#') {
return TokenizeAttr();
Expand Down
8 changes: 4 additions & 4 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ class TECompilerImpl : public TECompilerNode {
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
// Abandon the existing function annotations.

// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
// Unfortunately, Optional<DictAttrs>() is indistinguishable from
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
// attributes.
Function function =
WithFields(GetRef<Function>(function_node), function_node->params,
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
// Mark function as 'extern' using the "ExternalSymbol" attribute.
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
// Mark function as 'extern'.
function = WithAttr(std::move(function), attr::kExtern, Integer(1));
module->Add(kv2.first, function);
}
}
Expand Down Expand Up @@ -688,7 +688,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
function_node->GetAttr<String>(attr::kExternalSymbol)) {
function_node->HasNonzeroAttr(attr::kExtern)) {
// Nothing to lower inside primitive/external functions.
return GetRef<Function>(function_node);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ void VMCompiler::LowerImpl(IRModule mod) {
for (const auto& pair : context_.module->functions) {
auto gvar = pair.first;
if (auto* n = pair.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kExternalSymbol).defined()) {
if (n->HasNonzeroAttr(attr::kExtern)) {
// Already compiled during lowering.
continue;
}
Expand Down Expand Up @@ -1131,7 +1131,7 @@ size_t VMCompiler::PopulateGlobalMap() {
// Excludes PrimFuncs and externs, which are managed by the primitive_map_.
for (const auto& kv : context_.module->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
if (!function_node->HasNonzeroAttr(attr::kExtern)) {
context_.global_map.emplace(kv.first, context_.global_map.size());
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const {
const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) {
if (const auto* function_node = base_func.as<FunctionNode>()) {
if (!function_node->GetAttr<String>(attr::kCompiler).defined() &&
!function_node->GetAttr<String>(attr::kExternalSymbol).defined() &&
!function_node->HasNonzeroAttr(attr::kExtern) &&
!function_node->HasNonzeroAttr(attr::kSkipOptimization)) {
return function_node;
}
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f
- **out**: `(b, m, n)`.

)code" TVM_ADD_FILELINE)
.set_attrs_type<BatchMatmulAttrs>()
.set_num_inputs(2)
.add_argument("tensor_a", "3D Tensor", "The first input.")
.add_argument("tensor_b", "3D Tensor", "The second input.")
Expand Down
Loading