Skip to content

Commit 2c8b757

Browse files
committed
Updated tests in test_remove_concatenates.py
1 parent 5d4bc85 commit 2c8b757

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tests/python/contrib/test_ethosu/test_remove_concatenates.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tvm.relay.testing import run_opt_pass
2525
from tvm.script import tir as T
2626

27-
from .infra import make_ethosu_conv2d
27+
from .infra import make_ethosu_conv2d, copy_allocate_const_data
2828

2929

3030
# fmt: off
@@ -35,18 +35,29 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1:
3535
# function attr dict
3636
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
3737

38+
data_3 = T.allocate_const([0]*160, 'uint8', [160])
39+
data_2 = T.allocate_const([0]*2992, 'uint8', [2992])
40+
data_7 = T.allocate_const([0]*160, 'uint8', [160])
41+
data_6 = T.allocate_const([0]*2992, 'uint8', [2992])
42+
data_1 = T.allocate_const([0]*160, 'uint8', [160])
43+
data = T.allocate_const([0]*2992, 'uint8', [2992])
44+
data_5 = T.allocate_const([0]*160, 'uint8', [160])
45+
data_4 = T.allocate_const([0]*2992, 'uint8', [2992])
46+
47+
buffer = T.Buffer([2992], "uint8", data=data)
48+
buffer_1 = T.Buffer([160], "uint8", data=data_1)
49+
buffer_2 = T.Buffer([2992], "uint8", data=data_2)
50+
buffer_3 = T.Buffer([160], "uint8", data=data_3)
51+
buffer_4 = T.Buffer([2992], "uint8", data=data_4)
52+
buffer_5 = T.Buffer([160], "uint8", data=data_5)
53+
buffer_6 = T.Buffer([2992], "uint8", data=data_6)
54+
buffer_7 = T.Buffer([160], "uint8", data=data_7)
55+
3856
placeholder = T.Buffer(1536, dtype="int8", data=input_placeholder.data)
3957
placeholder_1 = T.Buffer(1280, dtype="int8", data=input_placeholder_1.data)
4058
T_concat = T.Buffer(4096, dtype="int8", data=input_T_concat.data)
4159

42-
buffer = T.Buffer([2992], "uint8")
43-
buffer_1 = T.Buffer([160], "uint8")
44-
buffer_2 = T.Buffer([2992], "uint8")
45-
buffer_3 = T.Buffer([160], "uint8")
46-
buffer_4 = T.Buffer([2992], "uint8")
47-
buffer_5 = T.Buffer([160], "uint8")
48-
buffer_6 = T.Buffer([2992], "uint8")
49-
buffer_7 = T.Buffer([160], "uint8")
60+
5061
# body
5162
T_concat_1_data = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True})
5263
T_concat_1 = T.Buffer([2816], "int8", data=T_concat_1_data)
@@ -78,6 +89,7 @@ def _get_func():
7889
test_mod = tvm.script.from_source(script)
7990

8091
reference_mod = ReferenceModule
92+
reference_mod = copy_allocate_const_data(test_mod, reference_mod)
8193
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
8294

8395

0 commit comments

Comments
 (0)