@@ -36,12 +36,15 @@ namespace ethosu {
3636 * \brief This mutator moves allocates to the top of the body of the main
3737 * function.
3838 *
39+ * Note: This pass can currently only be run in conjunction with the
40+ * LowerToTIR() pass as it expects a single primitive function called
41+ * "main" that is being offloaded to the NPU.
42+ *
3943 * For example,
4044 * allocate {
41- * extern_call(...) {
42- * allocate {
43- * Before: extern_call(...)
44- * }
45+ * extern_call(...)
46+ * allocate {
47+ * Before: extern_call(...)
4548 * }
4649 * }
4750 *
@@ -56,9 +59,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
5659 public:
5760 HoistAllocatesMutator () {}
5861
59- IRModule operator ()(IRModule mod) {
60- GlobalVar gv = mod->GetGlobalVar (" main" );
61- PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup (gv));
62+ PrimFunc operator ()(PrimFunc main_func) {
6263 Stmt new_main_func_body = this ->VisitStmt (main_func->body );
6364
6465 // Write all allocates that were removed in reverse order
@@ -76,8 +77,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
7677 PrimFunc new_main_func =
7778 PrimFunc (main_func->params , new_main_func_body, main_func->ret_type , main_func->buffer_map ,
7879 main_func->preflattened_buffer_map , main_func->attrs );
79- mod->Update (gv, new_main_func);
80- return mod;
80+ return new_main_func;
8181 }
8282
8383 private:
@@ -107,10 +107,14 @@ class HoistAllocatesMutator : public StmtExprMutator {
107107 * \return tvm::transform::Pass
108108 */
109109tvm::transform::Pass HoistAllocates () {
110- auto pass_func = [=](IRModule mod, tvm::transform::PassContext ctx) {
111- return HoistAllocatesMutator ()(mod);
110+ auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
111+ ICHECK (mod->GetGlobalVars ().size () == 1 && mod->ContainGlobalVar (" main" ))
112+ << " Expected a single primitive function called 'main'. Please run the HoistAllocates pass "
113+ " in conjunction with the LowerToTIR() pass." ;
114+ return HoistAllocatesMutator ()(f);
112115 };
113- return tvm::transform::CreateModulePass (pass_func, 0 , " tir.contrib.ethos-u.HoistAllocates" , {});
116+ return tvm::tir::transform::CreatePrimFuncPass (pass_func, 0 , " tir.contrib.ethos-u.HoistAllocates" ,
117+ {});
114118}
115119
116120TVM_REGISTER_GLOBAL (" tir.contrib.ethos-u.HoistAllocates" ).set_body_typed(HoistAllocates);
0 commit comments