@@ -1559,5 +1559,49 @@ def verify(ext_func):
15591559 verify (mod ["tvmgen_default_ethos_u_main_0" ])
15601560
15611561
1562+ def test_multiple_requantize_offload ():
1563+ """
1564+ Testing requantize offload in the case one requauntize operation is part of
1565+ an existing pattern (in this case Mean: cast->mean->requantize) and the
1566+ other is a stand-alone requantize.
1567+ """
1568+
1569+ def create_model ():
1570+ ifm = relay .var ("input" , shape = (1 , 3 , 3 , 4 ), dtype = "int8" )
1571+ cast = relay .cast (ifm , dtype = "int32" )
1572+ mean = relay .mean (cast , axis = 1 , keepdims = True )
1573+ requantize = relay .qnn .op .requantize (
1574+ mean ,
1575+ input_scale = relay .const (1.0 , dtype = "float32" ),
1576+ input_zero_point = relay .const (0 , dtype = "int32" ),
1577+ output_scale = relay .const (1.0 , dtype = "float32" ),
1578+ output_zero_point = relay .const (0 , dtype = "int32" ),
1579+ )
1580+ requantize = relay .qnn .op .requantize (
1581+ requantize ,
1582+ input_scale = relay .const (1.0 , dtype = "float32" ),
1583+ input_zero_point = relay .const (0 , dtype = "int32" ),
1584+ output_scale = relay .const (1.0 , dtype = "float32" ),
1585+ output_zero_point = relay .const (0 , dtype = "int32" ),
1586+ )
1587+ return tvm .IRModule .from_expr (relay .Function ([ifm ], requantize ))
1588+
1589+ def verify (ext_func ):
1590+ # If mean operation and separate requantize were offloaded correctly,
1591+ # there should only be a pooling operation followed by an identity
1592+ # operation leagalized.
1593+ op = ext_func .body
1594+ assert op .op .name == "contrib.ethosu.identity"
1595+ op = op .args [0 ]
1596+ assert ext_func .body .args [0 ].op .name == "contrib.ethosu.pooling"
1597+ op = op .args [0 ]
1598+ assert isinstance (op , relay .Var )
1599+
1600+ mod = create_model ()
1601+ mod = ethosu .partition_for_ethosu (mod )
1602+ mod = legalize .LegalizeEthosU ()(mod )
1603+ verify (mod ["tvmgen_default_ethos_u_main_0" ])
1604+
1605+
15621606if __name__ == "__main__" :
15631607 pytest .main ([__file__ ])
0 commit comments