1515# specific language governing permissions and limitations
1616# under the License.
1717
18+ # pylint: disable=redefined-outer-name
19+
1820""" Tests for avg_pool2d fake quantization to integer """
1921
2022import numpy as np
23+ import pytest
24+
2125import tvm
2226import tvm .testing
2327import tvm .topi .testing
2428from tvm import relay
2529from tvm .contrib .hexagon .session import Session
2630from tvm .contrib .hexagon .pytest_plugin import HEXAGON_AOT_LLVM_TARGET
27- from .infrastructure import quantize_np , build_module , run_module
28-
29-
30- def compare_graphs (expr , ref_expr ):
31- """Compares the given graph with the expected graph"""
32- mod = tvm .IRModule .from_expr (expr )
33- mod = tvm .relay .transform .InferType ()(mod )
34- mod_int = tvm .relay .transform .FakeQuantizationToInteger ()(mod )
35- ref_mod = tvm .IRModule .from_expr (ref_expr )
36- ref_mod = tvm .relay .transform .InferType ()(ref_mod )
37- assert tvm .ir .structural_equal (mod_int ["main" ], ref_mod ["main" ], map_free_vars = True )
38-
39-
40- def compare_fq_to_int (hexagon_session , expr , inputs ):
41- """Compares the float module output with the integer module output"""
42- mod = tvm .IRModule .from_expr (expr )
43- mod = tvm .relay .transform .InferType ()(mod )
44- mod_int = tvm .relay .transform .FakeQuantizationToInteger ()(mod )
45- assert not tvm .ir .structural_equal (mod , mod_int )
46-
47- mod = build_module (
48- mod , tvm .target .Target (HEXAGON_AOT_LLVM_TARGET , host = HEXAGON_AOT_LLVM_TARGET )
49- )
50- mod_int = build_module (
51- mod_int , tvm .target .Target (HEXAGON_AOT_LLVM_TARGET , host = HEXAGON_AOT_LLVM_TARGET )
52- )
53-
54- hexagon_mod = hexagon_session .get_executor_from_factory (mod )
55- result = run_module (hexagon_mod , inputs )
56-
57- hexagon_mod = hexagon_session .get_executor_from_factory (mod_int )
58- result_int = run_module (hexagon_mod , inputs )
5931
60- tvm . testing . assert_allclose ( result , result_int , rtol = 1e-02 , atol = 1e-02 )
32+ from . infrastructure import quantize_np , build_module , run_module
6133
6234
63- @tvm .testing .requires_hexagon
64- def test_avgpool_conv2d (hexagon_session : Session ):
35+ def _make_avgpool_conv2d ():
6536 """Test case with avg_pool2d followed by a conv2d"""
6637 dtype = "int8"
6738 shape_x = [1 , 2 , 9 , 9 ]
@@ -112,8 +83,6 @@ def test_avgpool_conv2d(hexagon_session: Session):
11283 expr = relay .qnn .op .dequantize (expr , out_sc , out_zp )
11384 args = {"input" : input_quant , "weight" : weight_quant }
11485
115- compare_fq_to_int (hexagon_session , expr , args )
116-
11786 # Expected graph
11887 op0 = relay .qnn .op .avg_pool2d (
11988 inp ,
@@ -148,11 +117,11 @@ def test_avgpool_conv2d(hexagon_session: Session):
148117 out_dtype = "int8" ,
149118 )
150119 ref_expr = relay .qnn .op .dequantize (op2 , out_sc , out_zp )
151- compare_graphs (expr , ref_expr )
152120
121+ return expr , args , ref_expr
153122
154- @ tvm . testing . requires_hexagon
155- def test_avgpool_avgpool ( hexagon_session : Session ):
123+
124+ def _make_avgpool_avgpool ( ):
156125 """Test case with avg_pool2d followed by an avg_pool2d"""
157126 dtype = "uint8"
158127 shape_x = [1 , 2 , 9 , 9 ]
@@ -197,7 +166,6 @@ def test_avgpool_avgpool(hexagon_session: Session):
197166 expr = relay .qnn .op .quantize (op2 , out_sc , out_zp , out_dtype = dtype )
198167 expr = relay .qnn .op .dequantize (expr , out_sc , out_zp )
199168 args = {"input" : input_quant }
200- compare_fq_to_int (hexagon_session , expr , args )
201169
202170 # Expected graph
203171 op0 = relay .qnn .op .avg_pool2d (
@@ -227,12 +195,11 @@ def test_avgpool_avgpool(hexagon_session: Session):
227195 count_include_pad = False ,
228196 )
229197 ref_expr = relay .qnn .op .dequantize (op1 , out_sc , out_zp )
230- compare_graphs (expr , ref_expr )
231198
199+ return expr , args , ref_expr
232200
233- @tvm .testing .requires_hexagon
234- def test_avgpool (hexagon_session : Session ):
235- """Test case of a regular avg_pool2d"""
201+
202+ def _make_avgpool ():
236203 dtype = "int8"
237204 shape_x = [1 , 2 , 9 , 9 ]
238205 kernel = [3 , 3 ]
@@ -266,7 +233,6 @@ def test_avgpool(hexagon_session: Session):
266233 expr = relay .qnn .op .quantize (op1 , out_sc , out_zp , out_dtype = dtype )
267234 expr = relay .qnn .op .dequantize (expr , out_sc , out_zp )
268235 args = {"input" : input_quant }
269- compare_fq_to_int (hexagon_session , expr , args )
270236
271237 # Expected graph
272238 op = relay .qnn .op .avg_pool2d (
@@ -283,6 +249,63 @@ def test_avgpool(hexagon_session: Session):
283249 count_include_pad = False ,
284250 )
285251 ref_expr = relay .qnn .op .dequantize (op , out_sc , out_zp )
252+
253+ return expr , args , ref_expr
254+
255+
256+ def compare_graphs (expr , ref_expr ):
257+ """Compares the given graph with the expected graph"""
258+ mod = tvm .IRModule .from_expr (expr )
259+ mod = tvm .relay .transform .InferType ()(mod )
260+ mod_int = tvm .relay .transform .FakeQuantizationToInteger ()(mod )
261+ ref_mod = tvm .IRModule .from_expr (ref_expr )
262+ ref_mod = tvm .relay .transform .InferType ()(ref_mod )
263+ tvm .ir .assert_structural_equal (mod_int ["main" ], ref_mod ["main" ], map_free_vars = True )
264+
265+
266+ def compare_fq_to_int (hexagon_session , expr , inputs ):
267+ """Compares the float module output with the integer module output"""
268+ mod = tvm .IRModule .from_expr (expr )
269+ mod = tvm .relay .transform .InferType ()(mod )
270+ mod_int = tvm .relay .transform .FakeQuantizationToInteger ()(mod )
271+ assert not tvm .ir .structural_equal (mod , mod_int )
272+
273+ mod = build_module (
274+ mod , tvm .target .Target (HEXAGON_AOT_LLVM_TARGET , host = HEXAGON_AOT_LLVM_TARGET )
275+ )
276+ mod_int = build_module (
277+ mod_int , tvm .target .Target (HEXAGON_AOT_LLVM_TARGET , host = HEXAGON_AOT_LLVM_TARGET )
278+ )
279+
280+ hexagon_mod = hexagon_session .get_executor_from_factory (mod )
281+ result = run_module (hexagon_mod , inputs )
282+
283+ hexagon_mod = hexagon_session .get_executor_from_factory (mod_int )
284+ result_int = run_module (hexagon_mod , inputs )
285+
286+ tvm .testing .assert_allclose (result , result_int , rtol = 1e-02 , atol = 1e-02 )
287+
288+
289+ avgpool_test_case = tvm .testing .parameter (
290+ _make_avgpool ,
291+ _make_avgpool_avgpool ,
292+ pytest .param (
293+ _make_avgpool_conv2d ,
294+ marks = pytest .mark .xfail (
295+ reason = "Rounding differences causing mismatch of Constant, difference around 10^-7"
296+ ),
297+ ),
298+ )
299+
300+
301+ @tvm .testing .requires_hexagon
302+ def test_execution (hexagon_session : Session , avgpool_test_case ):
303+ expr , args , _ = avgpool_test_case ()
304+ compare_fq_to_int (hexagon_session , expr , args )
305+
306+
307+ def test_quantization (avgpool_test_case ):
308+ expr , _ , ref_expr = avgpool_test_case ()
286309 compare_graphs (expr , ref_expr )
287310
288311
0 commit comments