-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Unity][IR][UX] Privacy annotation in Relax #15140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
3 similar comments
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
| return gv | ||
|
|
||
| @R.function | ||
| @R.function(private=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@masahi I requested your review because I am not sure why exactly I needed to insert the privacy annotations for these tests. Leaving it off prevents any of the functions from being merged, but I'm not sure why that would happen. Do you know where it might matter that the functions have a global_symbol attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yelite Masa said you might also be able to comment on this issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this diff, these tests will pass without making those function private
diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc
index fe36eb28ef..596c68ef3e 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -96,7 +96,7 @@ IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> entry_funcs
for (auto f : existing_functions) {
// If a function has an external linkage type, we do not remove it.
// Otherwise, we check the function and remove it if it is not used anywhere.
- if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) {
+ if (!tracer.check_if_called(f.first)) {
mod_->Remove(f.first);
}
}
diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc
index 0bc92ba923..bdae719e31 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -60,6 +60,8 @@
#include <tvm/tir/function.h>
#include "../../support/arena.h"
+#include "tvm/ir/attrs.h"
+#include "tvm/ir/function.h"
#include "utils.h"
namespace tvm {
@@ -311,7 +313,8 @@ class CompositeInliner : public ExprMutator {
if (func->GetAttr<String>(attr::kComposite)) {
if (!inlined_functions_.count(func)) {
- inlined_functions_.Set(func, CopyWithNewVars(func));
+ inlined_functions_.Set(func,
+ WithoutAttr(CopyWithNewVars(func), tvm::attr::kGlobalSymbol));
}
return Call(inlined_functions_[func], call->args);
}At the end of MergeCompositeFunction, a DCE pass will be called to removed the original copies of the functions that are merged. However, DCE will not remove public functions, i.e. those with a global symbol. Making those functions public prevents them from being cleaned up by DCE.
I think the logic here is correct. Those fused_xxx functions should be private as they are not meant to be called externally. So we probably don't want to change anything here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for the explanation, @yelite. So, to confirm, you don't think there needs to be a change to the merge composite pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the merge composite pass shouldn't be changed. We might need to update the fusion pass if it produces fused functions with global symbol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think by default fusion pass should not produce functions with global symbol. Ideally we should prevent fusing functions that with global symbol --- because user expects to access the original function
|
|
||
| // Step 2: Add the lifted function to the module | ||
| builder_->AddFunction(lift_plan_.f_transform_params, new_func_name); | ||
| // (The lifted function should be public so we add a global symbol to it) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming we want them to be public.
|
cc @yongwww |
src/script/printer/relax/function.cc
Outdated
| ->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))}))); | ||
| // if the function is a global function and has a global symbol, | ||
| // then don't print the global symbol (it will be implicit from not being private) | ||
| if (d->frames.size() == 3 && n->attrs->dict.count("global_symbol")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fond of having this magic number, but it was the only way I could figure out to check that the function is at the top level of a module. Possibly a hack--is there a better way I could check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yongwww might there be a more robust way to check if we're inside a module?
| bb = relax.BlockBuilder() | ||
| x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) | ||
| with bb.function("main", [x]): | ||
| with bb.function("main", [x], {"global_symbol": "main"}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about we also have private=False flag here(default to False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private=False in the parser (which is the default) means that there's a global symbol in the attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the block builder should also have a private flag when building a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this way it is closer to what function provide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented that, good idea 👍
|
Thanks @slyubomirsky on the default behavior, would be great if we default It will involve more handling in printing(like not printing global symbol, and when global symbol do not exist, print as private=True), and modifying the test cases that comes with global_symbol (by removing them). But it would be much cleaner from user pov |
|
Do you mean that private should be another attribute? Or that it should be a field in the function AST node? RIght now, privacy is handled entirely in the parser (if |
|
The tough part is dealing with nested functions. There is no way to have a public nested function, but the BlockBuilder/printer/etc would have to know whether we're inside another function or not. (Though it would not be hard to modify the BlockBuilder to track that if that's what we want.) |
|
sorry i am not talking about changing the AST. I only mean the interface And make it happen also for tir |
|
Then, yes, that is already how I had implemented it. False is the default value for private. I will investigate doing the same for TIR (I've never touched the TVMScript for TIR). |
|
We can get relax tests in first. Can we detect that printer prints out things correctly (aka skips the global symbol and instead use private annotation)? |
|
Yeah, there are printer tests included. I am making some further tweaks because the implementation right now is rather hacky. |
| ) | ||
|
|
||
|
|
||
| def test_nested_function(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually uncovered a seemingly unrelated bug in the printer, where the nested function would be printed with a suffix after its name (see the changes in src/script/printer/relax/function.cc) because the LHS variable in the VarBinding was treated as already having been encountered. Also fixed.
|
I will implement the same changes for TIR in a separate PR |
Implements proposal #14899 for TIR. Adapted from Unity PR #15171, by the same method as PR #15140. Namely, a TIR PrimFunc can be specified to be private (without a global symbol attribute) in TVMScript in the `prim_func` decorator. By default, `PrimFunc`s are not private, so they will have a `global_symbol` attribute that is mapped to their name. Example usage: ```python # not private: its global symbol will be "func" @T.prim_func def func(...): ... # no global symbol included @T.prim_func(private=True) def func(...): ... ``` This did require changing very, very many tests, unfortunately.
This PR implements the proposal in #14899. Namely, the
@R.functiondecorator now has an optionalprivateattribute. If a function is marked as private, then it will not have a global symbol attached to it and thus will not be externally accessible. By default, functions are not private, so the parser does insert a global symbol for them.Here is an example of how to use the privacy attribute:
Passes that introduce new functions should decide whether a global symbol should be included with those functions. Most of the time, this is not the case (e.g., the functions extracted by the lambda lifting pass probably should not be publicly accessible).
Note that there are alternative approaches we can consider. For example, we may have "private" as a field on a function node and fill in the global symbol attributes later. This may be important for some passes.