Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cc2bb1f
Implement privacy annotation in TIR
slyubomirsky Jun 26, 2023
bf1d06f
Fix broken TIR tests
slyubomirsky Jun 27, 2023
a17ae54
Treat it as an error to specify the global symbol if marked private
slyubomirsky Jun 27, 2023
119ced2
Handle global symbol in testing utils
slyubomirsky Jun 27, 2023
caa19a8
Transfer global symbol in SplitHostDevice
slyubomirsky Jun 27, 2023
b19cd9d
Preserve global symbol in USMP convert pool allocations
slyubomirsky Jun 27, 2023
182de83
Use global symbol in numerous tests
slyubomirsky Jun 27, 2023
80b4418
Remove some code duplication
slyubomirsky Jun 27, 2023
a7e86f7
formatting
slyubomirsky Jun 27, 2023
1543dce
pylint
slyubomirsky Jun 28, 2023
eedb556
Formatting
slyubomirsky Jun 28, 2023
be9ee73
Undo SplitHostDevice changes, use privacy annotation in tests instead
slyubomirsky Jun 29, 2023
98bd20f
Fix rebase mistakes in test_tir_schedule_pad_einsum
slyubomirsky Jun 30, 2023
9bf0c12
Factor out func_name in function.cc
slyubomirsky Jul 3, 2023
a7ca19a
Fix global symbol in more tests
slyubomirsky Jul 3, 2023
cc7f39c
Fix more global symbols in unit tests
slyubomirsky Jul 3, 2023
c07a149
Fix global symbols in test_pass_plan_devices.py::test_lowered
slyubomirsky Jul 6, 2023
57bbceb
Fix global symbol in test_tir_transform_helpers
slyubomirsky Jul 6, 2023
9777abf
Fix test_tir_schedule_cache_read_write tests
slyubomirsky Jul 14, 2023
3b9246b
Fix global symobls in test_tvmscript_parser_tir.py
slyubomirsky Jul 14, 2023
8b219e5
Transfer over without_attr from Unity branch
slyubomirsky Jul 14, 2023
4a2e3b2
ignore global symbols in meta schedule tests
slyubomirsky Jul 14, 2023
6ee9ce9
Do not use global symols in test_tir_transform_make_unpacked_pi.py
slyubomirsky Jul 17, 2023
f2f3976
Fix global symbol in test_evaluator_with_preproc
slyubomirsky Jul 17, 2023
3ab5307
Linting fix
slyubomirsky Jul 17, 2023
745421b
Fix unused import
slyubomirsky Jul 17, 2023
4baeacd
Do not use private with an explicit global symbol in test_tir_transfo…
slyubomirsky Jul 17, 2023
64c0cf0
Use an ordinary bool for the PrimFuncFrame's is_private attribute
slyubomirsky Jul 18, 2023
2f70c39
Do not use abbreviation in assert_structural_equal_gs (expand name)
slyubomirsky Jul 18, 2023
645d356
Fix test case in test_tir_transform_split_host_device
slyubomirsky Jul 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PrimFuncFrameNode : public TIRFrameNode {
Optional<String> name;
/*! \brief Function parameters. */
Array<tvm::tir::Var> args;
/*! \brief Whether the PrimFunc is annotated as private. */
bool is_private;
/*! \brief The return type of the function. */
Optional<Type> ret_type;
/*! \brief Maps some parameters to specific Buffer data structures. */
Expand All @@ -86,6 +88,7 @@ class PrimFuncFrameNode : public TIRFrameNode {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("is_private", &is_private);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("attrs", &attrs);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Opt
* \brief The primitive function statement.
* \return The PrimFuncFrame.
*/
PrimFuncFrame PrimFunc();
PrimFuncFrame PrimFunc(bool is_private);

/*!
* \brief The PrimFunc variable arguments adding function.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,19 @@ def with_attr(self, attr_key_or_dict, attr_value=None):
return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
)

def without_attr(self, attr_key: str) -> "BaseFunc":
"""Create a new copy of the function with an attribute without provided key.

Parameters
----------
attr_key : str
The attribute key to delete from the attrubte pairs.


Returns
-------
func : BaseFunc
A new copy of the function
"""
return _ffi_api.BaseFuncWithoutAttr(self, attr_key)
17 changes: 16 additions & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ def get_rules(
return [rule for rule in rules if isinstance(rule, types)]


def structural_equal_no_gs(mod1: IRModule, mod2: IRModule) -> bool:
"""
Checks structural equality but ignores global symbols
"""
# for every function in the modules, remove global symbols from the attrs and then compare
def remove_global_symbols(mod: IRModule) -> IRModule:
stripped_mod = IRModule()
for global_var in mod.get_global_vars():
func = mod[global_var]
stripped_mod[global_var] = func.without_attr("global_symbol")
return stripped_mod

return structural_equal(remove_global_symbols(mod1), remove_global_symbols(mod2))


def generate_design_space(
kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"],
mod: IRModule,
Expand Down Expand Up @@ -87,7 +102,7 @@ def _find_match_sketch_id(
insts=sketch.trace.insts,
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
if structural_equal_no_gs(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json")
return sketch_id
return None
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,22 @@ def buffer_decl(*args, **kwargs):
return buffer(*args, **kwargs)


def prim_func() -> frame.PrimFuncFrame:
def prim_func(is_private: bool = False) -> frame.PrimFuncFrame:
"""The primitive function statement.

Parameters
----------
is_private : bool
Whether the PrimFunc is annotated as private
(if yes, it does not have a global symbol assigned;
otherwise, the global symbol is the PrimFunc's name)

Returns
-------
res : frame.PrimFuncFrame
The PrimFuncFrame.
"""
return _ffi_api.PrimFunc() # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.PrimFunc(is_private) # type: ignore[attr-defined] # pylint: disable=no-member


def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]:
Expand Down
44 changes: 35 additions & 9 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc
Expand All @@ -25,26 +25,52 @@
from .._core import doc, parse, parse_macro, utils


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.

Parameters
----------
func : Callable
The function to be parsed as prim func.
(Listed as optional to allow the decorator to be used
without arguments, like `@prim_func`,
or with an argument, `@prim_func(private=True)`)

private : bool, optional
Whether the function should be treated as private.
A private function has no global symbol attribute;
if the function is not private, it will have a global symbol
matching the function name.

Returns
-------
res : Union[PrimFunc, Callable]
The parsed tir prim func.
"""
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(inspect.stack(), func):
return func
f = parse(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
return f
# pylint: disable=unused-argument
# (private will be used in the parser, but not immediately)

# need to capture this var outside the wrapper because the wrapper
# adds to the stack
outer_stack = inspect.stack()

def decorator_wrapper(func):
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
return f

if func is not None:
# no optional args given => use wrapper directly
return decorator_wrapper(func)
else:
# if there is an optional arg given, return a new decorator
# that will then be invoked
setattr(decorator_wrapper, "dispatch_token", "tir")
return decorator_wrapper


setattr(prim_func, "dispatch_token", "tir")
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
return var


def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool:
"""
Check the value of given annotation (argument name) in the prim_func decorator.
Returns the value of the annotation if present, otherwise giving the default value.
"""
# look for the named argument in the prim_func decorator
for dec in node.decorator_list:
if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func":
continue
for keyword in dec.keywords:
if keyword.arg == annotation:
return keyword.value.value
return default


@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
"""The for visiting method for tir.
Expand Down Expand Up @@ -365,10 +380,11 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
"""
supplied_annotation = self.function_annotations
func_annotation = supplied_annotation.get(node.name, {})
privacy = find_decorator_annotation(node, "private", default=False)
self.function_annotations = None
with self.var_table.with_frame():
self.var_table.add("range", T.serial)
with T.prim_func():
with T.prim_func(is_private=privacy):
T.func_name(node.name)
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,11 +1983,11 @@ def inner(self):
if name.startswith("_"):
pass
elif isinstance(method, tvm.ir.function.BaseFunc):
func_dict[name] = method
func_dict[name] = method.with_attr("global_symbol", name)
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(source_code)
func_dict[name] = prim_func
func_dict[name] = prim_func.with_attr("global_symbol", name)
return tvm.IRModule(func_dict)

else:
Expand Down Expand Up @@ -2093,6 +2093,10 @@ def test_compare(self, before, expected, transform):
after = transform(before)

try:
# overwrite global symbol so it doesn't come up in the comparison
if isinstance(after, tvm.tir.PrimFunc):
after = after.with_attr("global_symbol", "main")
expected = expected.with_attr("global_symbol", "main")
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
before_str = before.script(name="before")
Expand Down
21 changes: 20 additions & 1 deletion python/tvm/tir/schedule/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,33 @@
# under the License.
# pylint: disable=dangerous-default-value
"""Testing utilities for the TensorIR schedule API"""
from typing import Sequence, Union
from typing import Any, Sequence, Union

import tvm
from tvm.ir import IRModule, assert_structural_equal
from tvm.tir import PrimFunc
from tvm.tir.schedule import Schedule, Trace


def assert_structural_equal_ignore_global_symbol(
func1: PrimFunc,
func2: PrimFunc,
*args: Any,
**kwargs: Any,
) -> None:
"""
Asserts that PrimFuncs func1 and func2 are structurally equal, setting both
their global symbol attributes to main so that the global symbol
will not be a point of comparison.
"""
assert_structural_equal(
func1.with_attr("global_symbol", "main"),
func2.with_attr("global_symbol", "main"),
*args,
**kwargs,
)


def verify_trace_roundtrip(
sch: Schedule,
mod: Union[PrimFunc, IRModule],
Expand Down
13 changes: 13 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief The function data structure.
*/
#include <tvm/ir/function.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>

Expand All @@ -44,4 +45,16 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Looks like we always forgot to upstream this piece from the unity branch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah lol

.set_body_typed([](BaseFunc func, String key) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithoutAttr(Downcast<tir::PrimFunc>(std::move(func)), key);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithoutAttr(Downcast<relay::Function>(std::move(func)), key);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});

} // namespace tvm
16 changes: 16 additions & 0 deletions src/script/ir_builder/tir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ namespace tir {

void PrimFuncFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
// if the prim func is not private and there isn't already a global symbol,
// add a global symbol
if (!is_private && name.defined()) {
if (!attrs.defined()) {
attrs = {{tvm::attr::kGlobalSymbol, name.value()}};
} else if (!attrs.value().count(tvm::attr::kGlobalSymbol)) {
// copy over attributes (can't mutate the dict inside the optional in-place)
Map<String, ObjectRef> new_attrs;
for (auto kv : attrs.value()) {
new_attrs.Set(kv.first, kv.second);
}
new_attrs.Set(tvm::attr::kGlobalSymbol, name.value());
attrs = std::move(new_attrs);
}
}

tvm::tir::PrimFunc func(
/*params=*/args,
/*body=*/AsStmt(stmts),
Expand Down
7 changes: 6 additions & 1 deletion src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Opt
axis_separators.value_or(Array<IntImm>()));
}

PrimFuncFrame PrimFunc() {
PrimFuncFrame PrimFunc(bool is_private) {
ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
n->name = NullOpt;
n->is_private = is_private;
n->args.clear();
n->ret_type = NullOpt;
n->buffer_map.clear();
Expand Down Expand Up @@ -96,6 +97,10 @@ void FuncAttrs(Map<String, ObjectRef> attrs) {
if (frame->attrs.defined()) {
LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
}
if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private) {
LOG(FATAL) << "ValueError: Specifying the global symbol even though the PrimFunc is annotated "
"as private";
}
frame->attrs = attrs;
}

Expand Down
36 changes: 31 additions & 5 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<TIRFrame> f(d, func);
(*f)->AddDispatchToken(d, "tir");
auto func_name = IdDoc(FindFunctionName(d, func).value_or("main"));
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
});
Expand Down Expand Up @@ -104,9 +105,25 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(TIR(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
// for global symbol, don't display it if it matches the func name
if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) &&
Downcast<String>(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) {
Map<String, ObjectRef> new_attrs;
for (auto kv : func->attrs->dict) {
if (kv.first != tvm::attr::kGlobalSymbol) {
new_attrs.Set(kv.first, kv.second);
}
}
if (!new_attrs.empty()) {
(*f)->stmts.push_back(ExprStmtDoc(
TIR(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs), p->Attr("attrs"))})));
}
} else {
(*f)->stmts.push_back(
ExprStmtDoc(TIR(d, "func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
}
// Step 3. Handle `func->buffer_map`
for (int i = 0; i < n_args; ++i) {
Expand Down Expand Up @@ -168,10 +185,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
}
}
// Step 5. Determine if we need to display the private annotation in the decorator
ExprDoc decorator = TIR(d, "prim_func");
// mark private if there is no global symbol
if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) {
Array<ExprDoc> pos_args;
decorator = std::move(decorator->Call(pos_args, {"private"},
{LiteralDoc::Boolean(true, Optional<ObjectPath>())}));
}

return HeaderWrapper(d, FunctionDoc(
/*name=*/IdDoc(FindFunctionName(d, func).value_or("main")),
/*name=*/func_name,
/*args=*/args,
/*decorators=*/{TIR(d, "prim_func")},
/*decorators=*/{decorator},
/*return_type=*/ret_type,
/*body=*/(*f)->stmts));
});
Expand Down
Loading