Skip to content

Commit 5d4bc85

Browse files
committed
Updated tests in test_scheduler.py
1 parent f41e5a6 commit 5d4bc85

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/python/contrib/test_ethosu/test_scheduler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
make_ethosu_conv2d,
4242
make_ethosu_identity,
4343
make_ethosu_binary_elementwise,
44+
copy_allocate_const_data,
4445
)
4546

4647

@@ -198,10 +199,15 @@ class DiamondGraphTir:
198199
@T.prim_func
199200
def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_write: T.Buffer((1, 56, 56, 24), "int8")) -> None:
200201
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
202+
203+
data3 = T.allocate_const([0]*976, 'uint8', [976])
204+
buffer3 = T.Buffer([976], "uint8", data=data3)
205+
data1 = T.allocate_const([0]*2848, 'uint8', [2848])
206+
buffer1 = T.Buffer([2848], "uint8", data=data1)
207+
208+
201209
placeholder = T.Buffer([301056], dtype='int8', data=input_placeholder.data)
202210
ethosu_write = T.Buffer([75264], dtype='int8', data=input_ethosu_write.data)
203-
buffer1 = T.Buffer([2848], "uint8")
204-
buffer3 = T.Buffer([976], "uint8")
205211
p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True})
206212
p1 = T.Buffer([2848], "uint8", data=p1_data)
207213
p2_data = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True})
@@ -230,6 +236,7 @@ def test_schedule_diamond_graph():
230236

231237
test_mod = _lower_to_tir(func, copy_constants())
232238
reference_mod = DiamondGraphTir
239+
reference_mod = copy_allocate_const_data(test_mod, reference_mod)
233240
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
234241

235242

0 commit comments

Comments
 (0)