|
41 | 41 | make_ethosu_conv2d, |
42 | 42 | make_ethosu_identity, |
43 | 43 | make_ethosu_binary_elementwise, |
| 44 | + copy_allocate_const_data, |
44 | 45 | ) |
45 | 46 |
|
46 | 47 |
|
@@ -198,10 +199,15 @@ class DiamondGraphTir: |
198 | 199 | @T.prim_func |
199 | 200 | def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_write: T.Buffer((1, 56, 56, 24), "int8")) -> None: |
200 | 201 | T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) |
| 202 | + |
| 203 | + data3 = T.allocate_const([0]*976, 'uint8', [976]) |
| 204 | + buffer3 = T.Buffer([976], "uint8", data=data3) |
| 205 | + data1 = T.allocate_const([0]*2848, 'uint8', [2848]) |
| 206 | + buffer1 = T.Buffer([2848], "uint8", data=data1) |
| 207 | + |
| 208 | + |
201 | 209 | placeholder = T.Buffer([301056], dtype='int8', data=input_placeholder.data) |
202 | 210 | ethosu_write = T.Buffer([75264], dtype='int8', data=input_ethosu_write.data) |
203 | | - buffer1 = T.Buffer([2848], "uint8") |
204 | | - buffer3 = T.Buffer([976], "uint8") |
205 | 211 | p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) |
206 | 212 | p1 = T.Buffer([2848], "uint8", data=p1_data) |
207 | 213 | p2_data = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) |
@@ -230,6 +236,7 @@ def test_schedule_diamond_graph(): |
230 | 236 |
|
231 | 237 | test_mod = _lower_to_tir(func, copy_constants()) |
232 | 238 | reference_mod = DiamondGraphTir |
| 239 | + reference_mod = copy_allocate_const_data(test_mod, reference_mod) |
233 | 240 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
234 | 241 |
|
235 | 242 |
|
|
0 commit comments