@@ -264,7 +264,17 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
264264 dy_np = np .random .uniform (- 1 , 1 , oshape ).astype (data_dtype )
265265 w_np = np .random .uniform (- 1 , 1 , wshape ).astype (data_dtype )
266266
267- dx_np = ref_func (dy_np , w_np , (stride_h , stride_w ), (pad_h , pad_w ), (0 , 0 ))
267+ if data_dtype == "float16" :
268+ dx_np = ref_func (
269+ dy_np .astype ("float32" ),
270+ w_np .astype ("float32" ),
271+ (stride_h , stride_w ),
272+ (pad_h , pad_w ),
273+ (0 , 0 ),
274+ )
275+ dx_np = dx_np .astype ("float16" )
276+ else :
277+ dx_np = ref_func (dy_np , w_np , (stride_h , stride_w ), (pad_h , pad_w ), (0 , 0 ))
268278
269279 dy = te .placeholder (oshape , name = "dy" , dtype = data_dtype )
270280 w = te .placeholder (wshape , name = "dw" , dtype = data_dtype )
@@ -292,6 +302,7 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
292302
293303 f (dy , w , dx )
294304 print (np .max (np .abs (dx .numpy () - dx_np )))
305+ print (np .mean (np .abs (dx .numpy () - dx_np )))
295306 tvm .testing .assert_allclose (dx .numpy (), dx_np , atol = tol , rtol = tol )
296307
297308
@@ -300,6 +311,9 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
300311def test_conv2d_backward_data ():
301312 verify_conv2d_backward_data ("float32" , "float32" , tensor_format = 0 , tol = 1e-5 )
302313 verify_conv2d_backward_data ("float32" , "float32" , tensor_format = 1 , tol = 1e-2 )
314+ # The scipy convolve function does not support fp16, so the reference will be computed with
315+ # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok).
316+ verify_conv2d_backward_data ("float16" , "float16" , tensor_format = 1 , tol = 1e-1 )
303317
304318
305319test_kwargs_default_2d = {
0 commit comments