1717 * specific language governing permissions and limitations
1818 * under the License.
1919 */
20+ /* !
21+ * \file extract_constant.cc
22+ * \brief Pushes out constants within partitioned functions all the way upto main()
23+ */
24+
2025#include < tvm/relay/attrs/nn.h>
2126#include < tvm/relay/expr_functor.h>
2227#include < tvm/relay/transform.h>
@@ -30,44 +35,47 @@ namespace relay {
3035namespace contrib {
3136namespace cmsisnn {
3237
38+ /* !
39+ * \brief This Mutator finds all functions with constants. Constants are replaced with function
40+ * parameter variables. Constants are pushed all the way upto main().
41+ */
3342class ExtractConstantsMutator : public MixedModeMutator {
3443 public:
35- explicit ExtractConstantsMutator (IRModule& mod) : mod_(mod) {}
44+ explicit ExtractConstantsMutator (const IRModule& mod) : mod_(mod) {}
3645
3746 private:
3847 String gen_var_name () { return " tvm_var_extract_const_" + std::to_string (var_count_++); }
3948
40- Expr VisitExpr_ (const FunctionNode* func) final {
41- Function final_func = GetRef<Function>(func);
42- ++func_nesting_level_;
49+ Expr VisitExpr_ (const FunctionNode* function) final {
50+ Function func = GetRef<Function>(function);
51+ function_to_constants_.Set (func, Array<Constant>{});
52+ functions_.push_back (func);
4353 auto new_body = VisitExpr (func->body );
44- --func_nesting_level_;
45- if (!new_body.same_as (func->body )) {
46- final_func = Function (FreeVars (new_body), new_body, func->ret_type ,
47- FreeTypeVars (new_body, mod_), func->attrs );
48- function_to_constants_.Set (GetRef<Function>(func), constants_within_function_);
49- constants_within_function_.clear ();
54+ functions_.pop_back ();
55+ if (function_to_constants_[func].size ()) {
56+ func = Function (FreeVars (new_body), new_body, func->ret_type , FreeTypeVars (new_body, mod_),
57+ func->attrs );
5058 }
51- return final_func ;
59+ return func ;
5260 }
5361
5462 Expr Rewrite_ (const CallNode* call, const Expr& post ) final {
5563 Expr final_call = post ;
5664 auto * post_call = post .as <CallNode>();
57- if (post_call == nullptr ) {
58- return final_call;
59- }
6065
6166 // Replace Constant arguments with Vars for ML Operators
6267 // Perform this for non-main Call Nodes only
63- if (func_nesting_level_ && call->op .as <OpNode>()) {
68+ if (!functions_. empty () && call->op .as <OpNode>()) {
6469 Array<Expr> new_args;
6570 for (auto & arg : post_call->args ) {
6671 auto * const_arg = arg.as <ConstantNode>();
6772 if (const_arg && !const_arg->is_scalar ()) {
6873 Var var_arg = Var (gen_var_name (), const_arg->tensor_type ());
6974 new_args.push_back (var_arg);
70- constants_within_function_.push_back (GetRef<Constant>(const_arg));
75+ const Function& last_func = functions_.back ();
76+ Array<Constant> fconstants (function_to_constants_[last_func]);
77+ fconstants.push_back (GetRef<Constant>(const_arg));
78+ function_to_constants_.Set (last_func, fconstants);
7179 } else {
7280 new_args.push_back (arg);
7381 }
@@ -94,17 +102,21 @@ class ExtractConstantsMutator : public MixedModeMutator {
94102
95103 // Since the constants are kicked out of the local partitioned functions
96104 // a new call to local function is needed
105+ // Also, pass on the constants to the callee of this function to support nested functions
97106 if (auto * func_node = call->op .as <FunctionNode>()) {
98107 Function func = GetRef<Function>(func_node);
99108 auto new_func = VisitExpr (func);
100109 if (!new_func.same_as (func)) {
101110 Array<Expr> new_args = post_call->args ;
102111 ICHECK (function_to_constants_.find (func) != function_to_constants_.end ());
112+ const Function& last_func = functions_.back ();
113+ Array<Constant> fconstants (function_to_constants_[last_func]);
103114 for (auto constant : function_to_constants_.at (func)) {
104- constants_within_function_ .push_back (constant);
115+ fconstants .push_back (constant);
105116 Var var_arg = Var (gen_var_name (), constant->tensor_type ());
106117 new_args.push_back (var_arg);
107118 }
119+ function_to_constants_.Set (last_func, fconstants);
108120 final_call = Call (new_func, new_args);
109121 }
110122 }
@@ -117,16 +129,14 @@ class ExtractConstantsMutator : public MixedModeMutator {
117129 IRModule mod_;
118130 /* \brief Maintains mapping of original function to the replaced constants */
119131 Map<Function, Array<Constant>> function_to_constants_;
120- /* \brief Constants being kicked out of a function during the function visit */
121- Array<Constant> constants_within_function_ ;
132+ /* \brief Stack of functions to determine scope while filling up function_to_constants_ */
133+ Array<Function> functions_ ;
122134 /* \brief Keeps track of variables being created */
123135 int var_count_ = 0 ;
124- /* \brief Keeps track of function scope */
125- int func_nesting_level_ = 0 ;
126136};
127137
128138/* ! * \brief Kicks out all constants out of the partitioned function into main() */
129- IRModule ExtractConstants (IRModule mod) {
139+ IRModule ExtractConstants (const IRModule& mod) {
130140 String func_name;
131141 Function func;
132142
@@ -150,7 +160,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
150160}
151161
152162TVM_REGISTER_GLOBAL (" relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction" )
153- .set_body_typed([]() { return ExtractConstantsFromPartitionedFunction (); } );
163+ .set_body_typed(ExtractConstantsFromPartitionedFunction);
154164
155165} // namespace cmsisnn
156166} // namespace contrib
0 commit comments