Skip to content

Commit 211a58b

Browse files
committed
fp16 also works
1 parent c2a34d4 commit 211a58b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

tests/python/contrib/test_cudnn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,17 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
264264
dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
265265
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
266266

267-
dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0))
267+
if data_dtype == "float16":
268+
dx_np = ref_func(
269+
dy_np.astype("float32"),
270+
w_np.astype("float32"),
271+
(stride_h, stride_w),
272+
(pad_h, pad_w),
273+
(0, 0),
274+
)
275+
dx_np = dx_np.astype("float16")
276+
else:
277+
dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0))
268278

269279
dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
270280
w = te.placeholder(wshape, name="dw", dtype=data_dtype)
@@ -292,6 +302,7 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
292302

293303
f(dy, w, dx)
294304
print(np.max(np.abs(dx.numpy() - dx_np)))
305+
print(np.mean(np.abs(dx.numpy() - dx_np)))
295306
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol)
296307

297308

@@ -300,6 +311,9 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
300311
def test_conv2d_backward_data():
301312
verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5)
302313
verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2)
314+
# The scipy convolve function does not support fp16, so the reference will be computed with
315+
# fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok).
316+
verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1)
303317

304318

305319
test_kwargs_default_2d = {

0 commit comments

Comments
 (0)