Skip to content

Commit 42b4f21

Browse files
authored
[Hexagon][UnitTest] Disable flaky quantization test (#16337)
* [Hexagon][UnitTest] Disable flaky quantization test The `test_pass_fq2i_avg_pool2d.py::test_avgpool_conv2d` test is sensitive to rounding errors, and failed about a third of the time (42 / 100 tests). This was first noticed as CI failures in unrelated PRs (e.g. https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-hexagon/detail/PR-16184/6/tests). This commit marks the flaky portions of the test with `pytest.mark.xfail`, to avoid causing breaking CI for other PRs. To minimize the extent of the disabled test cases, this commit breaks up each of the unit tests. Where previously a single test performed both hardware/simulation tests and relay graph comparisons, these are now done in separate test functions. The hardware/simulation tests use `tvm.testing.assert_allclose` and have a tolerance of `1e-02`, while the graph-comparison tests use `tvm.ir.structural_equal`, and require identical floating-point values. Only the graph-comparison test is disabled here. The other two test cases in `test_pass_fq2i_avg_pool2d.py` do not show this same sensitivity, with no failures seen in 100 executions. * Disable pylint for pytest fixture names
1 parent eb15d04 commit 42b4f21

File tree

1 file changed

+69
-46
lines changed

1 file changed

+69
-46
lines changed

tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,53 +15,24 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
# pylint: disable=redefined-outer-name
19+
1820
""" Tests for avg_pool2d fake quantization to integer """
1921

2022
import numpy as np
23+
import pytest
24+
2125
import tvm
2226
import tvm.testing
2327
import tvm.topi.testing
2428
from tvm import relay
2529
from tvm.contrib.hexagon.session import Session
2630
from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
27-
from .infrastructure import quantize_np, build_module, run_module
28-
29-
30-
def compare_graphs(expr, ref_expr):
31-
"""Compares the given graph with the expected graph"""
32-
mod = tvm.IRModule.from_expr(expr)
33-
mod = tvm.relay.transform.InferType()(mod)
34-
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
35-
ref_mod = tvm.IRModule.from_expr(ref_expr)
36-
ref_mod = tvm.relay.transform.InferType()(ref_mod)
37-
assert tvm.ir.structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)
38-
39-
40-
def compare_fq_to_int(hexagon_session, expr, inputs):
41-
"""Compares the float module output with the integer module output"""
42-
mod = tvm.IRModule.from_expr(expr)
43-
mod = tvm.relay.transform.InferType()(mod)
44-
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
45-
assert not tvm.ir.structural_equal(mod, mod_int)
46-
47-
mod = build_module(
48-
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
49-
)
50-
mod_int = build_module(
51-
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
52-
)
53-
54-
hexagon_mod = hexagon_session.get_executor_from_factory(mod)
55-
result = run_module(hexagon_mod, inputs)
56-
57-
hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
58-
result_int = run_module(hexagon_mod, inputs)
5931

60-
tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)
32+
from .infrastructure import quantize_np, build_module, run_module
6133

6234

63-
@tvm.testing.requires_hexagon
64-
def test_avgpool_conv2d(hexagon_session: Session):
35+
def _make_avgpool_conv2d():
6536
"""Test case with avg_pool2d followed by a conv2d"""
6637
dtype = "int8"
6738
shape_x = [1, 2, 9, 9]
@@ -112,8 +83,6 @@ def test_avgpool_conv2d(hexagon_session: Session):
11283
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
11384
args = {"input": input_quant, "weight": weight_quant}
11485

115-
compare_fq_to_int(hexagon_session, expr, args)
116-
11786
# Expected graph
11887
op0 = relay.qnn.op.avg_pool2d(
11988
inp,
@@ -148,11 +117,11 @@ def test_avgpool_conv2d(hexagon_session: Session):
148117
out_dtype="int8",
149118
)
150119
ref_expr = relay.qnn.op.dequantize(op2, out_sc, out_zp)
151-
compare_graphs(expr, ref_expr)
152120

121+
return expr, args, ref_expr
153122

154-
@tvm.testing.requires_hexagon
155-
def test_avgpool_avgpool(hexagon_session: Session):
123+
124+
def _make_avgpool_avgpool():
156125
"""Test case with avg_pool2d followed by an avg_pool2d"""
157126
dtype = "uint8"
158127
shape_x = [1, 2, 9, 9]
@@ -197,7 +166,6 @@ def test_avgpool_avgpool(hexagon_session: Session):
197166
expr = relay.qnn.op.quantize(op2, out_sc, out_zp, out_dtype=dtype)
198167
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
199168
args = {"input": input_quant}
200-
compare_fq_to_int(hexagon_session, expr, args)
201169

202170
# Expected graph
203171
op0 = relay.qnn.op.avg_pool2d(
@@ -227,12 +195,11 @@ def test_avgpool_avgpool(hexagon_session: Session):
227195
count_include_pad=False,
228196
)
229197
ref_expr = relay.qnn.op.dequantize(op1, out_sc, out_zp)
230-
compare_graphs(expr, ref_expr)
231198

199+
return expr, args, ref_expr
232200

233-
@tvm.testing.requires_hexagon
234-
def test_avgpool(hexagon_session: Session):
235-
"""Test case of a regular avg_pool2d"""
201+
202+
def _make_avgpool():
236203
dtype = "int8"
237204
shape_x = [1, 2, 9, 9]
238205
kernel = [3, 3]
@@ -266,7 +233,6 @@ def test_avgpool(hexagon_session: Session):
266233
expr = relay.qnn.op.quantize(op1, out_sc, out_zp, out_dtype=dtype)
267234
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
268235
args = {"input": input_quant}
269-
compare_fq_to_int(hexagon_session, expr, args)
270236

271237
# Expected graph
272238
op = relay.qnn.op.avg_pool2d(
@@ -283,6 +249,63 @@ def test_avgpool(hexagon_session: Session):
283249
count_include_pad=False,
284250
)
285251
ref_expr = relay.qnn.op.dequantize(op, out_sc, out_zp)
252+
253+
return expr, args, ref_expr
254+
255+
256+
def compare_graphs(expr, ref_expr):
257+
"""Compares the given graph with the expected graph"""
258+
mod = tvm.IRModule.from_expr(expr)
259+
mod = tvm.relay.transform.InferType()(mod)
260+
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
261+
ref_mod = tvm.IRModule.from_expr(ref_expr)
262+
ref_mod = tvm.relay.transform.InferType()(ref_mod)
263+
tvm.ir.assert_structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)
264+
265+
266+
def compare_fq_to_int(hexagon_session, expr, inputs):
267+
"""Compares the float module output with the integer module output"""
268+
mod = tvm.IRModule.from_expr(expr)
269+
mod = tvm.relay.transform.InferType()(mod)
270+
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
271+
assert not tvm.ir.structural_equal(mod, mod_int)
272+
273+
mod = build_module(
274+
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
275+
)
276+
mod_int = build_module(
277+
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
278+
)
279+
280+
hexagon_mod = hexagon_session.get_executor_from_factory(mod)
281+
result = run_module(hexagon_mod, inputs)
282+
283+
hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
284+
result_int = run_module(hexagon_mod, inputs)
285+
286+
tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)
287+
288+
289+
avgpool_test_case = tvm.testing.parameter(
290+
_make_avgpool,
291+
_make_avgpool_avgpool,
292+
pytest.param(
293+
_make_avgpool_conv2d,
294+
marks=pytest.mark.xfail(
295+
reason="Rounding differences causing mismatch of Constant, difference around 10^-7"
296+
),
297+
),
298+
)
299+
300+
301+
@tvm.testing.requires_hexagon
302+
def test_execution(hexagon_session: Session, avgpool_test_case):
303+
expr, args, _ = avgpool_test_case()
304+
compare_fq_to_int(hexagon_session, expr, args)
305+
306+
307+
def test_quantization(avgpool_test_case):
308+
expr, _, ref_expr = avgpool_test_case()
286309
compare_graphs(expr, ref_expr)
287310

288311

0 commit comments

Comments
 (0)