Skip to content

Commit a8741e2

Browse files
authored
[microNPU] Fix layout assignment in layout optimizer pass (#10143)
Fixes the layout optimizer incorrectly assigning layouts for graphs with more complex topologies than previously considered. Specifically, this commit now ensures that intermediate layouts match (e.g. parent output = child input) and that all consumers are taken into account when altering the output layout - something not done previously due to an incorrect traversal order. Previously, the input layout was always altered if the producer was an NPU operation without regard to the output layout of that operation. Additionally, is was possible for the output layout to be incorrectly set due to a depth-first post-order of traversal of the graph, meaning it was possible for not all consumers to be taken into account when altering the layout. Now the `AnalyzeConsumers` pass is run before `LayoutOptimization` which determines a mapping from NPU operation to list of boolean values that represent whether or not each consumer is an NPU operation. Since this is completed before `LayoutOptimization`, all consumers are guaranteed to be taken into account when altering the output layout. In turn, the input layouts can correctly be determined by checking whether the output of the producer will be altered. Change-Id: I04e9605da65fa9f12801109dd50c5e3f08cbc73c
1 parent 8736593 commit a8741e2

File tree

2 files changed

+169
-55
lines changed

2 files changed

+169
-55
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 84 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
18+
from collections import defaultdict
1819

1920
import tvm
2021
from tvm import relay
@@ -24,7 +25,7 @@
2425
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
2526
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
2627
from tvm.relay.backend.contrib.ethosu import util
27-
from tvm.relay.expr_functor import ExprMutator
28+
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
2829

2930
# pylint: disable=unused-import
3031
from tvm.relay.backend.contrib.ethosu.op import op_attrs
@@ -138,38 +139,76 @@ def __call__(self, *args, **kwargs):
138139
pass
139140

140141

141-
class LayoutOptimization(ExprMutator):
142-
"""A pass to optimize the layout of NPU operations. If both the
143-
producer and consumer of a tensor are NPU operators, then the
144-
layout is converted from NHWC to NHCWB16.
142+
class AnalyzeConsumers(ExprVisitor):
143+
"""Traverses the graph to determine consumers that are NPU operations. The
144+
result is maintained in `npu_consumers`.
145145
146146
Attributes
147147
----------
148-
children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
149-
A map from current call to a list of calls that rely on the current
150-
call. This allows the graph to be traversed backwards, which is useful
151-
for checking whether the output layouts can be rewritten.
152-
optimize_op : Dict[str, Callable]
153-
A map from NPU op name to function that creates NPU op.
148+
npu_consumers : Dict[tvm.relay.expr.Call, List[bool]]
149+
Mapping from NPU operation to list of boolean values that represent
150+
whether or not each consumer is an NPU operation.
151+
optimize_ops : Dict[str, Callable]
152+
A map from NPU operation name to function that creates NPU operation.
154153
"""
155154

156-
def __init__(self):
157-
self.children = {}
158-
self.optimize_op = {
159-
"contrib.ethosu.conv2d": op.ethosu_conv2d,
160-
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
161-
"contrib.ethosu.pooling": op.ethosu_pooling,
162-
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
163-
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
164-
}
155+
def __init__(self, optimize_ops):
156+
self.npu_consumers = defaultdict(list)
157+
self.optimize_ops = optimize_ops
158+
super().__init__()
159+
160+
def visit_call(self, call: relay.Call):
161+
is_npu_consumer = call.op.name in self.optimize_ops
162+
args = []
165163

164+
# Expand tuples
165+
for arg in call.args:
166+
if isinstance(arg, relay.Tuple):
167+
args.extend(arg.fields)
168+
else:
169+
args.append(arg)
170+
171+
for arg in args:
172+
if isinstance(arg, relay.Call) and arg.op.name in self.optimize_ops:
173+
self.npu_consumers[arg].append(is_npu_consumer)
174+
175+
super().visit_call(call)
176+
177+
178+
class LayoutOptimization(ExprMutator):
179+
"""A pass to optimize the layout of NPU operations by converting to brick format (NHCWB16).
180+
This pass traverses the graph and attempts to alter the input/output layouts when an NPU
181+
operation is visited. Whether or not the input/output layout can be altered for a given NPU
182+
operation depends on the following:
183+
184+
Check alter input layout: For each argument, if the producer is also an NPU operation and
185+
its output is altered to brick format, then the input layout with respect to the current
186+
argument is altered to brick format.
187+
188+
Check alter output layout: If all consumers (child nodes) are an NPU operation, then the
189+
output layout is altered to brick format.
190+
191+
Note
192+
----
193+
In order for this pass to be run, the consumers of each NPU operation must first be analyzed
194+
by the `AnalyzeConsumers` pass, since Relay doesn't keep a reference to child nodes.
195+
196+
Attributes
197+
----------
198+
npu_consumers : Dict[tvm.relay.expr.Call, bool]
199+
A map from current call to a list boolean values that state whether or not each consumer
200+
is an NPU operation.
201+
optimize_ops : Dict[str, Callable]
202+
A map from NPU operation name to function that creates NPU operation.
203+
"""
204+
205+
def __init__(self, npu_consumers, optimize_ops):
206+
self.npu_consumers = npu_consumers
207+
self.optimize_ops = optimize_ops
166208
super().__init__()
167209

168210
def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
169-
"""Alter the input and output layouts of an NPU operation if needed.
170-
Input layout is only altered if the producing operation is an NPU
171-
operation. Likewise, the output layout is only altered if the consuming
172-
operation is an NPU operation.
211+
"""Alter the layouts of given NPU operation to brick format if possible.
173212
174213
Parameters
175214
----------
@@ -189,46 +228,26 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca
189228
)
190229

191230
new_attrs = dict(call.attrs)
192-
parents = []
193231

194232
# Check if we can rewrite the input layouts
195233
input_count = 0
196234
for arg in call.args:
197235
input_count += 1
198-
if not isinstance(arg, tvm.relay.expr.Call):
236+
if arg not in self.npu_consumers:
199237
continue
200-
if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op:
238+
consumers = self.npu_consumers[arg]
239+
parent_has_brick_output = consumers and all(consumers)
240+
if parent_has_brick_output:
201241
layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout"
202242
new_attrs[layout_string] = "NHCWB16"
203-
parents.append(arg)
204243

205244
# Check if we can rewrite the output layouts
206-
if call in self.children:
207-
children = self.children[call]
208-
if all(
209-
isinstance(child, tvm.relay.expr.Call)
210-
and isinstance(child.op, tvm.ir.op.Op)
211-
and child.op.name in self.optimize_op
212-
and child.attrs["ifm_layout"] == "NHCWB16"
213-
for child in children
214-
):
215-
new_attrs["ofm_layout"] = "NHCWB16"
245+
consumers = self.npu_consumers[call]
246+
if consumers and all(consumers):
247+
new_attrs["ofm_layout"] = "NHCWB16"
216248

217249
name = call.op.name
218-
assert name in self.optimize_op, (
219-
f"Could not create operator '{name}' as the creation function "
220-
"is unknown. Please provide a mapping."
221-
)
222-
new_call = self.optimize_op[name](*call.args, **new_attrs)
223-
224-
# Update map of children
225-
for input_arg in parents:
226-
if input_arg in self.children:
227-
self.children[input_arg].append(new_call)
228-
else:
229-
self.children[input_arg] = [new_call]
230-
231-
return super().visit_call(new_call)
250+
return self.optimize_ops[name](*call.args, **new_attrs)
232251

233252
def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
234253
"""Recursively visit call nodes in the input graph and alter the
@@ -246,23 +265,33 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
246265
not refer to an Op. Else, a new call node with altered Op
247266
attributes.
248267
"""
249-
if isinstance(call.op, tvm.ir.op.Op) and call.op.name in self.optimize_op:
250-
return self.alter_ethosu_op_layout(call)
268+
if isinstance(call.op, tvm.ir.Op) and call.op.name in self.optimize_ops:
269+
call = self.alter_ethosu_op_layout(call)
251270
return super().visit_call(call)
252271

253272

254273
@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer")
255274
class LayoutOptimizer:
256275
"""Register LayoutOptimizer as a Relay pass."""
257276

277+
OPTIMIZE_OPS = {
278+
"contrib.ethosu.conv2d": op.ethosu_conv2d,
279+
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
280+
"contrib.ethosu.pooling": op.ethosu_pooling,
281+
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
282+
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
283+
}
284+
258285
def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
259286
"""A pass to optimize the layout of NPU operations. If both the
260287
producer and consumer of a tensor are NPU operators, then the
261288
layout is converted from NHWC to NHCWB16 as this is the layout NPU
262289
uses internally."""
263290
assert len(mod.functions.items()) == 1, "Module can only contain one function."
264291
global_var, func = mod.functions.items()[0]
265-
optimized_func = LayoutOptimization().visit(func)
292+
analyze = AnalyzeConsumers(self.OPTIMIZE_OPS)
293+
analyze.visit(func)
294+
optimized_func = LayoutOptimization(analyze.npu_consumers, self.OPTIMIZE_OPS).visit(func)
266295
mod.update_func(global_var, optimized_func)
267296
return mod
268297

tests/python/contrib/test_ethosu/test_layout_optimizer.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,91 @@ def get_graph(get_expected=False):
538538
_assert_structural_equal(a, b)
539539

540540

541+
def test_op_without_ethosu_consumer():
542+
"""Test the layout optimization pass works as expected when
543+
there is a case that the output layout should not be altered
544+
since not all consumers are NPU operations (in this case conv).
545+
546+
depthwise
547+
|
548+
conv
549+
/ \
550+
| pool
551+
\ /
552+
(concat)
553+
"""
554+
555+
def get_graph(get_expected=False):
556+
exp_layout = "NHCWB16" if get_expected else "NHWC"
557+
558+
x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
559+
depthwise = infra.make_ethosu_depthwise_conv2d(
560+
x, 2, (1, 1), (0, 0), (1, 1), (0, 0), ofm_layout=exp_layout
561+
)
562+
conv = infra.make_ethosu_conv2d(
563+
depthwise,
564+
2,
565+
2,
566+
(1, 1),
567+
(0, 0),
568+
(1, 1),
569+
(0, 0),
570+
ifm_layout=exp_layout,
571+
)
572+
pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, (1, 1), (0, 0))
573+
concat = relay.concatenate([conv, pool], axis=0)
574+
return relay.Function(relay.analysis.free_vars(concat), concat)
575+
576+
a = _optimize(get_graph())
577+
b = _optimize(get_graph(get_expected=True), optimize=False)
578+
_assert_structural_equal(a, b)
579+
580+
581+
def test_diamond_graph():
582+
"""
583+
Test the layout optimizer pass works as expected on a diamond graph
584+
with a case where the operation dominating the output operation
585+
cannot be altered, but operations within the diamond can.
586+
587+
pool_1
588+
|
589+
pool_2
590+
/ \
591+
| pool_3
592+
| |
593+
| pool_4
594+
| |
595+
| pool_5
596+
\ /
597+
(concat)
598+
"""
599+
600+
def get_graph(get_expected=False):
601+
exp_layout = "NHCWB16" if get_expected else "NHWC"
602+
x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
603+
pool_1 = infra.make_ethosu_pooling(
604+
x, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
605+
)
606+
pool_2 = infra.make_ethosu_pooling(
607+
pool_1, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
608+
)
609+
pool_3 = infra.make_ethosu_pooling(
610+
pool_2, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
611+
)
612+
pool_4 = infra.make_ethosu_pooling(
613+
pool_3, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout, ofm_layout=exp_layout
614+
)
615+
pool_5 = infra.make_ethosu_pooling(
616+
pool_4, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
617+
)
618+
concat = relay.concatenate([pool_2, pool_5], axis=0)
619+
return relay.Function(relay.analysis.free_vars(concat), concat)
620+
621+
a = _optimize(get_graph())
622+
b = _optimize(get_graph(get_expected=True), optimize=False)
623+
_assert_structural_equal(a, b)
624+
625+
541626
def test_same_output_multiple_convolutions():
542627
"""Test running the layout optimization pass with multiple convolutions
543628
gives same output as TFLite."""

0 commit comments

Comments
 (0)