Skip to content

Commit 7836a42

Browse files
[ETHOSU][MicroNPU][Pass] Add a pass to replicate pads
1 parent f172f6c commit 7836a42

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,8 @@ def partition_for_ethosu(
23482348

23492349
pattern = relay.op.contrib.get_pattern_table("ethos-u")
23502350
mod = relay.transform.InferType()(mod)
2351+
mod = relay.transform.replicate_pads(mod)
2352+
mod = relay.transform.InferType()(mod)
23512353
mod = relay.transform.MergeComposite(pattern)(mod)
23522354
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
23532355
mod = relay.transform.MergeCompilerRegions()(mod)

python/tvm/relay/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""The Relay IR namespace containing transformations."""
1919
# transformation passes
2020
from .transform import *
21+
from .replicate_pads_with_multiple_consumers import *
2122
from .recast import recast
2223
from . import fake_quantization_to_integer, mixed_precision
2324
from .flexible_shape import FlexibleShapeDispatch
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"Adds pads so that each conv2d operator has only one consumer"
18+
19+
import tvm
20+
from tvm import relay
21+
22+
from ..expr_functor import ExprMutator, Call
23+
from .. import expr as _expr
24+
25+
26+
class PadsWithMultipleConsumersReplicator(ExprMutator):
27+
"""A pass to to handle the situation when nn.pad operator has
28+
more than one qnn.conv2d consumer.
29+
30+
pad
31+
/ \
32+
Conv2D Conv2D
33+
34+
In this case, because of the peculiarities of pattern parsing,
35+
conv2d does not get into the composite for the NPU.
36+
Therefore, pads are added so that each has only one consumer.
37+
"""
38+
39+
def __init__(self):
40+
ExprMutator.__init__(self)
41+
self.hashes = set()
42+
43+
def visit_call(self, call):
44+
if (
45+
isinstance(call.op, tvm.ir.Op)
46+
and isinstance(call.args[0], Call)
47+
and isinstance(call.args[0].op, tvm.ir.Op)
48+
and call.op == relay.op.get("qnn.conv2d")
49+
and call.args[0].op == relay.op.get("nn.pad")
50+
):
51+
if tvm.ir.structural_hash(call.args[0]) not in self.hashes:
52+
self.hashes.add(tvm.ir.structural_hash(call.args[0]))
53+
else:
54+
used_pad = self.visit(call.args[0])
55+
used_pad_args = [self.visit(arg) for arg in used_pad.args]
56+
new_pad = Call(
57+
used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span
58+
)
59+
new_pad = self.visit(new_pad)
60+
new_conv2d_args = []
61+
for i, arg in enumerate(call.args):
62+
if i == 0:
63+
new_conv2d_args.append(self.visit(new_pad))
64+
else:
65+
new_conv2d_args.append(self.visit(arg))
66+
new_conv2d_op = self.visit(call.op)
67+
expr__ = _expr.CallWithFields(
68+
call,
69+
new_conv2d_op,
70+
new_conv2d_args,
71+
call.attrs,
72+
call.type_args,
73+
None,
74+
call.span,
75+
)
76+
return expr__
77+
78+
new_args = [self.visit(arg) for arg in call.args]
79+
new_op = self.visit(call.op)
80+
expr__ = _expr.CallWithFields(
81+
call, new_op, new_args, call.attrs, call.type_args, None, call.span
82+
)
83+
return expr__
84+
85+
86+
def replicate_pads(mod):
87+
"""Traverses the Relay graph to replicate nn.pad operators if thay have
88+
multiple qnn.conv2d consumers. That making remove the situation when
89+
e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped
90+
because several conv2d use the same pad operation.
91+
92+
Parameters
93+
----------
94+
tvm.ir.IRModule
95+
The IRModule that gets generated from a relay frontend.
96+
97+
Returns
98+
-------
99+
tvm.ir.IRModule
100+
The IRModule without nn.pad operators with multiple consumers.
101+
"""
102+
replicator = PadsWithMultipleConsumersReplicator()
103+
for global_var, func in mod.functions.items():
104+
func = replicator.visit(func)
105+
mod.update_func(global_var, func)
106+
return mod

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
4444
want to add the operator's pattern to the pattern table so that the compiler
4545
wouldn't attempt to offload an operator without full stack support."""
4646
mod = relay.transform.InferType()(mod)
47+
mod = relay.transform.replicate_pads(mod)
48+
mod = relay.transform.InferType()(mod)
4749
mod = relay.transform.MergeComposite(pattern_table)(mod)
4850
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
4951
mod = relay.transform.MergeCompilerRegions()(mod)
@@ -3646,5 +3648,149 @@ def _visit(stmt):
36463648
verify(mod["tvmgen_default_ethos_u_main_0"])
36473649

36483650

3651+
@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)])
3652+
@pytest.mark.parametrize("kernel_shape", [(3, 3)])
3653+
@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
3654+
@pytest.mark.parametrize("op_padding", ["SAME", "VALID"])
3655+
@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)])
3656+
@pytest.mark.parametrize(
3657+
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
3658+
)
3659+
def test_tflite_shared_pad_legalize(
3660+
ifm_shape,
3661+
kernel_shape,
3662+
strides,
3663+
dilation,
3664+
op_padding,
3665+
sep_padding,
3666+
op_pairs,
3667+
):
3668+
dtype = "int8"
3669+
3670+
def create_tflite_graph():
3671+
class Model(tf.Module):
3672+
@tf.function
3673+
def tf_function(self, x):
3674+
3675+
x = tf.pad(
3676+
x,
3677+
[
3678+
[0, 0],
3679+
[sep_padding[0], sep_padding[2]],
3680+
[sep_padding[1], sep_padding[3]],
3681+
[0, 0],
3682+
],
3683+
"CONSTANT",
3684+
)
3685+
3686+
# The input strides to the TensorFlow API needs to be of shape 1x4
3687+
tf_strides = [1, strides[0], strides[1], 1]
3688+
3689+
if op_pairs[0] == "depthwise":
3690+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
3691+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3692+
x1 = tf.nn.depthwise_conv2d(
3693+
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
3694+
)
3695+
else:
3696+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
3697+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3698+
x1 = tf.nn.conv2d(
3699+
x,
3700+
weight,
3701+
strides=tf_strides,
3702+
padding=op_padding,
3703+
dilations=dilation,
3704+
)
3705+
3706+
if op_pairs[1] == "depthwise":
3707+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
3708+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3709+
x2 = tf.nn.depthwise_conv2d(
3710+
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
3711+
)
3712+
else:
3713+
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
3714+
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
3715+
x2 = tf.nn.conv2d(
3716+
x,
3717+
weight,
3718+
strides=tf_strides,
3719+
padding=op_padding,
3720+
dilations=dilation,
3721+
)
3722+
3723+
x3 = tf.math.add(x1, x2)
3724+
return x3
3725+
3726+
model = Model()
3727+
concrete_func = model.tf_function.get_concrete_function(
3728+
tf.TensorSpec(ifm_shape, dtype=tf.float32)
3729+
)
3730+
# Convert the model
3731+
def representative_dataset():
3732+
for _ in range(100):
3733+
data = np.random.rand(*tuple(ifm_shape))
3734+
yield [data.astype(np.float32)]
3735+
3736+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
3737+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
3738+
converter.representative_dataset = representative_dataset
3739+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
3740+
converter.inference_input_type = tf.int8
3741+
converter.inference_output_type = tf.int8
3742+
tflite_model = converter.convert()
3743+
return tflite_model
3744+
3745+
conv2d_pattern_table = [
3746+
(
3747+
ethosu.QnnConv2DParams.composite_name,
3748+
ethosu.qnn_conv2d_pattern(),
3749+
lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
3750+
),
3751+
(
3752+
ethosu.QnnDepthwiseConv2DParams.composite_name,
3753+
ethosu.qnn_depthwise_conv2d_pattern(),
3754+
lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
3755+
),
3756+
]
3757+
3758+
tflite_graph = create_tflite_graph()
3759+
# tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
3760+
tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0)
3761+
3762+
mod, params = relay.frontend.from_tflite(
3763+
tflite_model,
3764+
shape_dict={"input": ifm_shape},
3765+
dtype_dict={"input": dtype},
3766+
)
3767+
3768+
mod["main"] = bind_params_by_name(mod["main"], params)
3769+
mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
3770+
3771+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
3772+
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
3773+
mod["tvmgen_default_ethos_u_main_0"],
3774+
)
3775+
mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite(
3776+
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
3777+
mod["tvmgen_default_ethos_u_main_1"],
3778+
)
3779+
3780+
if op_pairs[0] == "depthwise":
3781+
assert (
3782+
mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d"
3783+
)
3784+
else:
3785+
assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d"
3786+
3787+
if op_pairs[1] == "depthwise":
3788+
assert (
3789+
mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d"
3790+
)
3791+
else:
3792+
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"
3793+
3794+
36493795
if __name__ == "__main__":
36503796
tvm.testing.main()

0 commit comments

Comments
 (0)