Skip to content

Commit 6b7629b

Browse files
committed
More comprehensive testing of is_int8_hw_support
1 parent 0ffa19b commit 6b7629b

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

python/tvm/topi/arm_cpu/conv2d_int8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
125125

126126
# 3) Check target
127127
current_target = target.Target.current(allow_none=False)
128-
is_target_support = bool(current_target.features.has_asimd or current_target.features.has_dotprod)
128+
is_target_support = bool(
129+
current_target.features.has_asimd or current_target.features.has_dotprod
130+
)
129131

130132
return is_dtype_support and is_llvm_support and is_target_support
131133

tests/python/target/test_arm_target.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,32 @@
2020
from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support
2121
from tvm.target import codegen
2222

23-
arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters(
23+
llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters(
2424
# Testing mcpu type
25-
("c -mcpu=cortex-m4 -keys=arm_cpu", "int8", "int8", False),
26-
("c -mcpu=cortex-m7 -keys=arm_cpu", "int8", "int8", False),
27-
("c -mcpu=cortex-m33 -keys=arm_cpu", "int8", "int8", False),
28-
("c -mcpu=cortex-m55 -keys=arm_cpu", "int8", "int8", False),
29-
("c -mcpu=cortex-m3 -keys=arm_cpu", "int8", "int8", False),
30-
("llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int8", "int8", True),
31-
("llvm -mtriple=aarch64-linux-gnu -mattr=+v8.4a,+dotprod", "int8", "int8", True),
25+
(8, "c -mcpu=cortex-m4", "int8", "int8", False),
26+
(8, "c -mcpu=cortex-m7", "int8", "int8", False),
27+
(8, "c -mcpu=cortex-m33", "int8", "int8", False),
28+
(8, "c -mcpu=cortex-m55", "int8", "int8", False),
29+
(8, "c -mcpu=cortex-m3", "int8", "int8", False),
30+
31+
(7, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", False),
32+
(8, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True),
33+
(9, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True),
34+
(8, "llvm -mtriple=arm-linux-gnueabi", "int8", "int8", False),
35+
36+
(7, "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.4a,+dotprod", "int8", "int8", False),
37+
(8, "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.4a,+dotprod", "int8", "int8", True),
38+
(9, "llvm -mtriple=arm-linux-gnueabi -mattr=+neon", "int8", "int8", True),
39+
(8, "llvm -mtriple=aarch64-linux-gnu", "int8", "int8", True),
40+
3241
# Testing dtype
33-
("llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int8", False),
34-
("llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int8", "int16", False),
35-
("llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int16", False),
42+
(8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int8", False),
43+
(8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int8", "int16", False),
44+
(8, "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", "int16", "int16", False),
3645
)
3746

3847

39-
def test_arm_conv2d_int8_support(arm_target, input_dtype, kernel_dtype, is_supported):
48+
def test_arm_conv2d_int8_support(monkeypatch, llvm_version, arm_target, input_dtype, kernel_dtype, is_supported):
4049
"""Test ARM conv2d int8 support for different targets.
4150
4251
Parameters
@@ -51,5 +60,5 @@ def test_arm_conv2d_int8_support(arm_target, input_dtype, kernel_dtype, is_suppo
5160
Expected result.
5261
"""
5362
with tvm.target.Target(arm_target):
54-
expected_result = is_supported and (codegen.llvm_version_major() >= 8)
55-
assert is_int8_hw_support(input_dtype, kernel_dtype) == expected_result
63+
monkeypatch.setattr(codegen, 'llvm_version_major', lambda: llvm_version)
64+
assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported

0 commit comments

Comments
 (0)