From cd19367dcbd7fb022c0ffc13face6c45c7e2aacb Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 28 Jun 2019 11:23:46 +0800 Subject: [PATCH] add 'asnumpy' dtype option to check_symbolic_backward (#15186) --- python/mxnet/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bd102412c6e2..d247c0fcde95 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1133,7 +1133,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= >>> grad_expected = ograd.copy().asnumpy() >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected]) """ - assert dtype in (np.float16, np.float32, np.float64) + assert dtype == 'asnumpy' or dtype in (np.float16, np.float32, np.float64) if ctx is None: ctx = default_context() @@ -1146,7 +1146,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()} args_grad_data = {} for k, v in args_grad_npy.items(): - nd = mx.nd.array(v, ctx=ctx, dtype=dtype) + nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype) if grad_stypes is not None and k in grad_stypes: stype = grad_stypes[k] if stype is not None and stype != 'default': @@ -1170,7 +1170,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= outg = list() for arr in out_grads: if isinstance(arr, np.ndarray): - outg.append(mx.nd.array(arr, ctx=ctx, dtype=dtype)) + outg.append(mx.nd.array(arr, ctx=ctx, dtype=arr.dtype if dtype == "asnumpy" else dtype)) else: outg.append(arr) out_grads = outg @@ -1178,7 +1178,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= outg = dict() for k, v in out_grads.items(): if isinstance(v, np.ndarray): - outg[k] = mx.nd.array(v, ctx=ctx, dtype=dtype) + outg[k] = mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) else: outg[k] = v out_grads = outg