|
21 | 21 | import tvm |
22 | 22 | from tvm import relay |
23 | 23 | from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator |
| 24 | +from tvm.relay.backend.contrib.ethosu.codegen import collect_consts |
24 | 25 | from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir |
25 | 26 | from tvm.relay.backend.contrib.ethosu.tir.scheduler import ( |
26 | 27 | OperatorCompute, |
@@ -140,12 +141,12 @@ def _get_func(): |
140 | 141 | } |
141 | 142 | with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): |
142 | 143 | func = _get_func() |
143 | | - mod, consts = _lower_to_tir(func, cascader=_planner) |
| 144 | + mod = _lower_to_tir(func, cascader=_planner) |
144 | 145 | script = mod.script() |
145 | 146 | test_mod = tvm.script.from_source(script) |
146 | 147 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
147 | 148 |
|
148 | | - test_const_size = [value.size for value in list(consts.values())] |
| 149 | + test_const_size = [value.size for value in collect_consts(test_mod).values()] |
149 | 150 | assert reference_const_sizes.sort() == test_const_size.sort() |
150 | 151 |
|
151 | 152 |
|
@@ -242,12 +243,12 @@ def _get_func(): |
242 | 243 | } |
243 | 244 | with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): |
244 | 245 | func = _get_func() |
245 | | - mod, consts = _lower_to_tir(func, cascader=_cascader) |
| 246 | + mod = _lower_to_tir(func, cascader=_cascader) |
246 | 247 | script = mod.script() |
247 | 248 | test_mod = tvm.script.from_source(script) |
248 | 249 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
249 | 250 |
|
250 | | - test_const_size = [value.size for value in list(consts.values())] |
| 251 | + test_const_size = [value.size for value in collect_consts(test_mod).values()] |
251 | 252 | assert reference_const_sizes.sort() == test_const_size.sort() |
252 | 253 |
|
253 | 254 |
|
@@ -339,13 +340,13 @@ def _get_func(): |
339 | 340 | } |
340 | 341 | with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): |
341 | 342 | func = _get_func() |
342 | | - mod, consts = _lower_to_tir(func) |
| 343 | + mod = _lower_to_tir(func) |
343 | 344 |
|
344 | 345 | script = mod.script() |
345 | 346 | test_mod = tvm.script.from_source(script) |
346 | 347 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
347 | 348 |
|
348 | | - test_const_size = [value.size for value in list(consts.values())] |
| 349 | + test_const_size = [value.size for value in collect_consts(test_mod).values()] |
349 | 350 | assert reference_const_sizes.sort() == test_const_size.sort() |
350 | 351 |
|
351 | 352 |
|
@@ -474,13 +475,13 @@ def _get_func(): |
474 | 475 | } |
475 | 476 | with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): |
476 | 477 | func = _get_func() |
477 | | - mod, consts = _lower_to_tir(func, cascader=_planner) |
| 478 | + mod = _lower_to_tir(func, cascader=_planner) |
478 | 479 |
|
479 | 480 | script = mod.script() |
480 | 481 | test_mod = tvm.script.from_source(script) |
481 | 482 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
482 | 483 |
|
483 | | - test_const_size = [value.size for value in list(consts.values())] |
| 484 | + test_const_size = [value.size for value in collect_consts(test_mod).values()] |
484 | 485 | assert reference_const_sizes.sort() == test_const_size.sort() |
485 | 486 |
|
486 | 487 |
|
|
0 commit comments