Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix rtrue_divide grad #16769

Merged
merged 2 commits into from
Nov 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/operator/numpy/np_true_divide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
})
#endif
.set_attr<FCompute>("FCompute<cpu>", TrueDivideScalarCompute<cpu, mshadow_op::rtrue_divide>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rdiv_scalar"})
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "float", "scalar input");

Expand Down
107 changes: 95 additions & 12 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np
from common import with_seed, TemporaryDirectory
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable, collapse_sum_like
from mxnet.ndarray.ndarray import py_slice
from mxnet.base import integer_types
import scipy.stats as ss
Expand Down Expand Up @@ -281,6 +281,62 @@ def test_np_ndarray_binary_element_wise_ops():
'<=': _np.less_equal
})

def _get_grad_func(op, scalar=None, reverse=False):
if op == '+':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape),
collapse_sum_like(ograd, x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd
else:
return lambda ograd, x1, x2, out: ograd
elif op == '-':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape),
-collapse_sum_like(ograd, x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd
else:
return lambda ograd, x1, x2, out: -ograd
elif op == '*':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * x2, x1.shape),
collapse_sum_like(ograd * x1, x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd * x2
else:
return lambda ograd, x1, x2, out: ograd * x1
elif op == '/':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd / x2, x1.shape),
collapse_sum_like(-x1 * ograd / (x2 * x2), x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd / x2
else:
return lambda ograd, x1, x2, out: -x1 * ograd / (x2 * x2)
elif op == 'mod':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape),
collapse_sum_like(-ograd * _np.floor(x1 / x2), x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd
else:
return lambda ograd, x1, x2, out: -ograd * _np.floor(x1 / x2)
elif op == 'pow':
if scalar is None:
return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * x2 * _np.power(x1, x2 - 1), x1.shape),
collapse_sum_like(ograd * out * _np.log(x1), x2.shape))
elif not reverse:
return lambda ograd, x1, x2, out: ograd * x2 * _np.power(x1, x2 - 1)
else:
return lambda ograd, x1, x2, out: ograd * out * _np.log(x1)
elif op in ('==', '!=', '<', '<=', '>', '>='):
if scalar is None:
return lambda ograd, x1, x2, out: (_np.zeros_like(x1), _np.zeros_like(x2))
else:
return lambda ograd, x1, x2, out: _np.zeros_like(ograd)
return None

def get_np_ret(x1, x2, op):
return np_op_map[op](x1, x2)

Expand Down Expand Up @@ -364,13 +420,15 @@ def check_binary_op_result(shape1, shape2, op, dtype=None):
mx_input1 = abs(_np.random.uniform()) + 1
np_input1 = mx_input1
else:
mx_input1 = rand_ndarray(shape1, dtype=dtype).abs() + 1
mx_input1 = (rand_ndarray(shape1, dtype=dtype).abs() + 1).as_np_ndarray()
mx_input1.attach_grad()
np_input1 = mx_input1.asnumpy()
if shape2 is None:
mx_input2 = abs(_np.random.uniform()) + 1
np_input2 = mx_input2
else:
mx_input2 = rand_ndarray(shape2, dtype=dtype).abs() + 1
mx_input2 = (rand_ndarray(shape2, dtype=dtype).abs() + 1).as_np_ndarray()
mx_input2.attach_grad()
np_input2 = mx_input2.asnumpy()

scalar = None
Expand All @@ -382,34 +440,59 @@ def check_binary_op_result(shape1, shape2, op, dtype=None):
scalar = mx_input1
reverse = True

grad_func = _get_grad_func(op, scalar, reverse)
np_out = get_np_ret(np_input1, np_input2, op)
ograd = _np.ones_like(np_out)
for hybridize in [True, False]:
if scalar is None:
get_mx_ret_np = TestBinaryElementWiseOp(op)
get_mx_ret_classic = TestBinaryElementWiseOp(op)
if hybridize:
get_mx_ret_np.hybridize()
get_mx_ret_classic.hybridize()
mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
if grad_func is None:
mx_out = get_mx_ret_np(mx_input1, mx_input2)
else:
with mx.autograd.record():
mx_out = get_mx_ret_np(mx_input1, mx_input2)
mx_out.backward()
assert type(mx_out) == np.ndarray
assert np_out.shape == mx_out.shape
if op in logic_ops:
assert np_out.dtype == mx_out.dtype
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5, use_broadcast=False)

if grad_func is not None:
x1_grad_expected, x2_grad_expected = grad_func(ograd, np_input1, np_input2, np_out)
assert_almost_equal(mx_input1.grad.asnumpy(), x1_grad_expected, atol=1e-5, rtol=1e-3,
use_broadcast=False)
assert_almost_equal(mx_input2.grad.asnumpy(), x2_grad_expected, atol=1e-5, rtol=1e-3,
use_broadcast=False)
else:
get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse)
if hybridize:
get_mx_ret.hybridize()
if reverse:
mx_out = get_mx_ret(mx_input2.as_np_ndarray())
assert type(mx_out) == np.ndarray
mx_input = mx_input2
else:
mx_out = get_mx_ret(mx_input1.as_np_ndarray())
assert type(mx_out) == np.ndarray
assert np_out.shape == mx_out.shape
mx_input = mx_input1

if grad_func is None:
mx_out = get_mx_ret(mx_input)
else:
with mx.autograd.record():
mx_out = get_mx_ret(mx_input)
mx_out.backward()
assert type(mx_out) == np.ndarray

if op in logic_ops:
assert np_out.dtype == mx_out.dtype
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5, use_broadcast=False)

# check grad
if grad_func is not None:
x_grad_expected = grad_func(ograd, np_input1, np_input2, np_out)
assert_almost_equal(mx_input.grad.asnumpy(), x_grad_expected, atol=1e-5, rtol=1e-3,
use_broadcast=False)

dtypes = [_np.float32, _np.float64, None]
ops = np_op_map.keys()
Expand Down
7 changes: 3 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,8 +1572,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False)
if rgrads is None:
assert_almost_equal(mx_test_x2.grad.asnumpy(),
collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape),
rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False)
collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape),
rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False)
else:
assert_almost_equal(mx_test_x2.grad.asnumpy(),
collapse_sum_like(rgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x2.shape),
Expand All @@ -1594,7 +1594,6 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
assertRaises(NotImplementedError, getattr(np, func), mx_test_x1, mx_test_x2, order='C')
assertRaises(NotImplementedError, getattr(np, func), mx_test_x1, mx_test_x2, order='mxnet')


funcs = {
'add': (-1.0, 1.0, [lambda y, x1, x2: _np.ones(y.shape)], None),
'subtract':
Expand All @@ -1603,7 +1602,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
'multiply': (-1.0, 1.0, [lambda y, x1, x2: _np.broadcast_to(x2, y.shape)],
[lambda y, x1, x2: _np.broadcast_to(x1, y.shape)]),
'divide': (0.1, 1.0, [lambda y, x1, x2: _np.ones(y.shape) / x2],
[lambda y, x1, x2: -x1 / (x2 * x2)]),
[lambda y, x1, x2: -x1 / (x2 * x2)]),
'mod': (1.0, 10.0,
[lambda y, x1, x2: _np.ones(y.shape),
lambda y, x1, x2: _np.zeros(y.shape)],
Expand Down