Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add 'asnumpy' dtype option to check_symbolic_backward (#15186)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and wkcn committed Jun 28, 2019
1 parent 582489c commit cd19367
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
>>> grad_expected = ograd.copy().asnumpy()
>>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])
"""
assert dtype in (np.float16, np.float32, np.float64)
assert dtype == 'asnumpy' or dtype in (np.float16, np.float32, np.float64)
if ctx is None:
ctx = default_context()

Expand All @@ -1146,7 +1146,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()}
args_grad_data = {}
for k, v in args_grad_npy.items():
nd = mx.nd.array(v, ctx=ctx, dtype=dtype)
nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype)
if grad_stypes is not None and k in grad_stypes:
stype = grad_stypes[k]
if stype is not None and stype != 'default':
Expand All @@ -1170,15 +1170,15 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
outg = list()
for arr in out_grads:
if isinstance(arr, np.ndarray):
outg.append(mx.nd.array(arr, ctx=ctx, dtype=dtype))
outg.append(mx.nd.array(arr, ctx=ctx, dtype=arr.dtype if dtype == "asnumpy" else dtype))
else:
outg.append(arr)
out_grads = outg
elif isinstance(out_grads, dict):
outg = dict()
for k, v in out_grads.items():
if isinstance(v, np.ndarray):
outg[k] = mx.nd.array(v, ctx=ctx, dtype=dtype)
outg[k] = mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype)
else:
outg[k] = v
out_grads = outg
Expand Down

0 comments on commit cd19367

Please sign in to comment.