Skip to content

Commit 663b1e4

Browse files
committed
test multiple requantize offload
Change-Id: I60a3283461a7a7083c05289e84f570698388077b
1 parent 677112f commit 663b1e4

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,5 +1559,49 @@ def verify(ext_func):
15591559
verify(mod["tvmgen_default_ethos_u_main_0"])
15601560

15611561

1562+
def test_multiple_requantize_offload():
1563+
"""
1564+
Testing requantize offload in the case one requauntize operation is part of
1565+
an existing pattern (in this case Mean: cast->mean->requantize) and the
1566+
other is a stand-alone requantize.
1567+
"""
1568+
1569+
def create_model():
1570+
ifm = relay.var("input", shape=(1, 3, 3, 4), dtype="int8")
1571+
cast = relay.cast(ifm, dtype="int32")
1572+
mean = relay.mean(cast, axis=1, keepdims=True)
1573+
requantize = relay.qnn.op.requantize(
1574+
mean,
1575+
input_scale=relay.const(1.0, dtype="float32"),
1576+
input_zero_point=relay.const(0, dtype="int32"),
1577+
output_scale=relay.const(1.0, dtype="float32"),
1578+
output_zero_point=relay.const(0, dtype="int32"),
1579+
)
1580+
requantize = relay.qnn.op.requantize(
1581+
requantize,
1582+
input_scale=relay.const(1.0, dtype="float32"),
1583+
input_zero_point=relay.const(0, dtype="int32"),
1584+
output_scale=relay.const(1.0, dtype="float32"),
1585+
output_zero_point=relay.const(0, dtype="int32"),
1586+
)
1587+
return tvm.IRModule.from_expr(relay.Function([ifm], requantize))
1588+
1589+
def verify(ext_func):
1590+
# If mean operation and separate requantize were offloaded correctly,
1591+
# there should only be a pooling operation followed by an identity
1592+
# operation leagalized.
1593+
op = ext_func.body
1594+
assert op.op.name == "contrib.ethosu.identity"
1595+
op = op.args[0]
1596+
assert ext_func.body.args[0].op.name == "contrib.ethosu.pooling"
1597+
op = op.args[0]
1598+
assert isinstance(op, relay.Var)
1599+
1600+
mod = create_model()
1601+
mod = ethosu.partition_for_ethosu(mod)
1602+
mod = legalize.LegalizeEthosU()(mod)
1603+
verify(mod["tvmgen_default_ethos_u_main_0"])
1604+
1605+
15621606
if __name__ == "__main__":
15631607
pytest.main([__file__])

0 commit comments

Comments
 (0)