Skip to content

Commit f41e5a6

Browse files
committed
Updated tests in test_replace_copy.py
1 parent b81361f commit f41e5a6

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/python/contrib/test_ethosu/test_replace_copy.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tvm.relay.testing import run_opt_pass
2828
from tvm.script import tir as T
2929

30-
from .infra import make_ethosu_conv2d
30+
from .infra import make_ethosu_conv2d, copy_allocate_const_data
3131

3232

3333
# fmt: off
@@ -37,7 +37,10 @@ class ReferenceModule:
3737
def main(input_placeholder_3: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 8), "int8")) -> None:
3838
# function attr dict
3939
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
40-
buffer_1 = T.Buffer([384], "uint8")
40+
41+
data_1 = T.allocate_const([0]*384,'uint8',[384])
42+
buffer_1 = T.Buffer([384], "uint8",data=data_1)
43+
4144
placeholder_3 = T.Buffer([8192], dtype="int8", data=input_placeholder_3.data)
4245
ethosu_write_1 = T.Buffer([2048], dtype="int8", data=input_ethosu_write_1.data)
4346
# body
@@ -71,6 +74,7 @@ def _get_func():
7174
script = mod.script()
7275
test_mod = tvm.script.from_source(script)
7376
reference_mod = ReferenceModule
77+
reference_mod = copy_allocate_const_data(test_mod, reference_mod)
7478
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
7579

7680

@@ -81,8 +85,12 @@ class WeightStream:
8185
def main(input_placeholder_5: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 16), "int8")) -> None:
8286
# function attr dict
8387
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
84-
buffer = T.Buffer([528], "uint8")
85-
buffer_2 = T.Buffer([336], "uint8")
88+
89+
data_2 = T.allocate_const([0]*336, 'uint8',[336])
90+
buffer_2 = T.Buffer([336], "uint8",data=data_2)
91+
data = T.allocate_const([0]*528, 'uint8',[528])
92+
buffer = T.Buffer([528], "uint8",data=data)
93+
8694
placeholder_5 = T.Buffer([8192], dtype="int8", data=input_placeholder_5.data)
8795
ethosu_write_1 = T.Buffer([4096], dtype="int8", data=input_ethosu_write_1.data)
8896
# body
@@ -131,6 +139,7 @@ def _get_func():
131139
script = mod.script()
132140
test_mod = tvm.script.from_source(script)
133141
reference_mod = WeightStream
142+
reference_mod = copy_allocate_const_data(test_mod, reference_mod)
134143
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
135144

136145

0 commit comments

Comments
 (0)