Skip to content

Commit eeb9d55

Browse files
committed
[microNPU] Add support for requantize
Adds support for stand-alone requantize operation which is legalized to an identity operation on the NPU. Change-Id: Ie2450c5fc72f405eddf517593236074aa4716c3b
1 parent 0d2340c commit eeb9d55

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,49 @@ def __call__(self, *args, **kwargs):
12421242
pass
12431243

12441244

1245+
class RequantizeRewriter(DFPatternCallback):
1246+
"""Convert ethos-u.requantize composite function to an identity operation."""
1247+
1248+
def __init__(self):
1249+
super().__init__(require_type=True)
1250+
self.pattern = (
1251+
wildcard().has_attr({"Composite": ethosu_patterns.RequantizeParams.composite_name})
1252+
)(wildcard())
1253+
1254+
def callback(
1255+
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
1256+
) -> tvm.relay.Expr:
1257+
params = ethosu_patterns.RequantizeParams(post.op.body)
1258+
params.ifm.tensor = post.args[0]
1259+
1260+
lut = relay.const([], "int8")
1261+
1262+
return ethosu_ops.ethosu_identity(
1263+
ifm=params.ifm.tensor,
1264+
lut=lut,
1265+
ifm_scale=float(params.ifm.q_params.scale_f32),
1266+
ifm_zero_point=int(params.ifm.q_params.zero_point),
1267+
ofm_scale=float(params.ofm.q_params.scale_f32),
1268+
ofm_zero_point=int(params.ofm.q_params.zero_point),
1269+
)
1270+
1271+
1272+
@ir.transform.module_pass(opt_level=1)
1273+
class LegalizeRequantize:
1274+
"""This is the pass that wraps RequantizeRewriter."""
1275+
1276+
def transform_module(
1277+
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
1278+
) -> tvm.ir.IRModule:
1279+
for global_var, func in mod.functions.items():
1280+
func = rewrite(RequantizeRewriter(), func)
1281+
mod.update_func(global_var, func)
1282+
return mod
1283+
1284+
def __call__(self, *args, **kwargs):
1285+
pass
1286+
1287+
12451288
@ir.transform.module_pass(opt_level=1)
12461289
class LegalizeEthosU:
12471290
"""This is the pass to call graph-rewrites to perform graph transformation
@@ -1271,6 +1314,7 @@ def transform_module(
12711314
mod = LegalizeMean()(mod)
12721315
mod = LegalizeConcat()(mod)
12731316
mod = LegalizeSigmoid()(mod)
1317+
mod = LegalizeRequantize()(mod)
12741318
mod = LegalizeReshape()(mod)
12751319
mod = LegalizeStridedSlice()(mod)
12761320
mod = LegalizeNoOps()(mod)

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,60 @@ def split_pattern():
11451145
return split
11461146

11471147

1148+
class RequantizeParams:
1149+
"""
1150+
This class will parse a call to ethos-u.requantize composite function
1151+
and extract the parameter information.
1152+
"""
1153+
1154+
composite_name = "ethos-u.requantize"
1155+
1156+
def __init__(self, func_body: Call):
1157+
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
1158+
1159+
layout = "NHWC"
1160+
in_var = func_body.args[0]
1161+
requantize = func_body
1162+
1163+
self.ifm = TensorParams(
1164+
in_var,
1165+
layout=layout,
1166+
scale=requantize.args[RequantArgs.IFM_SCALE.value],
1167+
zero_point=requantize.args[RequantArgs.IFM_ZERO_POINT.value],
1168+
)
1169+
self.ofm = TensorParams(
1170+
requantize,
1171+
layout=layout,
1172+
scale=requantize.args[RequantArgs.OFM_SCALE.value],
1173+
zero_point=requantize.args[RequantArgs.OFM_ZERO_POINT.value],
1174+
)
1175+
1176+
attrs = requantize.attrs
1177+
self.out_dtype = attrs.out_dtype
1178+
1179+
def is_valid(self) -> bool:
1180+
"""
1181+
Checks whether qnn.requantize has compatible attributes with HW.
1182+
"""
1183+
tensor_params = [self.ifm, self.ofm]
1184+
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
1185+
return False
1186+
if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
1187+
return False
1188+
if self.out_dtype and self.out_dtype != "int8":
1189+
return False
1190+
return True
1191+
1192+
1193+
def requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
1194+
"""
1195+
This function creates the pattern for qnn.requantize.
1196+
"""
1197+
return is_op("qnn.requantize")(
1198+
wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
1199+
)
1200+
1201+
11481202
@register_pattern_table("ethos-u")
11491203
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
11501204
return [
@@ -1230,6 +1284,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
12301284
split_pattern(),
12311285
lambda pat: SplitParams(pat).is_valid(),
12321286
),
1287+
(
1288+
RequantizeParams.composite_name,
1289+
requantize_pattern(),
1290+
lambda pat: RequantizeParams(pat).is_valid(),
1291+
),
12331292
]
12341293

12351294

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,5 +986,36 @@ def split_func(x):
986986
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)
987987

988988

989+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
990+
@pytest.mark.parametrize(
991+
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
992+
[
993+
[(1, 8, 8, 3), 1.0, 0, 1.0, 0],
994+
[(1, 20, 30, 3), 1.345, 34, 0.32, -23],
995+
],
996+
)
997+
def test_ethosu_requantize(accel_type, ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
998+
dtype = "int8"
999+
ifm_shape = [1, 8, 8, 3]
1000+
1001+
def create_model():
1002+
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
1003+
requantize = relay.qnn.op.requantize(
1004+
ifm,
1005+
relay.const(ifm_scale, dtype="float32"),
1006+
relay.const(ifm_zp, dtype="int32"),
1007+
relay.const(ofm_scale, dtype="float32"),
1008+
relay.const(ofm_zp, dtype="int32"),
1009+
)
1010+
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))
1011+
1012+
cpu_mod = create_model()
1013+
input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype)}
1014+
output_data = generate_ref_data(cpu_mod, input_data)
1015+
ethosu_mod = partition_for_ethosu(cpu_mod)
1016+
1017+
_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)
1018+
1019+
9891020
if __name__ == "__main__":
9901021
pytest.main([__file__])

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
pytest.importorskip("ethosu.vela")
2222

2323
import math
24+
2425
import numpy as np
2526
import tensorflow as tf
2627
import tflite.Model
@@ -1502,5 +1503,61 @@ def verify(ext_func):
15021503
verify(mod["tvmgen_default_ethos_u_main_0"])
15031504

15041505

1506+
@pytest.mark.parametrize(
1507+
"ifm_shape,ifm_scale,ifm_zp,ofm_scale,ofm_zp",
1508+
[[(1, 8, 8, 3), 1.0, 0, 1.0, 0], [(1, 20, 30, 3), 1.345, 34, 0.32, -23]],
1509+
)
1510+
def test_ethosu_requantize(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp):
1511+
dtype = "int8"
1512+
ifm_shape = [1, 8, 8, 3]
1513+
1514+
def create_model():
1515+
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
1516+
requantize = relay.qnn.op.requantize(
1517+
ifm,
1518+
relay.const(ifm_scale, dtype="float32"),
1519+
relay.const(ifm_zp, dtype="int32"),
1520+
relay.const(ofm_scale, dtype="float32"),
1521+
relay.const(ofm_zp, dtype="int32"),
1522+
)
1523+
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))
1524+
1525+
def verify(ext_func):
1526+
op = ext_func.body
1527+
1528+
# Check IFM
1529+
ifm = op.args[0].checked_type
1530+
assert list(ifm.shape) == list(ifm_shape)
1531+
assert str(ifm.dtype) == dtype
1532+
1533+
# Check OFM
1534+
ofm = op.checked_type
1535+
assert list(ofm.shape) == list(ifm_shape)
1536+
assert str(ofm.dtype) == dtype
1537+
1538+
# Check quantization params
1539+
assert math.isclose(op.attrs.ifm_scale, ifm_scale, abs_tol=1e-7)
1540+
assert op.attrs.ifm_zero_point == ifm_zp
1541+
assert math.isclose(op.attrs.ofm_scale, ofm_scale, abs_tol=1e-7)
1542+
assert op.attrs.ofm_zero_point == ofm_zp
1543+
1544+
rewriter = legalize.RequantizeRewriter()
1545+
pattern_table = [
1546+
(
1547+
ethosu.RequantizeParams.composite_name,
1548+
ethosu.requantize_pattern(),
1549+
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
1550+
),
1551+
]
1552+
1553+
mod = create_model()
1554+
mod = partition_ethosu_by_table(mod, pattern_table)
1555+
1556+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
1557+
rewriter, mod["tvmgen_default_ethos_u_main_0"]
1558+
)
1559+
verify(mod["tvmgen_default_ethos_u_main_0"])
1560+
1561+
15051562
if __name__ == "__main__":
15061563
pytest.main([__file__])

0 commit comments

Comments
 (0)