Skip to content

Commit 60358a1

Browse files
[microNPU] Add hardware constraints for binary elementwise (#13772)
Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU. Since there are hardware constraints, we cannot perform min or max operation fused with requantize (please look at NPU_SET_OFM_SCALE register description https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-) when we have different scales. min/max operations with matching scales are offloaded to NPU as ethosu_binary_elementwise min/max operations with different scales are offloaded to NPU as ethosu_binary_elementwise + ethosu_identity
1 parent 0730422 commit 60358a1

File tree

3 files changed

+150
-34
lines changed

3 files changed

+150
-34
lines changed

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -700,15 +700,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
700700
clip = None
701701
requantize = None
702702

703-
if is_quantized_operation:
704-
if str(current_call.op.name) == "clip":
705-
clip = current_call
706-
current_call = clip.args[0]
707-
else:
708-
if str(current_call.op.name) == "qnn.requantize":
709-
requantize = current_call
710-
clip = current_call.args[0]
711-
current_call = clip.args[0]
703+
if str(current_call.op.name) == "clip":
704+
clip = current_call
705+
current_call = clip.args[0]
706+
elif str(current_call.op.name) == "qnn.requantize":
707+
requantize = current_call
708+
clip = current_call.args[0]
709+
current_call = clip.args[0]
712710
binary_op = current_call
713711

714712
layout = "NHWC"
@@ -941,21 +939,40 @@ def is_valid(self):
941939
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
942940
):
943941
return False
942+
# MIN with different scales is not supported on NPU
943+
# (please look at NPU_SET_OFM_SCALE register description
944+
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
945+
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
946+
return False
944947
return True
945948

946949

950+
# This pattern is for case when there are different scales for requantize and
951+
# minimum + clip + qnn.requantize can't be offloaded to NPU by one operation
952+
# due to hardware constraints.
953+
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
947954
def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
948955
"""
949-
This function creates the pattern for minimum with optional fused RELU activation.
956+
This function creates the pattern for minimum with optional fused RELU activation without
957+
requantize.
950958
"""
951959
minimum = is_op("minimum")(wildcard(), wildcard())
952960
optional_min_clip = is_op("clip")(minimum)
953-
optional_min_clip = is_op("qnn.requantize")(
954-
optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
955-
)
956961
return minimum | optional_min_clip
957962

958963

964+
def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
965+
"""
966+
This function creates the pattern for minimum with fused RELU activation with requantize.
967+
"""
968+
pattern = is_op("minimum")(wildcard(), wildcard())
969+
pattern = is_op("clip")(pattern)
970+
pattern = is_op("qnn.requantize")(
971+
pattern, is_constant(), is_constant(), is_constant(), is_constant()
972+
)
973+
return pattern
974+
975+
959976
class MaxParams(BinaryElementwiseParams):
960977
"""
961978
This class will parse a call to a ethosu.binary_elementwise Max composite function
@@ -979,21 +996,40 @@ def is_valid(self):
979996
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
980997
):
981998
return False
999+
# MAX with different scales is not supported on NPU
1000+
# (please look at NPU_SET_OFM_SCALE register description
1001+
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
1002+
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
1003+
return False
9821004
return True
9831005

9841006

1007+
# This pattern is for case when there are different scales for requantize and
1008+
# maximum + clip + qnn.requantize can't be offloaded to NPU by one operation due to
1009+
# hardware constraints.
1010+
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
9851011
def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
9861012
"""
987-
This function creates the pattern for maximum with optional fused RELU activation.
1013+
This function creates the pattern for maximum with optional fused RELU activation without
1014+
requantize.
9881015
"""
9891016
maximum = is_op("maximum")(wildcard(), wildcard())
9901017
optional_max_clip = is_op("clip")(maximum)
991-
optional_max_clip = is_op("qnn.requantize")(
992-
optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
993-
)
9941018
return maximum | optional_max_clip
9951019

9961020

1021+
def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
1022+
"""
1023+
This function creates the pattern for maximum with fused RELU activation with requantize.
1024+
"""
1025+
pattern = is_op("maximum")(wildcard(), wildcard())
1026+
pattern = is_op("clip")(pattern)
1027+
pattern = is_op("qnn.requantize")(
1028+
pattern, is_constant(), is_constant(), is_constant(), is_constant()
1029+
)
1030+
return pattern
1031+
1032+
9971033
class ShlParams(BinaryElementwiseParams):
9981034
"""
9991035
This class will parse a call to a ethosu.binary_elementwise Shl composite function
@@ -1913,11 +1949,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
19131949
qnn_mul_pattern(),
19141950
lambda pat: MulParams(pat).is_valid(),
19151951
),
1952+
(
1953+
MinParams.composite_name,
1954+
minimum_clip_requantize_pattern(),
1955+
lambda pat: MinParams(pat).is_valid(),
1956+
),
19161957
(
19171958
MinParams.composite_name,
19181959
minimum_pattern(),
19191960
lambda pat: MinParams(pat).is_valid(),
19201961
),
1962+
(
1963+
MaxParams.composite_name,
1964+
maximum_clip_requantize_pattern(),
1965+
lambda pat: MaxParams(pat).is_valid(),
1966+
),
19211967
(
19221968
MaxParams.composite_name,
19231969
maximum_pattern(),

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,29 @@ def conv2d_relu6(x):
11911191
)
11921192

11931193

1194+
# Specific case when operation cannot be offloaded to NPU by single binary elementwise operation because
1195+
# min and max operations cannot be fused with requantize if there are different scales as it's not supported on NPU.
1196+
@pytest.mark.parametrize("operation", [tf.math.minimum, tf.math.maximum])
1197+
def test_tflite_min_max_relu_n1_to_1(operation):
1198+
np.random.seed(0)
1199+
accel_type = "ethos-u55-128"
1200+
ifm_shape = (1, 12, 16, 8)
1201+
1202+
@tf.function
1203+
def min_max_relu_n1_to_1(lhs, rhs):
1204+
op = operation(lhs, rhs)
1205+
# The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
1206+
return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))
1207+
1208+
infra.compare_tvm_with_tflite(
1209+
min_max_relu_n1_to_1,
1210+
[ifm_shape, ifm_shape],
1211+
accel_type,
1212+
enable_cascader=True,
1213+
ranges=[(-1, 1), (0, 2)],
1214+
)
1215+
1216+
11941217
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
11951218
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
11961219
@pytest.mark.parametrize("ofm_channels", [32, 64])

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def partition_ethosu_by_table(mod, pattern_table):
5353
return mod
5454

5555

56+
def relu_n1_to_1(x):
57+
"""
58+
The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
59+
"""
60+
return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
61+
62+
5663
def test_split_indices_legalize():
5764
def create_graph(axis):
5865
x = relay.var("x", shape=(1, 50, 50, 3))
@@ -881,7 +888,7 @@ def verify(ext_func):
881888
([1, 4, 4], [4, 1], False),
882889
],
883890
)
884-
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
891+
@pytest.mark.parametrize("activation_function", [None, tf.nn.relu])
885892
def test_tflite_binary_elemwise_legalize(
886893
operator_type,
887894
ifm_shape,
@@ -906,8 +913,8 @@ def tf_function(self, x, y):
906913
op = tf.math.minimum(x, y)
907914
elif operator_type == "MAX":
908915
op = tf.math.maximum(x, y)
909-
if activation_function == "RELU":
910-
op = tf.nn.relu(op)
916+
if activation_function:
917+
op = activation_function(op)
911918
return op
912919

913920
model = Model()
@@ -938,9 +945,13 @@ def verify(ext_func):
938945
op = ext_func.body
939946

940947
has_reshaped_output = False
948+
has_separate_requantize = False
941949
shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
942950
out_padded = [1] * (4 - len(out_shape)) + out_shape
943-
if op.op.name != "contrib.ethosu.binary_elementwise":
951+
if op.op.name == "contrib.ethosu.identity":
952+
op = op.args[0]
953+
has_separate_requantize = True
954+
if op.op.name == "reshape":
944955
has_reshaped_output = True
945956
op = op.args[0]
946957

@@ -951,20 +962,30 @@ def verify(ext_func):
951962
assert op.checked_type.dtype == dtype
952963
assert op.attrs.operator_type == operator_type
953964
assert op.attrs.reversed_operands == reversed_operands
954-
if activation_function == "RELU":
965+
if activation_function != None:
955966
assert str(op.attrs.activation) == "CLIP"
956967

957968
if operator_type in ["MIN", "MAX"]:
958-
# MIN and MAX with an activation must have a requantize operation
959-
# baked into the output. To check the extra requantize node was
960-
# picked up by the pattern, we can make sure the quantization
961-
# information is not default.
962-
assert float(op.attrs.ifm_scale) != 1.0
963-
assert int(op.attrs.ifm_zero_point) != 0
964-
assert float(op.attrs.ifm2_scale) != 1.0
965-
assert int(op.attrs.ifm2_zero_point) != 0
966-
assert float(op.attrs.ofm_scale) != 1.0
967-
assert int(op.attrs.ofm_zero_point) != 0
969+
if has_separate_requantize:
970+
# In case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints
971+
# there should be default quantization values since requantize is separate operation.
972+
assert float(op.attrs.ifm_scale) == 1.0
973+
assert int(op.attrs.ifm_zero_point) == 0
974+
assert float(op.attrs.ifm2_scale) == 1.0
975+
assert int(op.attrs.ifm2_zero_point) == 0
976+
assert float(op.attrs.ofm_scale) == 1.0
977+
assert int(op.attrs.ofm_zero_point) == 0
978+
else:
979+
# MIN and MAX with an activation must have a requantize operation
980+
# baked into the output. To check the extra requantize node was
981+
# picked up by the pattern, we can make sure the quantization
982+
# information is not default.
983+
assert float(op.attrs.ifm_scale) != 1.0
984+
assert int(op.attrs.ifm_zero_point) != 0
985+
assert float(op.attrs.ifm2_scale) != 1.0
986+
assert int(op.attrs.ifm2_zero_point) != 0
987+
assert float(op.attrs.ofm_scale) != 1.0
988+
assert int(op.attrs.ofm_zero_point) != 0
968989

969990
if has_reshaped_output:
970991
assert list(ext_func.body.checked_type.shape) == out_shape
@@ -997,22 +1018,42 @@ def verify(ext_func):
9971018
),
9981019
]
9991020
elif operator_type == "MIN":
1000-
rewriter = legalize.MinRewriter()
1021+
rewriter = [legalize.MinRewriter(), legalize.RequantizeRewriter()]
10011022
pattern_table = [
1023+
(
1024+
ethosu.MinParams.composite_name,
1025+
ethosu.minimum_clip_requantize_pattern(),
1026+
lambda pat: ethosu.MinParams(pat).is_valid(),
1027+
),
10021028
(
10031029
ethosu.MinParams.composite_name,
10041030
ethosu.minimum_pattern(),
10051031
lambda pat: ethosu.MinParams(pat).is_valid(),
10061032
),
1033+
(
1034+
ethosu.RequantizeParams.composite_name,
1035+
ethosu.requantize_pattern(),
1036+
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
1037+
),
10071038
]
10081039
elif operator_type == "MAX":
1009-
rewriter = legalize.MaxRewriter()
1040+
rewriter = [legalize.MaxRewriter(), legalize.RequantizeRewriter()]
10101041
pattern_table = [
1042+
(
1043+
ethosu.MaxParams.composite_name,
1044+
ethosu.maximum_clip_requantize_pattern(),
1045+
lambda pat: ethosu.MaxParams(pat).is_valid(),
1046+
),
10111047
(
10121048
ethosu.MaxParams.composite_name,
10131049
ethosu.maximum_pattern(),
10141050
lambda pat: ethosu.MaxParams(pat).is_valid(),
10151051
),
1052+
(
1053+
ethosu.RequantizeParams.composite_name,
1054+
ethosu.requantize_pattern(),
1055+
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
1056+
),
10161057
]
10171058

10181059
tflite_graph = create_tflite_graph()
@@ -1031,6 +1072,12 @@ def verify(ext_func):
10311072
verify(mod["tvmgen_default_ethos_u_main_0"])
10321073

10331074

1075+
# This test is for checking the case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints.
1076+
def test_tflite_max_relu_n1_to_1_legalize():
1077+
ifm_shape = [1, 4, 8, 16]
1078+
test_tflite_binary_elemwise_legalize("MAX", ifm_shape, ifm_shape, False, relu_n1_to_1)
1079+
1080+
10341081
def test_binary_add_from_constant_scalar():
10351082
dtype = "uint8"
10361083
ifm_shape = (1, 4, 4, 8)

0 commit comments

Comments
 (0)