Skip to content

Conversation

@slyubomirsky
Copy link
Contributor

@slyubomirsky slyubomirsky commented Jun 21, 2023

This PR implements the proposal in #14899. Namely, the @R.function decorator now has an optional private attribute. If a function is marked as private, then it will not have a global symbol attached to it and thus will not be externally accessible. By default, functions are not private, so the parser does insert a global symbol for them.

Here is an example of how to use the privacy attribute:

@I.ir_module
class Module:
    # Will not have a global symbol! Often not very useful to make main private...
    @R.function(private=True)
    def main(...): ...

    # private is false by default. This function will have a global symbol ("other")
    @R.function
    def other(...): ...

Passes that introduce new functions should decide whether a global symbol should be included with those functions. Most of the time, this is not the case (e.g., the functions extracted by the lambda lifting pass probably should not be publicly accessible).

Note that there are alternative approaches we can consider. For example, we may have "private" as a field on a function node and fill in the global symbol attributes later. This may be important for some passes.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

3 similar comments
@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

return gv

@R.function
@R.function(private=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi I requested your review because I am not sure why exactly I needed to insert the privacy annotations for these tests. Leaving it off prevents any of the functions from being merged, but I'm not sure why that would happen. Do you know where it might matter that the functions have a global_symbol attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yelite Masa said you might also be able to comment on this issue

Copy link
Contributor

@yelite yelite Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this diff, these tests will pass without making those function private

diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc
index fe36eb28ef..596c68ef3e 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -96,7 +96,7 @@ IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> entry_funcs
   for (auto f : existing_functions) {
     // If a function has an external linkage type, we do not remove it.
     // Otherwise, we check the function and remove it if it is not used anywhere.
-    if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) {
+    if (!tracer.check_if_called(f.first)) {
       mod_->Remove(f.first);
     }
   }
diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc
index 0bc92ba923..bdae719e31 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -60,6 +60,8 @@
 #include <tvm/tir/function.h>
 
 #include "../../support/arena.h"
+#include "tvm/ir/attrs.h"
+#include "tvm/ir/function.h"
 #include "utils.h"
 
 namespace tvm {
@@ -311,7 +313,8 @@ class CompositeInliner : public ExprMutator {
 
       if (func->GetAttr<String>(attr::kComposite)) {
         if (!inlined_functions_.count(func)) {
-          inlined_functions_.Set(func, CopyWithNewVars(func));
+          inlined_functions_.Set(func,
+                                 WithoutAttr(CopyWithNewVars(func), tvm::attr::kGlobalSymbol));
         }
         return Call(inlined_functions_[func], call->args);
       }

At the end of MergeCompositeFunction, a DCE pass will be called to removed the original copies of the functions that are merged. However, DCE will not remove public functions, i.e. those with a global symbol. Making those functions public prevents them from being cleaned up by DCE.

I think the logic here is correct. Those fused_xxx functions should be private as they are not meant to be called externally. So we probably don't want to change anything here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the explanation, @yelite. So, to confirm, you don't think there needs to be a change to the merge composite pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the merge composite pass shouldn't be changed. We might need to update the fusion pass if it produces fused functions with global symbol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think by default fusion pass should not produce functions with global symbol. Ideally we should prevent fusing functions that with global symbol --- because user expects to access the original function


// Step 2: Add the lifted function to the module
builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);
// (The lifted function should be public so we add a global symbol to it)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming we want them to be public.

@tqchen
Copy link
Member

tqchen commented Jun 21, 2023

cc @yongwww

->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))})));
// if the function is a global function and has a global symbol,
// then don't print the global symbol (it will be implicit from not being private)
if (d->frames.size() == 3 && n->attrs->dict.count("global_symbol")) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fond of having this magic number, but it was the only way I could figure out to check that the function is at the top level of a module. Possibly a hack--is there a better way I could check for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongwww might there be a more robust way to check if we're inside a module?

bb = relax.BlockBuilder()
x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
with bb.function("main", [x]):
with bb.function("main", [x], {"global_symbol": "main"}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we also have private=False flag here(default to False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private=False in the parser (which is the default) means that there's a global symbol in the attributes.

Copy link
Contributor Author

@slyubomirsky slyubomirsky Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the block builder should also have a private flag when building a function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this way it is closer to what function provide

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented that, good idea 👍

@tqchen
Copy link
Member

tqchen commented Jun 22, 2023

Thanks @slyubomirsky on the default behavior, would be great if we default private=False if not specified, and ask for explicit private=True.

It will involve more handling in printing(like not printing global symbol, and when global symbol do not exist, print as private=True), and modifying the test cases that comes with global_symbol (by removing them). But it would be much cleaner from user pov

@slyubomirsky
Copy link
Contributor Author

Do you mean that private should be another attribute? Or that it should be a field in the function AST node? RIght now, privacy is handled entirely in the parser (if private is False, which is the default, the parser adds a global symbol).

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Jun 22, 2023

The tough part is dealing with nested functions. There is no way to have a public nested function, but the BlockBuilder/printer/etc would have to know whether we're inside another function or not. (Though it would not be hard to modify the BlockBuilder to track that if that's what we want.)

@tqchen
Copy link
Member

tqchen commented Jun 22, 2023

sorry i am not talking about changing the AST. I only mean the interface

@R.function
def public_func():
    pass
    
    
@R.function(private=True):
def private_func():
    pass

And make it happen also for tir

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Jun 22, 2023

Then, yes, that is already how I had implemented it. False is the default value for private.

I will investigate doing the same for TIR (I've never touched the TVMScript for TIR).

@tqchen
Copy link
Member

tqchen commented Jun 23, 2023

We can get relax tests in first. Can we detect that printer prints out things correctly (aka skips the global symbol and instead use private annotation)?

@slyubomirsky
Copy link
Contributor Author

Yeah, there are printer tests included. I am making some further tweaks because the implementation right now is rather hacky.

)


def test_nested_function():
Copy link
Contributor Author

@slyubomirsky slyubomirsky Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually uncovered a seemingly unrelated bug in the printer, where the nested function would be printed with a suffix after its name (see the changes in src/script/printer/relax/function.cc) because the LHS variable in the VarBinding was treated as already having been encountered. Also fixed.

@slyubomirsky slyubomirsky changed the title [Unity][IR][UX] Privacy annotation [Unity][IR][UX] Privacy annotation in Relax Jun 23, 2023
@slyubomirsky
Copy link
Contributor Author

I will implement the same changes for TIR in a separate PR

@tqchen tqchen merged commit daf9c20 into apache:unity Jun 27, 2023
junrushao pushed a commit that referenced this pull request Jul 20, 2023
Implements proposal #14899 for TIR. Adapted from Unity PR #15171, by the same method as PR #15140. Namely, a TIR PrimFunc can be specified to be private (without a global symbol attribute) in TVMScript in the `prim_func` decorator. By default, `PrimFunc`s are not private, so they will have a `global_symbol` attribute that is mapped to their name.

Example usage:
```python
# not private: its global symbol will be "func"
@T.prim_func
def func(...): ...

# no global symbol included
@T.prim_func(private=True)
def func(...): ...
```

This did require changing very, very many tests, unfortunately.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants