diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index 0dc7ea9022ee..87356428428f 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -100,11 +100,15 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req): return s, [X, in_grad_a, in_grad] -def compute_deg2rad(dtype, ndim): +def compute_degandrad(dtype, ndim, n): A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype) import math - B = tvm.compute([tvm.var() for _ in range(ndim)], - lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B') + if n == 0: + B = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B') + else: + B = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B') s = tvm.create_schedule(B.op) return s, A, B @@ -112,7 +116,17 @@ def compute_deg2rad(dtype, ndim): @defop(name="deg2rad", target="cpu", auto_broadcast=False, dtype=["float32", "float64"], ndim=list(range(0, 6))) def deg2rad(dtype, ndim): - s, A, B = compute_deg2rad(dtype, ndim) + s, A, B = compute_degandrad(dtype, ndim, 0) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + s[B].parallel(fused) + return s, [A, B] + + +@defop(name="rad2deg", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def rad2deg(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 1) axes = [axis for axis in B.op.axis] fused = s[B].fuse(*axes) s[B].parallel(fused) @@ -122,7 +136,7 @@ def deg2rad(dtype, ndim): @defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False, dtype=["float32", "float64"], ndim=list(range(0, 6))) def deg2rad_gpu(dtype, ndim): - s, A, B = compute_deg2rad(dtype, ndim) + s, A, B = compute_degandrad(dtype, ndim, 0) s = tvm.create_schedule(B.op) axes = [axis for axis in B.op.axis] fused = s[B].fuse(*axes) @@ -132,13 +146,29 @@ def deg2rad_gpu(dtype, ndim): return s, [A, B] -def compute_backward_deg2rad(dtype, ndim, req): +@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def rad2deg_gpu(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 1) + s = tvm.create_schedule(B.op) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + bx, tx = s[B].split(fused, factor=64) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + return s, [A, B] + + +def compute_backward_degandrad(dtype, ndim, req, n): ishape = [tvm.var() for _ in range(ndim)] in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype) in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype) out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype) import math - ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype)) + if n == 0: + ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype)) + else: + ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype)) if (req == "kAddTo"): in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index]) else: @@ -151,7 +181,19 @@ def compute_backward_deg2rad(dtype, ndim, req): dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], attrs=["req"]) def backward_deg2rad(dtype, ndim, req): - s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req) + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0) + for t in c_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [out_grad, in_grad, in_grad_tmp] + + +@defop(name="backward_rad2deg", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def backward_rad2deg(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1) for t in c_list: axes = [axis for axis in t.op.axis] fused = s[t].fuse(*axes) @@ -163,7 +205,24 @@ def backward_deg2rad(dtype, ndim, req): dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], attrs=["req"]) def cuda_backward_deg2rad(dtype, ndim, req): - s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req) + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0) + num_thread = 64 + for t in c_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_thread) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [out_grad, in_grad, in_grad_tmp] + + +@defop(name="cuda_backward_rad2deg", target="gpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def cuda_backward_rad2deg(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1) num_thread = 64 for t in c_list: block_x = tvm.thread_axis("blockIdx.x") diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index a00e2621d43e..36acc97e8b9f 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,7 +32,8 @@ 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad', + 'rad2deg'] @set_module('mxnet.ndarray.numpy') @@ -1907,7 +1908,6 @@ def get_list(arrays): return _npi.stack(*arrays, axis=axis, out=out) - @set_module('mxnet.ndarray.numpy') def deg2rad(x, out=None): r""" @@ -1942,3 +1942,39 @@ def deg2rad(x, out=None): 3.1415927 """ return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out) + + +@set_module('mxnet.ndarray.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x *180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.rad2deg(np.pi/2) + 90.0 + """ + return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 55692aea55bb..8bed50b823d0 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'deg2rad'] + 'stack', 'deg2rad', 'rad2deg'] # This function is copied from ndarray.py since pylint @@ -3122,3 +3122,39 @@ def deg2rad(x, out=None): 3.1415927 """ return _mx_nd_np.deg2rad(x, out=out) + + +@set_module('mxnet.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x * 180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.rad2deg(np.pi/2) + 90.0 + """ + return _mx_nd_np.rad2deg(x, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 2ac94dbb93d6..8901dc0b2537 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,7 +34,8 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad', + 'rad2deg'] def _num_outputs(sym): @@ -2358,4 +2359,34 @@ def deg2rad(x, out=None): return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out) +@set_module('mxnet.symbol.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : _Symbol or scalar + Angles in degrees. + out : _Symbol or None, optional + A location into which the result is stored. + + Returns + ------- + y : _Symbol or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x * 180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + """ + return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index 52f4032c2dac..30db366c5f7c 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -144,9 +144,9 @@ NNVM_REGISTER_OP(_backward_contrib_tvm_vadd) .set_attr("FCompute", mxnet::op::TVMBinaryBackwardComputeUseNone); -inline bool Deg2radOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { +inline bool DegandradOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); @@ -210,7 +210,7 @@ NNVM_REGISTER_OP(_npi_deg2rad) return std::vector{"data"}; }) .set_attr("FInferShape", mxnet::op::ElemwiseShape<1, 1>) -.set_attr("FInferType", mxnet::op::Deg2radOpType) +.set_attr("FInferType", mxnet::op::DegandradOpType) #if MXNET_USE_CUDA .set_attr("FCompute", mxnet::op::TVMUnaryCompute) #endif // MXNET_USE_CUDA @@ -227,6 +227,35 @@ NNVM_REGISTER_OP(_backward_npi_deg2rad) .set_attr("FCompute", mxnet::op::TVMUnaryBackwardComputeUseNone); +static constexpr char func_rad2deg_cpu[] = "rad2deg"; +static constexpr char func_rad2deg_gpu[] = "cuda_rad2deg"; +static constexpr char func_backward_rad2deg_cpu[] = "backward_rad2deg"; +static constexpr char func_backward_rad2deg_gpu[] = "cuda_backward_rad2deg"; + +NNVM_REGISTER_OP(_npi_rad2deg) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", mxnet::op::ElemwiseShape<1, 1>) +.set_attr("FInferType", mxnet::op::DegandradOpType) +#if MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +.add_argument("data", "NDArray-or-Symbol", "the input") +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_rad2deg"}); + +NNVM_REGISTER_OP(_backward_npi_rad2deg) +.set_attr("TIsBackward", true) +#if MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone); } // namespace op } // namespace mxnet #endif // MXNET_USE_TVM_OP diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index cdc3774f5396..7907ffb0cbad 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1082,7 +1082,7 @@ def hybrid_forward(self, F, a, *args): @with_seed() @use_np -def test_np_deg2rad(): +def test_np_degandrad(): class TestDeg2rad(HybridBlock): def __init__(self): super(TestDeg2rad, self).__init__() @@ -1090,6 +1090,13 @@ def __init__(self): def hybrid_forward(self, F, x): return F.np.deg2rad(x) + class TestRad2deg(HybridBlock): + def __init__(self): + super(TestRad2deg, self).__init__() + + def hybrid_forward(self, F, x): + return F.np.rad2deg(x) + types = ['float64', 'float32'] for hybridize in [True, False]: for shape in [(), @@ -1103,30 +1110,48 @@ def hybrid_forward(self, F, x): rtol=1e-3 atol=1e-5 test_deg2rad = TestDeg2rad() + test_rad2deg = TestRad2deg() if hybridize: test_deg2rad.hybridize() - x = rand_ndarray(shape, dtype=oneType).as_np_ndarray() - x.attach_grad() - np_out = _np.deg2rad(x.asnumpy()) + test_rad2deg.hybridize() + x1 = rand_ndarray(shape, dtype=oneType).as_np_ndarray() + x2 = rand_ndarray(shape, dtype=oneType).as_np_ndarray() + x1.attach_grad() + x2.attach_grad() + np_out1 = _np.deg2rad(x1.asnumpy()) + np_out2 = _np.rad2deg(x2.asnumpy()) with mx.autograd.record(): - mx_out = test_deg2rad(x) - assert mx_out.shape == np_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol) - mx_out.backward() + mx_out1 = test_deg2rad(x1) + mx_out2 = test_rad2deg(x2) + assert mx_out1.shape == np_out1.shape + assert mx_out2.shape == np_out2.shape + assert_almost_equal(mx_out1.asnumpy(), np_out1, rtol, atol) + assert_almost_equal(mx_out2.asnumpy(), np_out2, rtol, atol) + mx_out1.backward() + mx_out2.backward() import math - np_backward = math.pi / 180 - assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) - - mx_out = np.deg2rad(x) - np_out = _np.deg2rad(x.asnumpy()) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol) + np_backward1 = math.pi / 180 + np_backward2 = 180 / math.pi + assert_almost_equal(x1.grad.asnumpy(), np_backward1, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad.asnumpy(), np_backward2, rtol=rtol, atol=atol) + + mx_out1 = np.deg2rad(x1) + np_out1 = _np.deg2rad(x1.asnumpy()) + assert_almost_equal(mx_out1.asnumpy(), np_out1, rtol, atol) + mx_out2 = np.rad2deg(x2) + np_out2 = _np.rad2deg(x2.asnumpy()) + assert_almost_equal(mx_out2.asnumpy(), np_out2, rtol, atol) # Test AddTo Request with mx.autograd.record(): - a = test_deg2rad(x) - b = test_deg2rad(x) - mx.autograd.backward([a, b]) - assert_almost_equal(x.grad.asnumpy(), 2 * np_backward, rtol=rtol, atol=atol) + a1 = test_deg2rad(x1) + b1 = test_deg2rad(x1) + a2 = test_rad2deg(x2) + b2 = test_rad2deg(x2) + mx.autograd.backward([a1, b1]) + mx.autograd.backward([a2, b2]) + assert_almost_equal(x1.grad.asnumpy(), 2 * np_backward1, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad.asnumpy(), 2 * np_backward2, rtol=rtol, atol=atol) if __name__ == '__main__':