Skip to content

Commit ac9a943

Browse files
authored
[TOPI][Testing] Enable conv2d NHWC fp16 topi testing for arm_cpu (#17007)
This commit adds fp16 test cases to the conv2d NHWC TOPI schedules for `arm_cpu`. Following the example of #8529, the numpy reference conv2d output is computed in fp32 instead of fp16, while the absolute tolerance varies for each test case according to the size of the summed axis and the output's largest element.
1 parent a5862a5 commit ac9a943

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

python/tvm/testing/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,13 @@ def _has_cpu_feat(features):
10571057
)
10581058

10591059

1060+
requires_arm_fp16 = Feature(
1061+
"arm_fp16",
1062+
"Arm(R) Neon(TM) instructions for FP16",
1063+
run_time_check=lambda: _has_cpu_feat("fullfp16"),
1064+
)
1065+
1066+
10601067
requires_aarch64_sve = Feature(
10611068
"arm_sve",
10621069
"AArch64 SVE",

tests/python/topi/test_topi_conv2d_nhwc.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
5454
),
5555
(
56-
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
56+
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16",
5757
topi.arm_cpu.compute_conv2d_NHWC_hybrid,
5858
topi.arm_cpu.schedule_conv2d_NHWC_hybrid,
5959
),
@@ -64,7 +64,7 @@
6464
),
6565
)
6666

67-
dtype = tvm.testing.parameter("float32")
67+
dtype = tvm.testing.parameter("float16", "float32")
6868

6969
batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters(
7070
# Pad M, N, K
@@ -104,14 +104,36 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd
104104
a_shape = (batch, in_height, in_width, in_channel)
105105
w_shape = (kernel, kernel, in_channel, num_filter)
106106

107+
np.random.seed(0)
107108
a_np = np.random.uniform(size=a_shape).astype(dtype)
108109
w_np = np.random.uniform(size=w_shape).astype(dtype)
109110
dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
110-
b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
111+
112+
# scipy.signal.convolve2d does not support float16 data types,
113+
# and the python fallback would be too slow for general use.
114+
conv_dtype = "float32" if dtype == "float16" else dtype
115+
b_np = tvm.topi.testing.conv2d_nhwc_python(
116+
a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding
117+
).astype(dtype)
111118
return a_np, w_np, b_np
112119

113120

114-
def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilation):
121+
def get_tolerance(dtype, w_np, b_np):
122+
if dtype == "float16":
123+
# A summation in float16 with a single accumulator very
124+
# quickly runs into large rounding errors.
125+
# This tolerance is necessary to ensure no false negatives,
126+
# but it may introduce false positives, depending on schedule behaviour.
127+
num_values_summed = w_np.shape[0] * w_np.shape[1] * w_np.shape[2]
128+
next_float_gap_size = np.nextafter(b_np.max(), np.inf, dtype=b_np.dtype) - b_np.max()
129+
tol = {"rtol": 1e-5, "atol": num_values_summed * next_float_gap_size / 2}
130+
else:
131+
tol = {"rtol": 1e-5, "atol": 1e-7}
132+
133+
return tol
134+
135+
136+
def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation):
115137
a_np, w_np, b_np = ref_data
116138

117139
A = te.placeholder(a_np.shape, name="A", dtype=dtype)
@@ -130,14 +152,21 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio
130152

131153
# Run only on AArch64 devices
132154
# Do not run SVE schedules on non-SVE devices
133-
build_only = platform.machine() != "aarch64" or (
134-
target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check()
155+
build_only = (
156+
platform.machine() != "aarch64"
157+
or (target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check())
158+
or (
159+
dtype == "float16"
160+
and target.features.has_fp16_simd
161+
and not tvm.testing.requires_arm_fp16.run_time_check()
162+
)
135163
)
136164
if build_only:
137165
return
138166

139167
func(a, w, b)
140-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
168+
tol = get_tolerance(dtype, w_np, b_np)
169+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
141170

142171

143172
def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation):
@@ -155,7 +184,8 @@ def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilatio
155184
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
156185
func = tvm.build(s, [A, W, B], target)
157186
func(a, w, b)
158-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
187+
tol = get_tolerance(dtype, w_np, b_np)
188+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
159189

160190

161191
def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
@@ -184,7 +214,8 @@ def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
184214
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
185215
func = tvm.build(s, [A, W, B], target)
186216
func(a, w, b)
187-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
217+
tol = get_tolerance(dtype, w_np_hwio, b_np)
218+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
188219

189220

190221
if __name__ == "__main__":

0 commit comments

Comments
 (0)