|
33 | 33 | from tvm import relay |
34 | 34 | from tvm.relay.op.contrib.ethosu import partition_for_ethosu |
35 | 35 | from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer |
| 36 | +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func |
36 | 37 |
|
37 | 38 | from . import infra |
38 | 39 |
|
39 | 40 |
|
40 | | -def _run_pass(expr, relay_pass): |
41 | | - """Create IRModule and run Relay pass.""" |
| 41 | +def _optimize(expr, optimize=True): |
| 42 | + """Create IRModule and run layout optimizer pass.""" |
42 | 43 | mod = tvm.IRModule.from_expr(expr) |
43 | | - mod = relay_pass(mod) |
| 44 | + mod = relay.transform.InferType()(mod) |
| 45 | + if optimize: |
| 46 | + mod = LayoutOptimizer()(mod) |
44 | 47 | entry = mod["main"] |
45 | 48 | return entry if isinstance(expr, relay.Function) else entry.body |
46 | 49 |
|
@@ -111,8 +114,8 @@ def get_graph(): |
111 | 114 | ) |
112 | 115 | return relay.Function(relay.analysis.free_vars(x), x) |
113 | 116 |
|
114 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
115 | | - b = _run_pass(get_graph(), relay.transform.InferType()) |
| 117 | + a = _optimize(get_graph()) |
| 118 | + b = _optimize(get_graph(), optimize=False) |
116 | 119 | _assert_structural_equal(a, b) |
117 | 120 |
|
118 | 121 |
|
@@ -144,8 +147,8 @@ def get_graph(get_expected=False): |
144 | 147 | ) |
145 | 148 | return relay.Function(relay.analysis.free_vars(x), x) |
146 | 149 |
|
147 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
148 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 150 | + a = _optimize(get_graph()) |
| 151 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
149 | 152 | _assert_structural_equal(a, b) |
150 | 153 |
|
151 | 154 |
|
@@ -176,8 +179,8 @@ def get_graph(get_expected=False): |
176 | 179 | ) |
177 | 180 | return relay.Function(relay.analysis.free_vars(x), x) |
178 | 181 |
|
179 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
180 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 182 | + a = _optimize(get_graph()) |
| 183 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
181 | 184 | _assert_structural_equal(a, b) |
182 | 185 |
|
183 | 186 |
|
@@ -222,8 +225,8 @@ def get_graph(): |
222 | 225 | ) |
223 | 226 | return relay.Function(relay.analysis.free_vars(conv_2), conv_2) |
224 | 227 |
|
225 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
226 | | - b = _run_pass(get_graph(), relay.transform.InferType()) |
| 228 | + a = _optimize(get_graph()) |
| 229 | + b = _optimize(get_graph(), optimize=False) |
227 | 230 | _assert_structural_equal(a, b) |
228 | 231 |
|
229 | 232 |
|
@@ -268,8 +271,8 @@ def get_graph(): |
268 | 271 | ) |
269 | 272 | return relay.Function(relay.analysis.free_vars(conv_2), conv_2) |
270 | 273 |
|
271 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
272 | | - b = _run_pass(get_graph(), relay.transform.InferType()) |
| 274 | + a = _optimize(get_graph()) |
| 275 | + b = _optimize(get_graph(), optimize=False) |
273 | 276 | _assert_structural_equal(a, b) |
274 | 277 |
|
275 | 278 |
|
@@ -322,8 +325,8 @@ def get_graph(): |
322 | 325 | ) |
323 | 326 | return relay.Function(relay.analysis.free_vars(pool_3), pool_3) |
324 | 327 |
|
325 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
326 | | - b = _run_pass(get_graph(), relay.transform.InferType()) |
| 328 | + a = _optimize(get_graph()) |
| 329 | + b = _optimize(get_graph(), optimize=False) |
327 | 330 | _assert_structural_equal(a, b) |
328 | 331 |
|
329 | 332 |
|
@@ -368,8 +371,8 @@ def get_graph(): |
368 | 371 | ) |
369 | 372 | return relay.Function(relay.analysis.free_vars(conv), conv) |
370 | 373 |
|
371 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
372 | | - b = _run_pass(get_graph(), relay.transform.InferType()) |
| 374 | + a = _optimize(get_graph()) |
| 375 | + b = _optimize(get_graph(), optimize=False) |
373 | 376 | _assert_structural_equal(a, b) |
374 | 377 |
|
375 | 378 |
|
@@ -413,8 +416,8 @@ def get_graph(get_expected=False): |
413 | 416 | concat = relay.concatenate(poolings, axis=0) |
414 | 417 | return relay.Function(relay.analysis.free_vars(concat), concat) |
415 | 418 |
|
416 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
417 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 419 | + a = _optimize(get_graph()) |
| 420 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
418 | 421 | _assert_structural_equal(a, b) |
419 | 422 |
|
420 | 423 |
|
@@ -467,8 +470,8 @@ def get_graph(get_expected=False): |
467 | 470 | ) |
468 | 471 | return relay.Function(relay.analysis.free_vars(add_3), add_3) |
469 | 472 |
|
470 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
471 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 473 | + a = _optimize(get_graph()) |
| 474 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
472 | 475 | _assert_structural_equal(a, b) |
473 | 476 |
|
474 | 477 |
|
@@ -500,8 +503,8 @@ def get_graph(get_expected=False): |
500 | 503 | ) |
501 | 504 | return relay.Function(relay.analysis.free_vars(x), x) |
502 | 505 |
|
503 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
504 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 506 | + a = _optimize(get_graph()) |
| 507 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
505 | 508 | _assert_structural_equal(a, b) |
506 | 509 |
|
507 | 510 |
|
@@ -530,8 +533,8 @@ def get_graph(get_expected=False): |
530 | 533 | ) |
531 | 534 | return relay.Function(relay.analysis.free_vars(x), x) |
532 | 535 |
|
533 | | - a = _run_pass(get_graph(), LayoutOptimizer()) |
534 | | - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) |
| 536 | + a = _optimize(get_graph()) |
| 537 | + b = _optimize(get_graph(get_expected=True), optimize=False) |
535 | 538 | _assert_structural_equal(a, b) |
536 | 539 |
|
537 | 540 |
|
@@ -619,5 +622,32 @@ def representative_dataset(): |
619 | 622 | _compile_and_compare_model(create_model(), ifm_shape, dtype) |
620 | 623 |
|
621 | 624 |
|
| 625 | +def test_layout_optimizer_runs_in_compilation_pipeline(): |
| 626 | + """Checks that the layout optimization pass runs as part of the NPU compilation |
| 627 | + pipeline.""" |
| 628 | + |
| 629 | + def get_graph(): |
| 630 | + x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8") |
| 631 | + for _ in range(2): |
| 632 | + x = relay.nn.max_pool2d(x, layout="NHWC") |
| 633 | + |
| 634 | + func = relay.Function(relay.analysis.free_vars(x), x) |
| 635 | + return tvm.IRModule.from_expr(func) |
| 636 | + |
| 637 | + mod = get_graph() |
| 638 | + mod = partition_for_ethosu(mod) |
| 639 | + |
| 640 | + external_gv_name = mod["main"].body.op.name_hint |
| 641 | + external_func = mod[external_gv_name] |
| 642 | + prim_func = relay_to_tir_func(external_func) |
| 643 | + |
| 644 | + # Check for hints in the TIR prim func that the layout optimization pass has ran |
| 645 | + ops = prim_func.body.body.seq |
| 646 | + max_pool1, max_pool2 = ops |
| 647 | + |
| 648 | + assert str(max_pool1.value.args[31]) == '"NHCWB16"' |
| 649 | + assert str(max_pool2.value.args[14]) == '"NHCWB16"' |
| 650 | + |
| 651 | + |
622 | 652 | if __name__ == "__main__": |
623 | 653 | pytest.main([__file__] + sys.argv[1:]) |
0 commit comments