Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1600441
Use privacy annotation to decide whether to include global_symbol
slyubomirsky Jun 21, 2023
4cd11a2
Remove unused import
slyubomirsky Jun 21, 2023
dabcb1e
Simplify function frame construction
slyubomirsky Jun 21, 2023
ebd1b80
Remove debug prints
slyubomirsky Jun 21, 2023
e2ecadc
Whitespace fix
slyubomirsky Jun 22, 2023
f2c999d
Include global_symbol in importers
slyubomirsky Jun 22, 2023
e012683
Formatting fix
slyubomirsky Jun 22, 2023
24769c0
Add private field to function builder in BlockBuilder
slyubomirsky Jun 22, 2023
aaee0cd
Whitespace fix
slyubomirsky Jun 23, 2023
801000b
Set the global symbol for functions outside a module too
slyubomirsky Jun 23, 2023
d1f7c01
Formatting fix
slyubomirsky Jun 23, 2023
5b6c814
Print privacy attribute for functions not in an IRModule
slyubomirsky Jun 23, 2023
954dfd3
Fix placement of pylint overrides in ir.py
slyubomirsky Jun 23, 2023
249ec0e
Fix line length
slyubomirsky Jun 23, 2023
886a63e
Check for nested function case in function.cc instead
slyubomirsky Jun 23, 2023
4b4b52f
Print the global symbol if it doesn't match the name for some reason
slyubomirsky Jun 23, 2023
8064cfe
Correctly coerce the attribute type
slyubomirsky Jun 23, 2023
25e4712
Tweak pylint override again
slyubomirsky Jun 23, 2023
8c7bab1
Add pylint override for trailing whitespace in printer tests
slyubomirsky Jun 25, 2023
b815449
Fix whitespace
slyubomirsky Jun 26, 2023
68142e1
Remove trailing whitespace altogether instead of trying to override it
slyubomirsky Jun 26, 2023
4e272ae
Fix test_utils
slyubomirsky Jun 26, 2023
5314469
Fix test_tvmscript_parser_op_datatype
slyubomirsky Jun 26, 2023
a0df076
Fix global symbols in torch dynamo importer
slyubomirsky Jun 26, 2023
d2a4902
Fix test_dataflow_pattern
slyubomirsky Jun 26, 2023
5ac02e8
Use tvm::attr::kGlobalSymbol instead of string literal
slyubomirsky Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class FunctionFrameNode : public SeqExprFrameNode {
Optional<tvm::relax::StructInfo> ret_struct_info;
/*! \brief Whether the function is annotated as pure */
Optional<Bool> is_pure;
/*! \brief Whether the function is annotated as private */
Optional<Bool> is_private;
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ namespace relax {
/*!
* \brief Start a function frame.
* \param is_pure Whether the function is annotated as pure.
* \param is_private Whether the function is annotated as private.
* \return The created ir_builder Function frame.
*/
TVM_DLL FunctionFrame Function(const Bool& is_pure);
TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);

/*!
* \brief Add a parameter to the last function frame.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def function(
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.

Expand All @@ -215,6 +216,12 @@ def function(
attrs : Dict[str, Object], optional
The function attrs

private : bool, optional
Whether the function is annotated as private.
If the function is private, it will not have a global symbol attribute.
If it is not private and not an inner function, then it will have
a global symbol attribute (mapped to the function's name)

Returns
-------
ret: FunctionScope
Expand All @@ -233,6 +240,11 @@ def function(
)
if attrs is None:
attrs = {}
# The block builder does not permit nesting functions, per above comment,
# so no further check should be needed
if not private:
attrs["global_symbol"] = name

return FunctionScope(self, name, params, attrs)

def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/frontend/torch/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def _capture(graph_module: fx.GraphModule, example_inputs):
keep_params_as_input=keep_params_as_input,
unwrap_unit_return_tuple=True,
)
mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
new_name = f"subgraph_{len(mod.get_global_vars())}"
mod[new_name] = mod_["main"].with_attr("global_symbol", new_name)
return graph_module.forward

dynamo.reset()
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relax/training/setup_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRM

# Add optimizer function.
self._optimizer.init(params)
mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function()
# Need the global symbol to match the function's name
mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function().with_attr(
"global_symbol", self.OPTIMIZER_FUNC
)

# Module attrs
mod = mod.with_attrs(
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,24 @@
############################### Function ################################


def function(is_pure: bool = True) -> frame.FunctionFrame:
def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFrame:
"""Start a function frame.
Parameters
----------
is_pure: bool
Whether the function is annotated as pure.

is_private : bool
Whether the function is annotated as private.

Returns
-------
frame: FunctionFrame
The constructed function frame.
"""
return _ffi_api.Function(is_pure) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Function( # type: ignore[attr-defined] # pylint: disable=no-member
is_pure, is_private
)


def arg(name: py_str, struct_info: StructInfo) -> Var:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class Parser(doc.NodeVisitor):
dispatch_tokens: List[str]
function_annotations: Optional[Dict[str, Dict[str, Any]]]
var_table: VarTable
inside_function: bool # whether we are within a function

def __init__(
self,
Expand All @@ -250,6 +251,7 @@ def __init__(
self.dispatch_tokens = ["default"]
self.function_annotations = function_annotations
self.var_table = VarTable()
self.inside_function = False

def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
"""The main parse method for parser.
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
# this formulation allows us to support having @R.function
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, FType]:
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure isn't used here, but is used later in parsing)
# (pure and private aren't used here, but are used later in parsing)

# need to inspect the stack first because is_defined_in_class expects the outer class
# to be in a particular position in the stack
Expand Down
27 changes: 17 additions & 10 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non

@dispatch.register(token="relax", type_name="FunctionDef")
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
is_inner_function = self.inside_function
self.inside_function = True

# reserve a var for local function
func_val = self.var_table.get().get(node.name)
if not func_val and is_recursive(node):
Expand All @@ -178,11 +181,14 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)

purity = find_purity_annotation(node)
purity = find_decorator_annotation(node, "pure")
# treat the function as private if we are inside another function
# or if it has a privacy annotation
privacy = is_inner_function or find_decorator_annotation(node, "private", default=False)

with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
with R.function(is_pure=purity):
with R.function(is_pure=purity, is_private=privacy):
R.func_name(node.name)
collect_symbolic_var_from_params(self, node)

Expand All @@ -202,22 +208,22 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
self.report_error(stmt, "inline prim_func is disallowed in Relax IR")

self.visit_body(node.body)
self.inside_function = is_inner_function


def find_purity_annotation(node: doc.FunctionDef) -> bool:
def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool:
"""
Check the value of `pure` in the function decorator.
Returns the annotated purity if present, otherwise defaulting to True.
This allows for specifying the purity in the function signature.
Check the value of given annotation (argument name) in the function decorator.
Returns the value of the annotation if present, otherwise giving the default value.
"""
# look for the pure argument in the function decorator
# look for the named argument in the function decorator
for dec in node.decorator_list:
if not isinstance(dec, doc.Call) or dec.func.attr != "function":
continue
for keyword in dec.keywords:
if keyword.arg == "pure":
if keyword.arg == annotation:
return keyword.value.value
return True
return default


@dispatch.register(token="relax", type_name="tvm_declare_function")
Expand All @@ -238,7 +244,8 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params.append(relax.Var(arg.arg, param_sinfo))

is_pure = find_purity_annotation(node)
is_pure = find_decorator_annotation(node, "pure")

func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure)
return I.decl_function(node.name, func_signature)

Expand Down
4 changes: 3 additions & 1 deletion src/relax/training/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class AppendLossMutator : private ExprMutator {
Function new_loss_func = CopyWithNewVars(loss_function);

AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs);
auto new_func_transformed = Downcast<Function>(mutator.VisitExpr(new_func));
auto new_func_transformed =
WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol,
new_func_name.value_or(func_name + "_loss"));

auto new_module = GetRef<IRModule>(mod.CopyOnWrite());
auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss"));
Expand Down
8 changes: 6 additions & 2 deletions src/relax/transform/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,14 @@ class GradientMutator : private ExprMutator {
}

GradientMutator mutator(mod, require_grads_value, target_index);
Function new_func_transformed = Downcast<Function>(mutator.VisitExpr(new_func));

// make the adjoint public
auto new_name = func_name + "_adjoint";
Function new_func_transformed = WithAttr(Downcast<Function>(mutator.VisitExpr(new_func)),
tvm::attr::kGlobalSymbol, new_name);

IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
new_module->Add(GlobalVar(new_name), new_func_transformed);
return new_module;
}

Expand Down
5 changes: 4 additions & 1 deletion src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator {
lift_plan_ = planner.Plan(func, num_input);

// Step 2: Add the lifted function to the module
builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
// (The lifted function should be public so we add a global symbol to it)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming we want them to be public.

auto lift_func =
WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, new_func_name);
builder_->AddFunction(lift_func, new_func_name);

// Step 3: Update the current function.

Expand Down
5 changes: 5 additions & 0 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ void FunctionFrameNode::ExitWithScope() {
"`return` to return an Expr";
this->block_builder->BeginScope(params);
Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value()));
// if the function is not private, add a global symbol to its attributes
if (!is_private.value_or(Bool(false))->value && name.defined() &&
!attrs.count(tvm::attr::kGlobalSymbol)) {
attrs.Set(tvm::attr::kGlobalSymbol, name.value());
}
auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
this->block_builder->EndScope();
tvm::relax::Function func(/*params=*/params,
Expand Down
3 changes: 2 additions & 1 deletion src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)

/////////////////////////////// Function ////////////////////////////////

FunctionFrame Function(const Bool& is_pure) {
FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
const IRBuilder& ir_builder = IRBuilder::Current();
Optional<tvm::IRModule> mod = NullOpt;
Expand All @@ -61,6 +61,7 @@ FunctionFrame Function(const Bool& is_pure) {
}
n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
n->is_pure = is_pure;
n->is_private = is_private;
return FunctionFrame(n);
}

Expand Down
69 changes: 62 additions & 7 deletions src/script/printer/relax/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,36 @@ namespace tvm {
namespace script {
namespace printer {

bool AtTopLevelFunction(const IRDocsifier& d) {
// fewer than 2 frames: not in a function at all
if (d->frames.size() < 2) {
return false;
}
// if the first frame is a RelaxFrame, then this is not inside a module.
// 2 frames => we are at a function (more than 2 => nested function)
if (d->frames[0]->IsInstance<RelaxFrameNode>()) {
return d->frames.size() == 2;
}
// otherwise the first two frames pertain to an IR module,
// so 3 frames => we are at a top-level function (more than 3 => nested function)
return d->frames.size() == 3;
}

TVM_REGISTER_NODE_TYPE(RelaxFrameNode);

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Function>("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc {
std::unordered_set<const tir::VarNode*> func_vars;
With<RelaxFrame> f(d);
IdDoc func_name = d->Define(n, f(), FindFunctionName(d, n).value_or("main"));

IdDoc func_name("");
// if we are binding a local definition, then calling d->Define
// will result in a repeated definition and an incorrect displayed name
if (Optional<String> name = GetBindingName(d)) {
func_name = std::move(IdDoc(name.value()));
} else {
func_name = std::move(d->Define(n, f(), FindFunctionName(d, n).value_or("main")));
}
(*f)->AddDispatchToken(d, "relax");
(*f)->is_func = true;
(*f)->func_vars = &func_vars;
Expand All @@ -52,17 +75,49 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->func_vars = nullptr;
// Step 4. Print attributes
if (n->attrs.defined() && !n->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))})));
// If the function is a global function and has a global symbol,
// then don't print the global symbol (it will be implicit from not being private).
// For a function without an IR module whose global symbol
// doesn't match the function name, we should still print the global symbol attribute.
if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) &&
Downcast<String>(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) {
Map<String, ObjectRef> new_attrs;
for (auto kv : n->attrs->dict) {
if (kv.first != tvm::attr::kGlobalSymbol) {
new_attrs.Set(kv.first, kv.second);
}
}
if (!new_attrs.empty()) {
(*f)->stmts.push_back(ExprStmtDoc(
Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs), n_p->Attr("attrs"))})));
}
} else {
(*f)->stmts.push_back(
ExprStmtDoc(Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))})));
}
}
// Step 5. Prepare the decorator (include purity if it's impure)
ExprDoc decorator = Relax(d, "function");
Array<ExprDoc, void> pos_args = {};
Array<String, void> dec_keys;
Array<ExprDoc, void> dec_values;
if (!n->is_pure) {
Array<ExprDoc> pos_args = {};
decorator = std::move(decorator->Call(
pos_args, {"pure"}, {LiteralDoc::Boolean(false, Optional<ObjectPath>())}));
dec_keys.push_back("pure");
dec_values.push_back(LiteralDoc::Boolean(false, Optional<ObjectPath>()));
}
// if the function is global or is not in a module and does not have a global symbol,
// indicate that it's private
if (AtTopLevelFunction(d) &&
(!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) {
dec_keys.push_back("private");
dec_values.push_back(LiteralDoc::Boolean(true, Optional<ObjectPath>()));
}
if (dec_keys.size()) {
decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values));
}

// Step 6. Print body
Array<StmtDoc> body =
PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"), d, /*use_ret=*/true);
Expand Down
Loading