Skip to content

Commit c83e144

Browse files
committed
address comments
* uses prim func pass, rather than module pass. * adds error message informing user to run this pass with LowerToTIR() pass for now. Change-Id: I57757b9dc5bff0208034a974a341c09cce0294bc
1 parent 833812a commit c83e144

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

src/tir/contrib/ethosu/passes.cc

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ namespace ethosu {
3636
* \brief This mutator moves allocates to the top of the body of the main
3737
* function.
3838
*
39+
* Note: This pass can currently only be run in conjunction with the
40+
* LowerToTIR() pass as it expects a single primitive function called
41+
* "main" that is being offloaded to the NPU.
42+
*
3943
* For example,
4044
* allocate {
41-
* extern_call(...) {
42-
* allocate {
43-
* Before: extern_call(...)
44-
* }
45+
* extern_call(...)
46+
* allocate {
47+
* Before: extern_call(...)
4548
* }
4649
* }
4750
*
@@ -56,9 +59,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
5659
public:
5760
HoistAllocatesMutator() {}
5861

59-
IRModule operator()(IRModule mod) {
60-
GlobalVar gv = mod->GetGlobalVar("main");
61-
PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(gv));
62+
PrimFunc operator()(PrimFunc main_func) {
6263
Stmt new_main_func_body = this->VisitStmt(main_func->body);
6364

6465
// Write all allocates that were removed in reverse order
@@ -76,8 +77,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
7677
PrimFunc new_main_func =
7778
PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map,
7879
main_func->preflattened_buffer_map, main_func->attrs);
79-
mod->Update(gv, new_main_func);
80-
return mod;
80+
return new_main_func;
8181
}
8282

8383
private:
@@ -107,10 +107,14 @@ class HoistAllocatesMutator : public StmtExprMutator {
107107
* \return tvm::transform::Pass
108108
*/
109109
tvm::transform::Pass HoistAllocates() {
110-
auto pass_func = [=](IRModule mod, tvm::transform::PassContext ctx) {
111-
return HoistAllocatesMutator()(mod);
110+
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
111+
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
112+
<< "Expected a single primitive function called 'main'. Please run the HoistAllocates pass "
113+
"in conjunction with the LowerToTIR() pass.";
114+
return HoistAllocatesMutator()(f);
112115
};
113-
return tvm::transform::CreateModulePass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates", {});
116+
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates",
117+
{});
114118
}
115119

116120
TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);

tests/python/contrib/test_ethosu/test_hoist_allocates.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,41 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,),
202202
allocate_info = ExtractAllocateInfo()(mod)
203203
mod = HoistAllocates()(mod)
204204
CheckAllocates(allocate_info)(mod)
205+
206+
207+
def test_multiple_prim_funcs():
208+
@tvm.script.ir_module
209+
class Module:
210+
@T.prim_func
211+
def main():
212+
T.evaluate(0)
213+
214+
@T.prim_func
215+
def abc():
216+
T.evaluate(0)
217+
218+
mod = Module
219+
220+
err_rgx = (
221+
r"Expected a single primitive function called 'main'. "
222+
r"Please run the HoistAllocates pass in conjunction with the LowerToTIR\(\) pass."
223+
)
224+
with pytest.raises(tvm.TVMError, match=err_rgx):
225+
mod = HoistAllocates()(mod)
226+
227+
228+
def test_no_main_prim_func():
229+
@tvm.script.ir_module
230+
class Module:
231+
@T.prim_func
232+
def abs():
233+
T.evaluate(0)
234+
235+
mod = Module
236+
237+
err_rgx = (
238+
r"Expected a single primitive function called 'main'. "
239+
r"Please run the HoistAllocates pass in conjunction with the LowerToTIR\(\) pass."
240+
)
241+
with pytest.raises(tvm.TVMError, match=err_rgx):
242+
mod = HoistAllocates()(mod)

0 commit comments

Comments
 (0)