5353 topi .arm_cpu .schedule_conv2d_nhwc_spatial_pack ,
5454 ),
5555 (
56- "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a" ,
56+ "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16 " ,
5757 topi .arm_cpu .compute_conv2d_NHWC_hybrid ,
5858 topi .arm_cpu .schedule_conv2d_NHWC_hybrid ,
5959 ),
6464 ),
6565)
6666
67- dtype = tvm .testing .parameter ("float32" )
67+ dtype = tvm .testing .parameter ("float16" , " float32" )
6868
6969batch , in_channel , in_size , num_filter , kernel , stride , padding , dilation = tvm .testing .parameters (
7070 # Pad M, N, K
@@ -104,14 +104,36 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd
104104 a_shape = (batch , in_height , in_width , in_channel )
105105 w_shape = (kernel , kernel , in_channel , num_filter )
106106
107+ np .random .seed (0 )
107108 a_np = np .random .uniform (size = a_shape ).astype (dtype )
108109 w_np = np .random .uniform (size = w_shape ).astype (dtype )
109110 dw_np = tvm .topi .testing .dilate_python (w_np , (dilation , dilation , 1 , 1 ))
110- b_np = tvm .topi .testing .conv2d_nhwc_python (a_np , dw_np , stride , padding )
111+
112+ # scipy.signal.convolve2d does not support float16 data types,
113+ # and the python fallback would be too slow for general use.
114+ conv_dtype = "float32" if dtype == "float16" else dtype
115+ b_np = tvm .topi .testing .conv2d_nhwc_python (
116+ a_np .astype (conv_dtype ), dw_np .astype (conv_dtype ), stride , padding
117+ ).astype (dtype )
111118 return a_np , w_np , b_np
112119
113120
114- def test_conv2d_nhwc_gemm_fp32 (device , ref_data , dtype , stride , padding , dilation ):
121+ def get_tolerance (dtype , w_np , b_np ):
122+ if dtype == "float16" :
123+ # A summation in float16 with a single accumulator very
124+ # quickly runs into large rounding errors.
125+ # This tolerance is necessary to ensure no false negatives,
126+ # but it may introduce false positives, depending on schedule behaviour.
127+ num_values_summed = w_np .shape [0 ] * w_np .shape [1 ] * w_np .shape [2 ]
128+ next_float_gap_size = np .nextafter (b_np .max (), np .inf , dtype = b_np .dtype ) - b_np .max ()
129+ tol = {"rtol" : 1e-5 , "atol" : num_values_summed * next_float_gap_size / 2 }
130+ else :
131+ tol = {"rtol" : 1e-5 , "atol" : 1e-7 }
132+
133+ return tol
134+
135+
136+ def test_conv2d_nhwc_gemm (device , ref_data , dtype , stride , padding , dilation ):
115137 a_np , w_np , b_np = ref_data
116138
117139 A = te .placeholder (a_np .shape , name = "A" , dtype = dtype )
@@ -130,14 +152,21 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio
130152
131153 # Run only on AArch64 devices
132154 # Do not run SVE schedules on non-SVE devices
133- build_only = platform .machine () != "aarch64" or (
134- target .features .has_sve and not tvm .testing .requires_aarch64_sve .run_time_check ()
155+ build_only = (
156+ platform .machine () != "aarch64"
157+ or (target .features .has_sve and not tvm .testing .requires_aarch64_sve .run_time_check ())
158+ or (
159+ dtype == "float16"
160+ and target .features .has_fp16_simd
161+ and not tvm .testing .requires_arm_fp16 .run_time_check ()
162+ )
135163 )
136164 if build_only :
137165 return
138166
139167 func (a , w , b )
140- tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-5 )
168+ tol = get_tolerance (dtype , w_np , b_np )
169+ tvm .testing .assert_allclose (b .numpy (), b_np , rtol = tol ["rtol" ], atol = tol ["atol" ])
141170
142171
143172def test_conv2d_nhwc_hwio (target , dev , ref_data , dtype , stride , padding , dilation ):
@@ -155,7 +184,8 @@ def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilatio
155184 b = tvm .nd .array (np .zeros (get_const_tuple (B .shape ), dtype = B .dtype ), dev )
156185 func = tvm .build (s , [A , W , B ], target )
157186 func (a , w , b )
158- tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-5 )
187+ tol = get_tolerance (dtype , w_np , b_np )
188+ tvm .testing .assert_allclose (b .numpy (), b_np , rtol = tol ["rtol" ], atol = tol ["atol" ])
159189
160190
161191def test_conv2d_nhwc_ohwi (ref_data , dtype , stride , padding , dilation ):
@@ -184,7 +214,8 @@ def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
184214 b = tvm .nd .array (np .zeros (get_const_tuple (B .shape ), dtype = B .dtype ), dev )
185215 func = tvm .build (s , [A , W , B ], target )
186216 func (a , w , b )
187- tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-5 )
217+ tol = get_tolerance (dtype , w_np_hwio , b_np )
218+ tvm .testing .assert_allclose (b .numpy (), b_np , rtol = tol ["rtol" ], atol = tol ["atol" ])
188219
189220
190221if __name__ == "__main__" :
0 commit comments