2727from tvm .relay .testing import run_opt_pass
2828from tvm .script import tir as T
2929
30- from .infra import make_ethosu_conv2d
30+ from .infra import make_ethosu_conv2d , copy_allocate_const_data
3131
3232
3333# fmt: off
@@ -37,7 +37,10 @@ class ReferenceModule:
3737 def main (input_placeholder_3 : T .Buffer ((1 , 16 , 16 , 32 ), "int8" ), input_ethosu_write_1 : T .Buffer ((1 , 16 , 16 , 8 ), "int8" )) -> None :
3838 # function attr dict
3939 T .func_attr ({"from_legacy_te_schedule" : True , "global_symbol" : "main" , "tir.noalias" : True })
40- buffer_1 = T .Buffer ([384 ], "uint8" )
40+
41+ data_1 = T .allocate_const ([0 ]* 384 ,'uint8' ,[384 ])
42+ buffer_1 = T .Buffer ([384 ], "uint8" ,data = data_1 )
43+
4144 placeholder_3 = T .Buffer ([8192 ], dtype = "int8" , data = input_placeholder_3 .data )
4245 ethosu_write_1 = T .Buffer ([2048 ], dtype = "int8" , data = input_ethosu_write_1 .data )
4346 # body
@@ -71,6 +74,7 @@ def _get_func():
7174 script = mod .script ()
7275 test_mod = tvm .script .from_source (script )
7376 reference_mod = ReferenceModule
77+ reference_mod = copy_allocate_const_data (test_mod , reference_mod )
7478 tvm .ir .assert_structural_equal (test_mod ["main" ], reference_mod ["main" ], True )
7579
7680
@@ -81,8 +85,12 @@ class WeightStream:
8185 def main (input_placeholder_5 : T .Buffer ((1 , 16 , 16 , 32 ), "int8" ), input_ethosu_write_1 : T .Buffer ((1 , 16 , 16 , 16 ), "int8" )) -> None :
8286 # function attr dict
8387 T .func_attr ({"from_legacy_te_schedule" : True , "global_symbol" : "main" , "tir.noalias" : True })
84- buffer = T .Buffer ([528 ], "uint8" )
85- buffer_2 = T .Buffer ([336 ], "uint8" )
88+
89+ data_2 = T .allocate_const ([0 ]* 336 , 'uint8' ,[336 ])
90+ buffer_2 = T .Buffer ([336 ], "uint8" ,data = data_2 )
91+ data = T .allocate_const ([0 ]* 528 , 'uint8' ,[528 ])
92+ buffer = T .Buffer ([528 ], "uint8" ,data = data )
93+
8694 placeholder_5 = T .Buffer ([8192 ], dtype = "int8" , data = input_placeholder_5 .data )
8795 ethosu_write_1 = T .Buffer ([4096 ], dtype = "int8" , data = input_ethosu_write_1 .data )
8896 # body
@@ -131,6 +139,7 @@ def _get_func():
131139 script = mod .script ()
132140 test_mod = tvm .script .from_source (script )
133141 reference_mod = WeightStream
142+ reference_mod = copy_allocate_const_data (test_mod , reference_mod )
134143 tvm .ir .assert_structural_equal (test_mod ["main" ], reference_mod ["main" ], True )
135144
136145
0 commit comments