Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 84 additions & 55 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
from collections import defaultdict

import tvm
from tvm import relay
Expand All @@ -24,7 +25,7 @@
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
Expand Down Expand Up @@ -138,38 +139,76 @@ def __call__(self, *args, **kwargs):
pass


class LayoutOptimization(ExprMutator):
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16.
class AnalyzeConsumers(ExprVisitor):
"""Traverses the graph to determine consumers that are NPU operations. The
result is maintained in `npu_consumers`.

Attributes
----------
children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
A map from current call to a list of calls that rely on the current
call. This allows the graph to be traversed backwards, which is useful
for checking whether the output layouts can be rewritten.
optimize_op : Dict[str, Callable]
A map from NPU op name to function that creates NPU op.
npu_consumers : Dict[tvm.relay.expr.Call, List[bool]]
Mapping from NPU operation to list of boolean values that represent
whether or not each consumer is an NPU operation.
optimize_ops : Dict[str, Callable]
A map from NPU operation name to function that creates NPU operation.
"""

def __init__(self):
self.children = {}
self.optimize_op = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
}
def __init__(self, optimize_ops):
self.npu_consumers = defaultdict(list)
self.optimize_ops = optimize_ops
super().__init__()

def visit_call(self, call: relay.Call):
is_npu_consumer = call.op.name in self.optimize_ops
args = []

# Expand tuples
for arg in call.args:
if isinstance(arg, relay.Tuple):
args.extend(arg.fields)
else:
args.append(arg)

for arg in args:
if isinstance(arg, relay.Call) and arg.op.name in self.optimize_ops:
self.npu_consumers[arg].append(is_npu_consumer)

super().visit_call(call)


class LayoutOptimization(ExprMutator):
"""A pass to optimize the layout of NPU operations by converting to brick format (NHCWB16).
This pass traverses the graph and attempts to alter the input/output layouts when an NPU
operation is visited. Whether or not the input/output layout can be altered for a given NPU
operation depends on the following:

Check alter input layout: For each argument, if the producer is also an NPU operation and
its output is altered to brick format, then the input layout with respect to the current
argument is altered to brick format.

Check alter output layout: If all consumers (child nodes) are an NPU operation, then the
output layout is altered to brick format.

Note
----
In order for this pass to be run, the consumers of each NPU operation must first be analyzed
by the `AnalyzeConsumers` pass, since Relay doesn't keep a reference to child nodes.

Attributes
----------
npu_consumers : Dict[tvm.relay.expr.Call, bool]
A map from current call to a list boolean values that state whether or not each consumer
is an NPU operation.
optimize_ops : Dict[str, Callable]
A map from NPU operation name to function that creates NPU operation.
"""

def __init__(self, npu_consumers, optimize_ops):
self.npu_consumers = npu_consumers
self.optimize_ops = optimize_ops
super().__init__()

def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Alter the input and output layouts of an NPU operation if needed.
Input layout is only altered if the producing operation is an NPU
operation. Likewise, the output layout is only altered if the consuming
operation is an NPU operation.
"""Alter the layouts of given NPU operation to brick format if possible.

Parameters
----------
Expand All @@ -189,46 +228,26 @@ def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Ca
)

new_attrs = dict(call.attrs)
parents = []

# Check if we can rewrite the input layouts
input_count = 0
for arg in call.args:
input_count += 1
if not isinstance(arg, tvm.relay.expr.Call):
if arg not in self.npu_consumers:
continue
if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op:
consumers = self.npu_consumers[arg]
parent_has_brick_output = consumers and all(consumers)
if parent_has_brick_output:
layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout"
new_attrs[layout_string] = "NHCWB16"
parents.append(arg)

# Check if we can rewrite the output layouts
if call in self.children:
children = self.children[call]
if all(
isinstance(child, tvm.relay.expr.Call)
and isinstance(child.op, tvm.ir.op.Op)
and child.op.name in self.optimize_op
and child.attrs["ifm_layout"] == "NHCWB16"
for child in children
):
new_attrs["ofm_layout"] = "NHCWB16"
consumers = self.npu_consumers[call]
if consumers and all(consumers):
new_attrs["ofm_layout"] = "NHCWB16"

name = call.op.name
assert name in self.optimize_op, (
f"Could not create operator '{name}' as the creation function "
"is unknown. Please provide a mapping."
)
new_call = self.optimize_op[name](*call.args, **new_attrs)

# Update map of children
for input_arg in parents:
if input_arg in self.children:
self.children[input_arg].append(new_call)
else:
self.children[input_arg] = [new_call]

return super().visit_call(new_call)
return self.optimize_ops[name](*call.args, **new_attrs)

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Recursively visit call nodes in the input graph and alter the
Expand All @@ -246,23 +265,33 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
not refer to an Op. Else, a new call node with altered Op
attributes.
"""
if isinstance(call.op, tvm.ir.op.Op) and call.op.name in self.optimize_op:
return self.alter_ethosu_op_layout(call)
if isinstance(call.op, tvm.ir.Op) and call.op.name in self.optimize_ops:
call = self.alter_ethosu_op_layout(call)
return super().visit_call(call)


@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer")
class LayoutOptimizer:
"""Register LayoutOptimizer as a Relay pass."""

OPTIMIZE_OPS = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
}

def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16 as this is the layout NPU
uses internally."""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
global_var, func = mod.functions.items()[0]
optimized_func = LayoutOptimization().visit(func)
analyze = AnalyzeConsumers(self.OPTIMIZE_OPS)
analyze.visit(func)
optimized_func = LayoutOptimization(analyze.npu_consumers, self.OPTIMIZE_OPS).visit(func)
mod.update_func(global_var, optimized_func)
return mod

Expand Down
85 changes: 85 additions & 0 deletions tests/python/contrib/test_ethosu/test_layout_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,91 @@ def get_graph(get_expected=False):
_assert_structural_equal(a, b)


def test_op_without_ethosu_consumer():
"""Test the layout optimization pass works as expected when
there is a case that the output layout should not be altered
since not all consumers are NPU operations (in this case conv).

depthwise
|
conv
/ \
| pool
\ /
(concat)
"""

def get_graph(get_expected=False):
exp_layout = "NHCWB16" if get_expected else "NHWC"

x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
depthwise = infra.make_ethosu_depthwise_conv2d(
x, 2, (1, 1), (0, 0), (1, 1), (0, 0), ofm_layout=exp_layout
)
conv = infra.make_ethosu_conv2d(
depthwise,
2,
2,
(1, 1),
(0, 0),
(1, 1),
(0, 0),
ifm_layout=exp_layout,
)
pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, (1, 1), (0, 0))
concat = relay.concatenate([conv, pool], axis=0)
return relay.Function(relay.analysis.free_vars(concat), concat)

a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


def test_diamond_graph():
"""
Test the layout optimizer pass works as expected on a diamond graph
with a case where the operation dominating the output operation
cannot be altered, but operations within the diamond can.

pool_1
|
pool_2
/ \
| pool_3
| |
| pool_4
| |
| pool_5
\ /
(concat)
"""

def get_graph(get_expected=False):
exp_layout = "NHCWB16" if get_expected else "NHWC"
x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
pool_1 = infra.make_ethosu_pooling(
x, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
)
pool_2 = infra.make_ethosu_pooling(
pool_1, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
)
pool_3 = infra.make_ethosu_pooling(
pool_2, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
)
pool_4 = infra.make_ethosu_pooling(
pool_3, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout, ofm_layout=exp_layout
)
pool_5 = infra.make_ethosu_pooling(
pool_4, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
)
concat = relay.concatenate([pool_2, pool_5], axis=0)
return relay.Function(relay.analysis.free_vars(concat), concat)

a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


def test_same_output_multiple_convolutions():
"""Test running the layout optimization pass with multiple convolutions
gives same output as TFLite."""
Expand Down