Skip to content

Commit 9a592bc

Browse files
committed
Support allocates when not followed by a sequence statement
With a test to back this case up. Change-Id: I670809f5ee53b583a15d9b783852dda3089756e9
1 parent c83e144 commit 9a592bc

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

src/tir/contrib/ethosu/passes.cc

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,21 @@ namespace ethosu {
4141
* "main" that is being offloaded to the NPU.
4242
*
4343
* For example,
44-
* allocate {
45-
* extern_call(...)
46-
* allocate {
47-
* Before: extern_call(...)
48-
* }
49-
* }
44+
* Before:
45+
* allocate {
46+
* extern_call(...)
47+
* allocate {
48+
* extern_call(...)
49+
* }
50+
* }
5051
*
51-
* allocate {
52-
* allocate {
53-
* extern_call(...)
54-
* After: extern_call(...)
55-
* }
56-
* }
52+
* After:
53+
* allocate {
54+
* allocate {
55+
* extern_call(...)
56+
* extern_call(...)
57+
* }
58+
* }
5759
*/
5860
class HoistAllocatesMutator : public StmtExprMutator {
5961
public:
@@ -85,16 +87,17 @@ class HoistAllocatesMutator : public StmtExprMutator {
8587
allocates_.push_back(GetRef<Allocate>(op));
8688

8789
// Skip the allocate node itself
88-
const auto* seq = op->body.as<SeqStmtNode>();
89-
ICHECK(seq) << "Expected a sequence statement but got " << op->body->GetTypeKey() << ".";
90-
91-
// Traverse the allocate body recursively and flatten
92-
Array<Stmt> new_stmts;
93-
new_stmts.reserve(seq->seq.size());
94-
for (const Stmt& old_stmt : seq->seq) {
95-
new_stmts.push_back(VisitStmt(old_stmt));
90+
if (const auto* seq = op->body.as<SeqStmtNode>()) {
91+
// Traverse the allocate body recursively and flatten
92+
Array<Stmt> new_stmts;
93+
new_stmts.reserve(seq->seq.size());
94+
for (const Stmt& old_stmt : seq->seq) {
95+
new_stmts.push_back(VisitStmt(old_stmt));
96+
}
97+
return SeqStmt::Flatten(new_stmts);
98+
} else {
99+
return VisitStmt(op->body);
96100
}
97-
return SeqStmt::Flatten(new_stmts);
98101
}
99102

100103
/*! A stack to store allocates as they are visited. */

tests/python/contrib/test_ethosu/test_hoist_allocates.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,48 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,),
204204
CheckAllocates(allocate_info)(mod)
205205

206206

207+
def test_allocate_without_seq_stmt():
208+
"""
209+
Tests the case when an allocate statement does not have a sequence statement as its body.
210+
"""
211+
# fmt: off
212+
@tvm.script.ir_module
213+
class Module:
214+
@T.prim_func
215+
def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None:
216+
# function attr dict
217+
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
218+
T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
219+
T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
220+
# body
221+
placeholder_global = T.allocate([128], "uint8", "global")
222+
placeholder_global_1 = T.allocate([112], "uint8", "global")
223+
placeholder_global_2 = T.allocate([112], "uint8", "global")
224+
placeholder_d_global = T.allocate([32], "uint8", "global")
225+
placeholder_d_global_1 = T.allocate([32], "uint8", "global")
226+
placeholder_d_global_2 = T.allocate([32], "uint8", "global")
227+
placeholder_global_3 = T.allocate([112], "uint8", "global")
228+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, placeholder_global[0], dtype="handle"))
229+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle"))
230+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
231+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 112, placeholder_global_1[0], dtype="handle"))
232+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle"))
233+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12, placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
234+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112, placeholder_global_2[0], dtype="handle"))
235+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle"))
236+
placeholder_d_global_3 = T.allocate([32], "uint8", "global")
237+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 112, 12, placeholder_d_global_2[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
238+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 112, placeholder_global_3[0], dtype="handle"))
239+
T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle"))
240+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 112, 12, placeholder_d_global_3[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
241+
# fmt: on
242+
243+
mod = Module
244+
allocate_info = ExtractAllocateInfo()(mod)
245+
mod = HoistAllocates()(mod)
246+
CheckAllocates(allocate_info)(mod)
247+
248+
207249
def test_multiple_prim_funcs():
208250
@tvm.script.ir_module
209251
class Module:

0 commit comments

Comments
 (0)