Skip to content

Commit 94d01d3

Browse files
authored
[microNPU] Add support for hard swish (#12120)
Adds support for hard swish by populating a LUT similar to Vela's implementation. Change-Id: I7ca15a3e21bc91c1b41cdd4547fabaa00de96e90
1 parent 6bad21e commit 94d01d3

File tree

4 files changed

+201
-0
lines changed

4 files changed

+201
-0
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,83 @@ def calculate_lut_value(i):
298298
return identity
299299

300300

301+
class HardSwishRewriter(DFPatternCallback):
302+
"""Convert ethosu.hard_swish composite function to add operation with LUT."""
303+
304+
def __init__(self):
305+
super().__init__(require_type=True, rewrite_once=True)
306+
self.params_class = ethosu_patterns.HardSwishParams
307+
self.pattern = wildcard().has_attr({"Composite": self.params_class.composite_name})(
308+
wildcard()
309+
)
310+
311+
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
312+
params = self.params_class(post.op.body)
313+
params.ifm.tensor = post.args[0]
314+
315+
# The calculation of the LUT values is similar to that in Vela
316+
# convert_hardswish_to_lut(op, arch, nng)
317+
# (https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#719) # pylint: disable=line-too-long
318+
input_scale = np.double(params.ifm.q_params.scale_f32)
319+
input_zp = int(params.ifm.q_params.zero_point)
320+
hires_input_scale = (1 / 128) * input_scale
321+
322+
output_scale = np.double(params.ofm.q_params.scale_f32)
323+
output_zp = int(params.ofm.q_params.zero_point)
324+
output_scale, output_shift = scaling.quantise_scale(hires_input_scale / output_scale)
325+
output_scale_16 = fp_math.downscale_multiplier_int32_to_int16(output_scale)
326+
output_shift = 31 - output_shift
327+
output_shift = -output_shift if output_shift < 0 else 0
328+
329+
dtype = params.ifm.dtype
330+
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
331+
332+
def calculate_relu_multiplier(inp, input_scale):
333+
rmultiplier = np.double(3 / 32768)
334+
rscale, rshift = scaling.quantise_scale(input_scale / rmultiplier)
335+
rscale_16 = fp_math.downscale_multiplier_int32_to_int16(rscale)
336+
337+
rvalue = np.int16(inp)
338+
if rshift < 31:
339+
rvalue = fp_math.shift_left16(rvalue, 30 - rshift)
340+
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
341+
rvalue = fp_math.shift_left16(rvalue, 1)
342+
elif rshift > 31:
343+
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
344+
rvalue = fp_math.rounding_divide_by_pot(rvalue, rshift - 31)
345+
else:
346+
rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
347+
348+
rvalue = (rvalue + (1 << 15)) >> 1
349+
return rvalue
350+
351+
def calculate_lut_values(i):
352+
hires_input_value = (i - input_zp) * 128
353+
preshift_input_value = fp_math.saturating_rounding_mul16(
354+
hires_input_value, output_scale_16
355+
)
356+
relu_value = calculate_relu_multiplier(hires_input_value, hires_input_scale)
357+
lut_result = fp_math.saturating_mul16(relu_value, preshift_input_value)
358+
lut_result = fp_math.rounding_divide_by_pot(lut_result, output_shift) + output_zp
359+
return min(qmax, max(qmin, lut_result))
360+
361+
values = list(map(calculate_lut_values, range(-128, 128)))
362+
lut = relay.const(values, dtype=dtype)
363+
364+
# We baked the requantization into the LUT, so we don't requantize the identity operator
365+
identity = ethosu_ops.ethosu_identity(
366+
ifm=params.ifm.tensor,
367+
lut=lut,
368+
ifm_scale=input_scale,
369+
ifm_zero_point=input_zp,
370+
ofm_scale=input_scale,
371+
ofm_zero_point=input_zp,
372+
activation="LUT",
373+
)
374+
375+
return identity
376+
377+
301378
class Conv2DRewriter(DFPatternCallback):
302379
"""Convert conv2d related composite functions into ethosu_conv2d operators"""
303380

@@ -1306,6 +1383,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
13061383
ShlRewriter(),
13071384
AbsRewriter(),
13081385
TanhRewriter(),
1386+
HardSwishRewriter(),
13091387
LeakyReLURewriter(),
13101388
MeanRewriter(),
13111389
ConcatRewriter(),

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,54 @@ def qnn_fc_pattern():
17241724
return optional_clip
17251725

17261726

1727+
class HardSwishParams:
1728+
"""
1729+
This class will parse a call to a ethos-u.hard_swish composite function
1730+
and extract the parameter information.
1731+
"""
1732+
1733+
composite_name = "ethos-u.hard_swish"
1734+
1735+
def __init__(self, func_body):
1736+
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
1737+
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
1738+
1739+
quantize = func_body
1740+
divide = quantize.args[0]
1741+
multiply = divide.args[0]
1742+
clip = multiply.args[1]
1743+
add = clip.args[0]
1744+
dequantize = add.args[0]
1745+
1746+
self.ifm = TensorParams(
1747+
dequantize.args[0],
1748+
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
1749+
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
1750+
)
1751+
self.ofm = TensorParams(
1752+
quantize,
1753+
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
1754+
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
1755+
)
1756+
1757+
def is_valid(self):
1758+
tensor_params = [self.ifm, self.ofm]
1759+
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
1760+
return False
1761+
return True
1762+
1763+
1764+
def hard_swish_pattern():
1765+
"""Create the pattern for hard swish."""
1766+
dequantize = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
1767+
add = is_op("add")(dequantize, is_constant())
1768+
clip = is_op("clip")(add)
1769+
multiply = is_op("multiply")(dequantize, clip)
1770+
divide = is_op("divide")(multiply, is_constant())
1771+
quantize = is_op("qnn.quantize")(divide, is_constant(), is_constant())
1772+
return quantize
1773+
1774+
17271775
@register_pattern_table("ethos-u")
17281776
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
17291777
return [
@@ -1844,6 +1892,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
18441892
squeeze_pattern(),
18451893
lambda pat: SqueezeParams(pat).is_valid(),
18461894
),
1895+
(
1896+
HardSwishParams.composite_name,
1897+
hard_swish_pattern(),
1898+
lambda pat: HardSwishParams(pat).is_valid(),
1899+
),
18471900
]
18481901

18491902

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,21 @@ def tanh_func(x):
819819
)
820820

821821

822+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
823+
@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
824+
def test_tflite_hard_swish(accel_type, ifm_shape):
825+
np.random.seed(0)
826+
827+
@tf.function
828+
def hard_swish_func(x):
829+
op = tf.keras.layers.Lambda(
830+
lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0
831+
)(x)
832+
return op
833+
834+
infra.compare_tvm_with_tflite(hard_swish_func, [ifm_shape], accel_type, ranges=[(-1, 1)])
835+
836+
822837
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
823838
@pytest.mark.parametrize(
824839
"shapes, axis",

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2751,5 +2751,60 @@ def verify(ext_func):
27512751
verify(mod["tvmgen_default_ethos_u_main_0"])
27522752

27532753

2754+
@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
2755+
def test_tflite_hard_swish(ifm_shape):
2756+
dtype = "int8"
2757+
2758+
def create_tflite_graph():
2759+
class Model(tf.Module):
2760+
@tf.function
2761+
def tf_function(self, x):
2762+
op = tf.keras.layers.Lambda(
2763+
lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0
2764+
)(x)
2765+
return op
2766+
2767+
model = Model()
2768+
concrete_func = model.tf_function.get_concrete_function(
2769+
tf.TensorSpec(ifm_shape, tf.float32)
2770+
)
2771+
2772+
def representative_dataset():
2773+
for _ in range(100):
2774+
data = np.random.rand(*tuple(ifm_shape))
2775+
yield [data.astype(np.float32)]
2776+
2777+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
2778+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
2779+
converter.representative_dataset = representative_dataset
2780+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
2781+
converter.inference_input_type = tf.int8
2782+
converter.inference_output_type = tf.int8
2783+
tflite_model = converter.convert()
2784+
2785+
return tflite_model
2786+
2787+
tflite_graph = create_tflite_graph()
2788+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
2789+
2790+
mod, params = relay.frontend.from_tflite(
2791+
tflite_model,
2792+
shape_dict={"input": ifm_shape},
2793+
dtype_dict={"input": dtype},
2794+
)
2795+
2796+
mod = ethosu.partition_for_ethosu(mod, params)
2797+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
2798+
legalize.HardSwishRewriter(), mod["tvmgen_default_ethos_u_main_0"]
2799+
)
2800+
mod = relay.transform.InferType()(mod)
2801+
2802+
func_body = mod["tvmgen_default_ethos_u_main_0"].body
2803+
assert func_body.op.name == "contrib.ethosu.identity"
2804+
assert func_body.attrs.activation == "LUT"
2805+
assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
2806+
assert tuple(func_body.args[1].checked_type.shape) == (256,)
2807+
2808+
27542809
if __name__ == "__main__":
27552810
pytest.main([__file__])

0 commit comments

Comments
 (0)