Skip to content

Commit 9265b08

Browse files
lhutton1yangulei
authored andcommitted
[microNPU] Add NHWC -> NHCWB16 layout transformation pass (apache#9561)
Adds a layout optimization pass that modifies the ifm/ofm layout of an operation to NHCWB16 where possible. This can occur when the producer or consumer of a tensor is also an NPU operator.
1 parent ec502f5 commit 9265b08

File tree

3 files changed

+761
-0
lines changed

3 files changed

+761
-0
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,134 @@ def transform_function(
134134
return OptimizeLUTs().visit(func)
135135

136136

137+
class LayoutOptimization(ExprMutator):
138+
"""A pass to optimize the layout of NPU operations. If both the
139+
producer and consumer of a tensor are NPU operators, then the
140+
layout is converted from NHWC to NHCWB16.
141+
142+
Attributes
143+
----------
144+
children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
145+
A map from current call to a list of calls that rely on the current
146+
call. This allows the graph to be traversed backwards, which is useful
147+
for checking whether the output layouts can be rewritten.
148+
optimize_op : Dict[str, Callable]
149+
A map from NPU op name to function that creates NPU op.
150+
"""
151+
152+
def __init__(self):
153+
self.children = {}
154+
self.optimize_op = {
155+
"contrib.ethosu.conv2d": op.ethosu_conv2d,
156+
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
157+
"contrib.ethosu.pooling": op.ethosu_pooling,
158+
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
159+
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
160+
}
161+
162+
super().__init__()
163+
164+
def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
165+
"""Alter the input and output layouts of an NPU operation if needed.
166+
Input layout is only altered if the producing operation is an NPU
167+
operation. Likewise, the output layout is only altered if the consuming
168+
operation is an NPU operation.
169+
170+
Parameters
171+
----------
172+
call : tvm.relay.expr.Call
173+
The call pointing to an NPU operation that will be checked if
174+
the layout needs altering.
175+
176+
Returns
177+
-------
178+
new_call : tvm.relay.expr.Call
179+
New call with altered layouts.
180+
"""
181+
assert isinstance(call.attrs, tvm.ir.Attrs), (
182+
f"The attributes for operator '{call.op.name}' could not be "
183+
"found. Did you register the relay.attrs.Ethosu<opname>Attrs "
184+
"object in python api?"
185+
)
186+
187+
new_attrs = dict(call.attrs)
188+
parents = []
189+
190+
# Check if we can rewrite the input layouts
191+
input_count = 0
192+
for arg in call.args:
193+
input_count += 1
194+
if not isinstance(arg, tvm.relay.expr.Call):
195+
continue
196+
if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op:
197+
layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout"
198+
new_attrs[layout_string] = "NHCWB16"
199+
parents.append(arg)
200+
201+
# Check if we can rewrite the output layouts
202+
if call in self.children:
203+
children = self.children[call]
204+
if all(
205+
isinstance(child, tvm.relay.expr.Call)
206+
and isinstance(child.op, tvm.ir.op.Op)
207+
and child.op.name in self.optimize_op
208+
and child.attrs["ifm_layout"] == "NHCWB16"
209+
for child in children
210+
):
211+
new_attrs["ofm_layout"] = "NHCWB16"
212+
213+
name = call.op.name
214+
assert name in self.optimize_op, (
215+
f"Could not create operator '{name}' as the creation function "
216+
"is unknown. Please provide a mapping."
217+
)
218+
new_call = self.optimize_op[name](*call.args, **new_attrs)
219+
220+
# Update map of children
221+
for input_arg in parents:
222+
if input_arg in self.children:
223+
self.children[input_arg].append(new_call)
224+
else:
225+
self.children[input_arg] = [new_call]
226+
227+
return super().visit_call(new_call)
228+
229+
def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
230+
"""Recursively visit call nodes in the input graph and alter the
231+
layout of an op if needed.
232+
233+
Parameters
234+
----------
235+
call : tvm.relay.expr.Call
236+
The current call node being visited.
237+
238+
Returns
239+
-------
240+
tvm.relay.expr.Call
241+
The input call node in the case the current call node does
242+
not refer to an Op. Else, a new call node with altered Op
243+
attributes.
244+
"""
245+
if isinstance(call.op, tvm.ir.op.Op) and call.op.name in self.optimize_op:
246+
return self.alter_ethosu_op_layout(call)
247+
return super().visit_call(call)
248+
249+
250+
@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer")
251+
class LayoutOptimizer(Pass):
252+
"""Register LayoutOptimizer as a Relay pass."""
253+
254+
def transform_function(
255+
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
256+
) -> tvm.IRModule:
257+
"""A pass to optimize the layout of NPU operations. If both the
258+
producer and consumer of a tensor are NPU operators, then the
259+
layout is converted from NHWC to NHCWB16 as this is the layout NPU
260+
uses internally."""
261+
assert len(mod.functions.items()) == 1, "Module can only contain one function."
262+
return LayoutOptimization().visit(func)
263+
264+
137265
@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
138266
def constant_updater(expr, symbol): # pylint: disable=unused-argument
139267
"""

python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,13 @@ class EthosuDepthwiseConv2DAttrs(Attrs):
3737
@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs")
3838
class EthosuPooling2DAttrs(Attrs):
3939
"""Attributes for contrib.ethosu.pooling."""
40+
41+
42+
@tvm._ffi.register_object("relay.attrs.EthosuBinaryElementwiseAttrs")
43+
class EthosuBinaryElementwiseAttrs(Attrs):
44+
"""Attributes for contrib.ethosu.binary_elementwise"""
45+
46+
47+
@tvm._ffi.register_object("relay.attrs.EthosuUnaryElementwiseAttrs")
48+
class EthosuUnaryElementwiseAttrs(Attrs):
49+
"""Attributes for contrib.ethosu.unary_elementwise"""

0 commit comments

Comments
 (0)