Skip to content

Commit 9301f72

Browse files
committed
[microNPU] Move optimization passes to be a module pass and ensure they
are running Moves LayoutOptimizer and LUTOptimizer passes to be a module pass, rather than a function pass. This is because it was found that these passes were not running in the NPU compilation flow. In addition, a test for both LayoutOptimizer and LUTOptimizer has been added to check that the passes are running in the compilation pipeline of the NPU. Change-Id: I5145c6f02eeb0daea3cdba56198e0804ec32f351
1 parent 364e2db commit 9301f72

File tree

3 files changed

+114
-38
lines changed

3 files changed

+114
-38
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
import tvm
2020
from tvm import relay
21+
from tvm import ir
2122
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
2223
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
2324
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
2425
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
2526
from tvm.relay.backend.contrib.ethosu import util
2627
from tvm.relay.expr_functor import ExprMutator
27-
from tvm.ir.transform import Pass
2828

2929
# pylint: disable=unused-import
3030
from tvm.relay.backend.contrib.ethosu.op import op_attrs
@@ -109,13 +109,11 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
109109
return new_call
110110

111111

112-
@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
113-
class LUTsOptimizer(Pass):
112+
@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer")
113+
class LUTsOptimizer:
114114
"""Register LUTsOptimizer as a relay pass."""
115115

116-
def transform_function(
117-
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
118-
) -> tvm.IRModule:
116+
def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
119117
"""Visit relay nodes in the given module.
120118
121119
Parameters
@@ -131,7 +129,13 @@ def transform_function(
131129
New module with optimized LUTs.
132130
"""
133131
assert len(mod.functions.items()) == 1, "Module can only contain one function."
134-
return OptimizeLUTs().visit(func)
132+
global_var, func = mod.functions.items()[0]
133+
optimized_func = OptimizeLUTs().visit(func)
134+
mod.update_func(global_var, optimized_func)
135+
return mod
136+
137+
def __call__(self, *args, **kwargs):
138+
pass
135139

136140

137141
class LayoutOptimization(ExprMutator):
@@ -247,19 +251,23 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
247251
return super().visit_call(call)
248252

249253

250-
@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer")
251-
class LayoutOptimizer(Pass):
254+
@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer")
255+
class LayoutOptimizer:
252256
"""Register LayoutOptimizer as a Relay pass."""
253257

254-
def transform_function(
255-
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
256-
) -> tvm.IRModule:
258+
def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
257259
"""A pass to optimize the layout of NPU operations. If both the
258260
producer and consumer of a tensor are NPU operators, then the
259261
layout is converted from NHWC to NHCWB16 as this is the layout NPU
260262
uses internally."""
261263
assert len(mod.functions.items()) == 1, "Module can only contain one function."
262-
return LayoutOptimization().visit(func)
264+
global_var, func = mod.functions.items()[0]
265+
optimized_func = LayoutOptimization().visit(func)
266+
mod.update_func(global_var, optimized_func)
267+
return mod
268+
269+
def __call__(self, *args, **kwargs):
270+
pass
263271

264272

265273
@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")

tests/python/contrib/test_ethosu/test_layout_optimizer.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@
3333
from tvm import relay
3434
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
3535
from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer
36+
from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
3637

3738
from . import infra
3839

3940

40-
def _run_pass(expr, relay_pass):
41-
"""Create IRModule and run Relay pass."""
41+
def _optimize(expr, optimize=True):
42+
"""Create IRModule and run layout optimizer pass."""
4243
mod = tvm.IRModule.from_expr(expr)
43-
mod = relay_pass(mod)
44+
mod = relay.transform.InferType()(mod)
45+
if optimize:
46+
mod = LayoutOptimizer()(mod)
4447
entry = mod["main"]
4548
return entry if isinstance(expr, relay.Function) else entry.body
4649

@@ -111,8 +114,8 @@ def get_graph():
111114
)
112115
return relay.Function(relay.analysis.free_vars(x), x)
113116

114-
a = _run_pass(get_graph(), LayoutOptimizer())
115-
b = _run_pass(get_graph(), relay.transform.InferType())
117+
a = _optimize(get_graph())
118+
b = _optimize(get_graph(), optimize=False)
116119
_assert_structural_equal(a, b)
117120

118121

@@ -144,8 +147,8 @@ def get_graph(get_expected=False):
144147
)
145148
return relay.Function(relay.analysis.free_vars(x), x)
146149

147-
a = _run_pass(get_graph(), LayoutOptimizer())
148-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
150+
a = _optimize(get_graph())
151+
b = _optimize(get_graph(get_expected=True), optimize=False)
149152
_assert_structural_equal(a, b)
150153

151154

@@ -176,8 +179,8 @@ def get_graph(get_expected=False):
176179
)
177180
return relay.Function(relay.analysis.free_vars(x), x)
178181

179-
a = _run_pass(get_graph(), LayoutOptimizer())
180-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
182+
a = _optimize(get_graph())
183+
b = _optimize(get_graph(get_expected=True), optimize=False)
181184
_assert_structural_equal(a, b)
182185

183186

@@ -222,8 +225,8 @@ def get_graph():
222225
)
223226
return relay.Function(relay.analysis.free_vars(conv_2), conv_2)
224227

225-
a = _run_pass(get_graph(), LayoutOptimizer())
226-
b = _run_pass(get_graph(), relay.transform.InferType())
228+
a = _optimize(get_graph())
229+
b = _optimize(get_graph(), optimize=False)
227230
_assert_structural_equal(a, b)
228231

229232

@@ -268,8 +271,8 @@ def get_graph():
268271
)
269272
return relay.Function(relay.analysis.free_vars(conv_2), conv_2)
270273

271-
a = _run_pass(get_graph(), LayoutOptimizer())
272-
b = _run_pass(get_graph(), relay.transform.InferType())
274+
a = _optimize(get_graph())
275+
b = _optimize(get_graph(), optimize=False)
273276
_assert_structural_equal(a, b)
274277

275278

@@ -322,8 +325,8 @@ def get_graph():
322325
)
323326
return relay.Function(relay.analysis.free_vars(pool_3), pool_3)
324327

325-
a = _run_pass(get_graph(), LayoutOptimizer())
326-
b = _run_pass(get_graph(), relay.transform.InferType())
328+
a = _optimize(get_graph())
329+
b = _optimize(get_graph(), optimize=False)
327330
_assert_structural_equal(a, b)
328331

329332

@@ -368,8 +371,8 @@ def get_graph():
368371
)
369372
return relay.Function(relay.analysis.free_vars(conv), conv)
370373

371-
a = _run_pass(get_graph(), LayoutOptimizer())
372-
b = _run_pass(get_graph(), relay.transform.InferType())
374+
a = _optimize(get_graph())
375+
b = _optimize(get_graph(), optimize=False)
373376
_assert_structural_equal(a, b)
374377

375378

@@ -413,8 +416,8 @@ def get_graph(get_expected=False):
413416
concat = relay.concatenate(poolings, axis=0)
414417
return relay.Function(relay.analysis.free_vars(concat), concat)
415418

416-
a = _run_pass(get_graph(), LayoutOptimizer())
417-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
419+
a = _optimize(get_graph())
420+
b = _optimize(get_graph(get_expected=True), optimize=False)
418421
_assert_structural_equal(a, b)
419422

420423

@@ -467,8 +470,8 @@ def get_graph(get_expected=False):
467470
)
468471
return relay.Function(relay.analysis.free_vars(add_3), add_3)
469472

470-
a = _run_pass(get_graph(), LayoutOptimizer())
471-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
473+
a = _optimize(get_graph())
474+
b = _optimize(get_graph(get_expected=True), optimize=False)
472475
_assert_structural_equal(a, b)
473476

474477

@@ -500,8 +503,8 @@ def get_graph(get_expected=False):
500503
)
501504
return relay.Function(relay.analysis.free_vars(x), x)
502505

503-
a = _run_pass(get_graph(), LayoutOptimizer())
504-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
506+
a = _optimize(get_graph())
507+
b = _optimize(get_graph(get_expected=True), optimize=False)
505508
_assert_structural_equal(a, b)
506509

507510

@@ -530,8 +533,8 @@ def get_graph(get_expected=False):
530533
)
531534
return relay.Function(relay.analysis.free_vars(x), x)
532535

533-
a = _run_pass(get_graph(), LayoutOptimizer())
534-
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
536+
a = _optimize(get_graph())
537+
b = _optimize(get_graph(get_expected=True), optimize=False)
535538
_assert_structural_equal(a, b)
536539

537540

@@ -619,5 +622,32 @@ def representative_dataset():
619622
_compile_and_compare_model(create_model(), ifm_shape, dtype)
620623

621624

625+
def test_layout_optimizer_runs_in_compilation_pipeline():
626+
"""Checks that the layout optimization pass runs as part of the NPU compilation
627+
pipeline."""
628+
629+
def get_graph():
630+
x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8")
631+
for _ in range(2):
632+
x = relay.nn.max_pool2d(x, layout="NHWC")
633+
634+
func = relay.Function(relay.analysis.free_vars(x), x)
635+
return tvm.IRModule.from_expr(func)
636+
637+
mod = get_graph()
638+
mod = partition_for_ethosu(mod)
639+
640+
external_gv_name = mod["main"].body.op.name_hint
641+
external_func = mod[external_gv_name]
642+
prim_func = relay_to_tir_func(external_func)
643+
644+
# Check for hints in the TIR prim func that the layout optimization pass has ran
645+
ops = prim_func.body.body.seq
646+
max_pool1, max_pool2 = ops
647+
648+
assert str(max_pool1.value.args[31]) == '"NHCWB16"'
649+
assert str(max_pool2.value.args[14]) == '"NHCWB16"'
650+
651+
622652
if __name__ == "__main__":
623653
pytest.main([__file__] + sys.argv[1:])

tests/python/contrib/test_ethosu/test_lut_optimizer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@
2121

2222
pytest.importorskip("ethosu.vela")
2323

24+
import tensorflow as tf
25+
import numpy as np
26+
2427
import tvm
2528
from tvm import relay
2629
from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer
30+
from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
31+
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
32+
33+
from .test_codegen import _get_tflite_graph
2734
from . import infra
2835

2936

@@ -59,6 +66,7 @@ def after():
5966
return mod
6067

6168
mod = LUTsOptimizer()(before())
69+
mod = relay.transform.InferType()(mod)
6270

6371
assert tvm.ir.structural_equal(mod, after())
6472

@@ -91,5 +99,35 @@ def after():
9199
return mod
92100

93101
mod = LUTsOptimizer()(before())
102+
mod = relay.transform.InferType()(mod)
94103

95104
assert tvm.ir.structural_equal(mod, after())
105+
106+
107+
def test_lut_optimizer_runs_in_compilation_pipeline():
108+
"""Test that the LUT optimization pass runs as part of the NPU compilation pipeline."""
109+
ifm_shape = (1, 4, 4, 4)
110+
111+
@tf.function
112+
def get_graph(x):
113+
weight1 = tf.constant(np.random.uniform(size=(1, 1, 4, 4)), dtype=tf.float32)
114+
op = tf.nn.conv2d(x, weight1, (1, 1), "VALID")
115+
op = tf.nn.tanh(op)
116+
weight2 = tf.constant(np.random.uniform(size=(1, 1, 4, 1)), dtype=tf.float32)
117+
op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID")
118+
return tf.nn.tanh(op)
119+
120+
mod, _ = _get_tflite_graph(get_graph, [ifm_shape])
121+
mod = partition_for_ethosu(mod)
122+
123+
external_gv_name = mod["main"].body.op.name_hint
124+
external_func = mod[external_gv_name]
125+
prim_func = relay_to_tir_func(external_func)
126+
127+
# Check for hints in the TIR prim func that the LUT optimization pass has ran.
128+
# If the module was optimized, there should be no identity operations.
129+
def check_identity(stmt):
130+
if isinstance(stmt, tvm.tir.expr.Call):
131+
assert stmt.args[0] != "ethosu_identity"
132+
133+
tvm.tir.stmt_functor.post_order_visit(prim_func.body, check_identity)

0 commit comments

Comments
 (0)