Skip to content

Commit 079eb4e

Browse files
authored
[microNPU] Add a pass to move allocate nodes to the outer scope (#10725)
* [microNPU] Add a pass to move allocate nodes to the outer scope Adds a pass called `HoistAllocates` to move allocate nodes to the top of the body of the main function. In doing so, it opens the door to other optimizations that need to swap the ordering of external calls. Pass illustration: (before) ``` allocate { extern_call { allocate { extern_call { } } } } ``` (after) ``` allocate { allocate { extern_call extern_call } } ``` Change-Id: Ibcfc3c75b15deebb5c6645a4923a6ddf683b37c4 * address comments * uses prim func pass, rather than module pass. * adds error message informing user to run this pass with LowerToTIR() pass for now. Change-Id: I57757b9dc5bff0208034a974a341c09cce0294bc * Support allocates when not followed by a sequence statement With a test to back this case up. Change-Id: I670809f5ee53b583a15d9b783852dda3089756e9 * Add new directory tir/contrib/ethosu to cmake build Change-Id: I3e9f24adfe992ace4e03238a18a8378b03257e1a
1 parent 937a14f commit 079eb4e

File tree

6 files changed

+437
-8
lines changed

6 files changed

+437
-8
lines changed

cmake/modules/contrib/EthosU.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ if(USE_ETHOSU)
1919
tvm_file_glob(GLOB COMPILER_ETHOSU_SRCS
2020
src/relay/backend/contrib/ethosu/*
2121
src/contrib/ethosu/cascader/*
22-
src/contrib/ethosu/cascader/parts/*)
22+
src/contrib/ethosu/cascader/parts/*
23+
src/tir/contrib/ethosu/*)
2324
list(APPEND COMPILER_SRCS ${COMPILER_ETHOSU_SRCS})
2425
else()
2526
# Keeping just utils.cc because it has Object definitions

python/tvm/relay/backend/contrib/ethosu/_ffi_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
import tvm._ffi # type: ignore
1919

2020
tvm._ffi._init_api("relay.ext.ethos-u", __name__)
21+
tvm._ffi._init_api("tir.contrib.ethos-u", __name__)

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
8888
mod = ethosu_passes.ReplaceOperators()(mod)
8989
mod = tvm.tir.transform.RemoveNoOp()(mod)
9090
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
91+
mod = ethosu_passes.HoistAllocates()(mod)
9192
disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False)
9293
if not disable_storage_rewrite:
9394
mod = tvm.tir.transform.StorageRewrite()(mod)

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from .transform import get_copy_params
3131
from .utils import get_weights_buffer, get_scale_bias_buffer
3232

33+
from .. import _ffi_api
34+
3335

3436
def RemoveZeroStores():
3537
"""This pass removes stores which just store zero to initialise buffers.
@@ -48,7 +50,7 @@ def _ftransform(f, mod, ctx):
4850
)
4951

5052
return tvm.tir.transform.prim_func_pass(
51-
_ftransform, opt_level=0, name="tir.ethosu.remove_zero_stores"
53+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.remove_zero_stores"
5254
)
5355

5456

@@ -207,7 +209,7 @@ def _ftransform(f, mod, ctx):
207209
)
208210

209211
return tvm.tir.transform.prim_func_pass(
210-
_ftransform, opt_level=0, name="tir.ethosu.replace_operators"
212+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.replace_operators"
211213
)
212214

213215

@@ -296,7 +298,7 @@ def _ftransform(f, mod, ctx):
296298

297299
def _divide_constants(mod):
298300
transform_func = tvm.tir.transform.prim_func_pass(
299-
_ftransform, opt_level=0, name="tir.ethosu.divide_constants"
301+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.divide_constants"
300302
)
301303
new_func = transform_func(mod)
302304
return new_func, new_const_dict
@@ -549,7 +551,7 @@ def _encode_constants(mod):
549551
for key, value in divided_const_dict.items():
550552
const_dict[key] = value
551553
transform_func = tvm.tir.transform.prim_func_pass(
552-
_ftransform, opt_level=0, name="tir.ethosu.encode_constants"
554+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.encode_constants"
553555
)
554556
new_func = transform_func(mod)
555557
return new_func, new_const_dict
@@ -584,7 +586,7 @@ def _ftransform(f, mod, ctx):
584586
)
585587

586588
return tvm.tir.transform.prim_func_pass(
587-
_ftransform, opt_level=0, name="tir.ethosu.annotate_allocates"
589+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.annotate_allocates"
588590
)
589591

590592

@@ -751,7 +753,7 @@ def _ftransform(f, mod, ctx):
751753
)
752754

753755
return tvm.tir.transform.prim_func_pass(
754-
_ftransform, opt_level=0, name="tir.ethosu.remove_concatenates"
756+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.remove_concatenates"
755757
)
756758

757759

@@ -795,9 +797,21 @@ def _ftransform(f, mod, ctx):
795797

796798
def _create_primfunc_without_constants(mod):
797799
transform_func = tvm.tir.transform.prim_func_pass(
798-
_ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants"
800+
_ftransform, opt_level=0, name="tir.contrib.ethos-u.CreatePrimFuncWithoutConstants"
799801
)
800802
mod = transform_func(mod)
801803
return mod, new_const_dict
802804

803805
return _create_primfunc_without_constants
806+
807+
808+
def HoistAllocates() -> tvm.IRModule:
809+
"""
810+
Hoist allocate nodes up to the top of the body of the main function.
811+
812+
Returns
813+
-------
814+
tvm.IRModule
815+
The new module with hoisted allocate nodes.
816+
"""
817+
return _ffi_api.HoistAllocates()

src/tir/contrib/ethosu/passes.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tir/contrib/ethosu/passes.cc
22+
*
23+
* \brief Passes used in TIR lowering for the microNPU compiler.
24+
*/
25+
#include <tvm/tir/builtin.h>
26+
#include <tvm/tir/function.h>
27+
#include <tvm/tir/stmt_functor.h>
28+
#include <tvm/tir/transform.h>
29+
30+
namespace tvm {
31+
namespace tir {
32+
namespace contrib {
33+
namespace ethosu {
34+
35+
/*!
36+
* \brief This mutator moves allocates to the top of the body of the main
37+
* function.
38+
*
39+
* Note: This pass can currently only be run in conjunction with the
40+
* LowerToTIR() pass as it expects a single primitive function called
41+
* "main" that is being offloaded to the NPU.
42+
*
43+
* For example,
44+
* Before:
45+
* allocate {
46+
* extern_call(...)
47+
* allocate {
48+
* extern_call(...)
49+
* }
50+
* }
51+
*
52+
* After:
53+
* allocate {
54+
* allocate {
55+
* extern_call(...)
56+
* extern_call(...)
57+
* }
58+
* }
59+
*/
60+
class HoistAllocatesMutator : public StmtExprMutator {
61+
public:
62+
HoistAllocatesMutator() {}
63+
64+
PrimFunc operator()(PrimFunc main_func) {
65+
Stmt new_main_func_body = this->VisitStmt(main_func->body);
66+
67+
// Write all allocates that were removed in reverse order
68+
for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) {
69+
Allocate current_alloc = *it;
70+
if (it != allocates_.rbegin()) {
71+
new_main_func_body = SeqStmt({new_main_func_body});
72+
}
73+
new_main_func_body =
74+
Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents,
75+
current_alloc->condition, new_main_func_body, current_alloc->annotations,
76+
current_alloc->span);
77+
}
78+
79+
PrimFunc new_main_func =
80+
PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map,
81+
main_func->preflattened_buffer_map, main_func->attrs);
82+
return new_main_func;
83+
}
84+
85+
private:
86+
Stmt VisitStmt_(const AllocateNode* op) override {
87+
allocates_.push_back(GetRef<Allocate>(op));
88+
89+
// Skip the allocate node itself
90+
if (const auto* seq = op->body.as<SeqStmtNode>()) {
91+
// Traverse the allocate body recursively and flatten
92+
Array<Stmt> new_stmts;
93+
new_stmts.reserve(seq->seq.size());
94+
for (const Stmt& old_stmt : seq->seq) {
95+
new_stmts.push_back(VisitStmt(old_stmt));
96+
}
97+
return SeqStmt::Flatten(new_stmts);
98+
} else {
99+
return VisitStmt(op->body);
100+
}
101+
}
102+
103+
/*! A stack to store allocates as they are visited. */
104+
std::vector<Allocate> allocates_;
105+
};
106+
107+
/*!
108+
* \brief A pass to hoist allocate nodes to the top of the body of the main function.
109+
*
110+
* \return tvm::transform::Pass
111+
*/
112+
tvm::transform::Pass HoistAllocates() {
113+
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
114+
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
115+
<< "Expected a single primitive function called 'main'. Please run the HoistAllocates pass "
116+
"in conjunction with the LowerToTIR() pass.";
117+
return HoistAllocatesMutator()(f);
118+
};
119+
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates",
120+
{});
121+
}
122+
123+
TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);
124+
125+
} // namespace ethosu
126+
} // namespace contrib
127+
} // namespace tir
128+
} // namespace tvm

0 commit comments

Comments
 (0)