Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1600441
Use privacy annotation to decide whether to include global_symbol
slyubomirsky Jun 21, 2023
4cd11a2
Remove unused import
slyubomirsky Jun 21, 2023
dabcb1e
Simplify function frame construction
slyubomirsky Jun 21, 2023
ebd1b80
Remove debug prints
slyubomirsky Jun 21, 2023
e2ecadc
Whitespace fix
slyubomirsky Jun 22, 2023
f2c999d
Include global_symbol in importers
slyubomirsky Jun 22, 2023
e012683
Formatting fix
slyubomirsky Jun 22, 2023
24769c0
Add private field to function builder in BlockBuilder
slyubomirsky Jun 22, 2023
aaee0cd
Whitespace fix
slyubomirsky Jun 23, 2023
801000b
Set the global symbol for functions outside a module too
slyubomirsky Jun 23, 2023
d1f7c01
Formatting fix
slyubomirsky Jun 23, 2023
5b6c814
Print privacy attribute for functions not in an IRModule
slyubomirsky Jun 23, 2023
954dfd3
Fix placement of pylint overrides in ir.py
slyubomirsky Jun 23, 2023
249ec0e
Fix line length
slyubomirsky Jun 23, 2023
886a63e
Check for nested function case in function.cc instead
slyubomirsky Jun 23, 2023
4b4b52f
Print the global symbol if it doesn't match the name for some reason
slyubomirsky Jun 23, 2023
8064cfe
Correctly coerce the attribute type
slyubomirsky Jun 23, 2023
25e4712
Tweak pylint override again
slyubomirsky Jun 23, 2023
8c7bab1
Add pylint override for trailing whitespace in printer tests
slyubomirsky Jun 25, 2023
b815449
Fix whitespace
slyubomirsky Jun 26, 2023
68142e1
Remove trailing whitespace altogether instead of trying to override it
slyubomirsky Jun 26, 2023
4e272ae
Fix test_utils
slyubomirsky Jun 26, 2023
5314469
Fix test_tvmscript_parser_op_datatype
slyubomirsky Jun 26, 2023
a0df076
Fix global symbols in torch dynamo importer
slyubomirsky Jun 26, 2023
d2a4902
Fix test_dataflow_pattern
slyubomirsky Jun 26, 2023
5ac02e8
Use tvm::attr::kGlobalSymbol instead of string literal
slyubomirsky Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
# this formulation allows us to support having @R.function
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, FType]:
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure isn't used here, but is used later in parsing)
# (pure and private aren't used here, but are used later in parsing)

# need to inspect the stack first because is_defined_in_class expects the outer class
# to be in a particular position in the stack
Expand Down
29 changes: 18 additions & 11 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional

from tvm import relax, tir
from tvm.ir import GlobalVar, structural_equal
from tvm.ir import GlobalVar, make_node, structural_equal
from tvm.relax import Expr, StructInfo
from tvm.relax.utils import convert_to_expr
from tvm.script.ir_builder.relax.frame import BlockFrame
Expand Down Expand Up @@ -178,7 +178,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)

purity = find_purity_annotation(node)
purity = find_decorator_annotation(node, "pure")
# don't handle the privacy annotation here because it's only relevant for global funcs

with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
Expand All @@ -204,20 +205,19 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
self.visit_body(node.body)


def find_purity_annotation(node: doc.FunctionDef) -> bool:
def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool:
"""
Check the value of `pure` in the function decorator.
Returns the annotated purity if present, otherwise defaulting to True.
This allows for specifying the purity in the function signature.
Check the value of given annotation (argument name) in the function decorator.
Returns the value of the annotation if present, otherwise giving the default value.
"""
# look for the pure argument in the function decorator
# look for the named argument in the function decorator
for dec in node.decorator_list:
if not isinstance(dec, doc.Call) or dec.func.attr != "function":
continue
for keyword in dec.keywords:
if keyword.arg == "pure":
if keyword.arg == annotation:
return keyword.value.value
return True
return default


@dispatch.register(token="relax", type_name="tvm_declare_function")
Expand All @@ -238,8 +238,15 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params.append(relax.Var(arg.arg, param_sinfo))

is_pure = find_purity_annotation(node)
func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure)
is_pure = find_decorator_annotation(node, "pure")

# if the global function is not private, then use its name as the global symbol
is_private = find_decorator_annotation(node, "private", default=False)
attrs = None
if not is_private:
attrs = make_node("DictAttrs", global_symbol=node.name)

func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure, attrs=attrs)
return I.decl_function(node.name, func_signature)


Expand Down
4 changes: 3 additions & 1 deletion src/relax/transform/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ class GradientMutator : private ExprMutator {
}

GradientMutator mutator(mod, require_grads_value, target_index);
Function new_func_transformed = Downcast<Function>(mutator.VisitExpr(new_func));
// remove the global symbol if the original had one (the adjoint does not need a global symbol)
Function new_func_transformed =
WithoutAttr(Downcast<Function>(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol);

IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
Expand Down
5 changes: 4 additions & 1 deletion src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator {
lift_plan_ = planner.Plan(func, num_input);

// Step 2: Add the lifted function to the module
builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
// (The lifted function should be public so we add a global symbol to it)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming we want them to be public.

auto lift_func =
WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, new_func_name);
builder_->AddFunction(lift_func, new_func_name);

// Step 3: Update the current function.

Expand Down
12 changes: 12 additions & 0 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ void FunctionFrameNode::ExitWithScope() {
"function scope, if it's defined in a Module";
const IRModuleFrame& frame = opt_frame.value();
const String& func_name = name.value_or("");
// If the function has already been declared (i.e., it is global), see if there is
// already a global symbol defined for it (i.e., it is not private).
// If yes, add it to the current function's attributes (unless one was manually defined)
if (frame->global_var_map.count(func_name)) {
auto decl = frame->functions.at(frame->global_var_map.at(func_name));
if (decl->attrs.defined()) {
auto attr_dict = decl->attrs.get()->dict;
if (attr_dict.count("global_symbol") && !attrs.count("global_symbol")) {
func = std::move(WithAttr(func, tvm::attr::kGlobalSymbol, attr_dict.at("global_symbol")));
}
}
}
if (!frame->global_var_map.count(func_name)) {
// First time visiting the function.
ir::DeclFunction(func_name, func);
Expand Down
40 changes: 34 additions & 6 deletions src/script/printer/relax/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,45 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
(*f)->func_vars = nullptr;
// Step 4. Print attributes
if (n->attrs.defined() && !n->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))})));
// if the function is a global function and has a global symbol,
// then don't print the global symbol (it will be implicit from not being private)
if (d->frames.size() == 3 && n->attrs->dict.count("global_symbol")) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fond of having this magic number, but it was the only way I could figure out to check that the function is at the top level of a module. Possibly a hack--is there a better way I could check for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongwww might there be a more robust way to check if we're inside a module?

Map<String, ObjectRef> new_attrs;
for (auto kv : n->attrs->dict) {
if (kv.first != "global_symbol") {
new_attrs.Set(kv.first, kv.second);
}
}
if (!new_attrs.empty()) {
(*f)->stmts.push_back(ExprStmtDoc(
Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs), n_p->Attr("attrs"))})));
}
} else {
(*f)->stmts.push_back(
ExprStmtDoc(Relax(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))})));
}
}
// Step 5. Prepare the decorator (include purity if it's impure)
ExprDoc decorator = Relax(d, "function");
Array<ExprDoc, void> pos_args = {};
Array<String, void> dec_keys;
Array<ExprDoc, void> dec_values;
if (!n->is_pure) {
Array<ExprDoc> pos_args = {};
decorator = std::move(decorator->Call(
pos_args, {"pure"}, {LiteralDoc::Boolean(false, Optional<ObjectPath>())}));
dec_keys.push_back("pure");
dec_values.push_back(LiteralDoc::Boolean(false, Optional<ObjectPath>()));
}
// if the function is global and does not have a global symbol, indicate that it's private
if (d->frames.size() == 3 &&
(!n->attrs.defined() || !n->attrs->dict.count("global_symbol"))) {
dec_keys.push_back("private");
dec_values.push_back(LiteralDoc::Boolean(true, Optional<ObjectPath>()));
}
if (dec_keys.size()) {
decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values));
}

// Step 6. Print body
Array<StmtDoc> body =
PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"), d, /*use_ret=*/true);
Expand Down
6 changes: 2 additions & 4 deletions tests/python/relax/test_transform_attach_global_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

@R.function
@R.function(private=True)
def main(
x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")
) -> R.Tensor:
Expand Down Expand Up @@ -74,7 +74,6 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
def main(
x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")
) -> R.Tensor:
R.func_attr({"global_symbol": "main"})
m, n, k = T.int64(), T.int64(), T.int64()
gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32"))
return gv0
Expand All @@ -94,7 +93,7 @@ class Before:
def tir_zeros(x: T.Buffer((2), "float32")) -> None:
x[0] = T.float32(0)

@R.function
@R.function(private=True)
def main() -> R.Tensor:
gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,), dtype="float32"))
return gv0
Expand All @@ -110,7 +109,6 @@ def tir_zeros(x: T.Buffer((2), "float32")) -> None:

@R.function
def main() -> R.Tensor:
R.func_attr({"global_symbol": "main"})
gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), dtype="float32"))
return gv0

Expand Down
3 changes: 3 additions & 0 deletions tests/python/relax/test_transform_combine_parallel_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def expected1(
w2: R.Tensor((2, 640, 640), dtype="float32"),
w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
) -> R.Tensor:
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w2), axis=2)
lv1 = R.matmul(x, lv, out_dtype="float32")
Expand Down Expand Up @@ -458,6 +459,7 @@ def expected1(
b0: R.Tensor((640,), dtype="float32"),
b1: R.Tensor((640,), dtype="float32"),
) -> R.Tensor:
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
Expand Down Expand Up @@ -515,6 +517,7 @@ def expected(
w3: R.Tensor((640, 640), dtype="float32"),
w4: R.Tensor((640, 640), dtype="float32"),
) -> R.Tensor:
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/test_transform_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def tir_add(
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]

@R.function
@R.function(private=True)
def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")):
gv0 = R.add(x, w)
return gv0
Expand Down Expand Up @@ -202,7 +202,7 @@ def tir_add(
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]

@R.function
@R.function(private=True)
def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")):
gv0 = R.add(x, w)
return gv0
Expand Down Expand Up @@ -239,7 +239,7 @@ def tir_add(
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]

@R.function
@R.function(private=True)
def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")):
gv0 = R.add(x, w)
return gv0
Expand Down Expand Up @@ -310,7 +310,7 @@ def unused_func1(
vi, vj = T.axis.remap("SS", [i, j])
z[vi, vj] = x[vi, vj] + y[vi, vj]

@R.function
@R.function(private=True)
def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")):
gv0 = R.add(x, w)
return gv0
Expand Down
14 changes: 7 additions & 7 deletions tests/python/relax/test_transform_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "floa
T.writes(B[v_i0, v_i1, v_i2, v_i3])
B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0))

@R.function
@R.function(private=True)
def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"),
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]

@R.function
@R.function(private=True)
def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
Expand All @@ -1091,7 +1091,7 @@ def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1
R.output(gv)
return gv

@R.function
@R.function(private=True)
def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 320), dtype="float32"):
cls = Expected
R.func_attr({"Primitive": 1})
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"),
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]

@R.function
@R.function(private=True)
def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Expected
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def main(x: R.Tensor(["n", "m"], "float32")):

@I.ir_module
class Expected:
@R.function
@R.function(private=True)
def fused_add_exp_squeeze(
x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
) -> R.Tensor(["n", "m"], dtype="float32"):
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def main(s: R.Shape(["n"])):

@I.ir_module
class Expected:
@R.function
@R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object):

@I.ir_module
class Expected:
@R.function
@R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
Expand Down
Loading