Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
merge rad2deg to deg2rad
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying committed Aug 27, 2019
1 parent 93cbb6d commit 39e0f1e
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 35 deletions.
77 changes: 68 additions & 9 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,33 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
return s, [X, in_grad_a, in_grad]


def compute_deg2rad(dtype, ndim):
def compute_degandrad(dtype, ndim, n):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
import math
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
if n == 0:
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
else:
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B')
s = tvm.create_schedule(B.op)
return s, A, B


@defop(name="deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad(dtype, ndim):
s, A, B = compute_deg2rad(dtype, ndim)
s, A, B = compute_degandrad(dtype, ndim, 0)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]


@defop(name="rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
Expand All @@ -122,7 +136,7 @@ def deg2rad(dtype, ndim):
@defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad_gpu(dtype, ndim):
s, A, B = compute_deg2rad(dtype, ndim)
s, A, B = compute_degandrad(dtype, ndim, 0)
s = tvm.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
Expand All @@ -132,13 +146,29 @@ def deg2rad_gpu(dtype, ndim):
return s, [A, B]


def compute_backward_deg2rad(dtype, ndim, req):
@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
s = tvm.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B]


def compute_backward_degandrad(dtype, ndim, req, n):
ishape = [tvm.var() for _ in range(ndim)]
in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype)
import math
ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
if n == 0:
ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
else:
ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype))
if (req == "kAddTo"):
in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
else:
Expand All @@ -151,7 +181,19 @@ def compute_backward_deg2rad(dtype, ndim, req):
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req)
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="backward_rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
Expand All @@ -163,7 +205,24 @@ def backward_deg2rad(dtype, ndim, req):
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req)
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="cuda_backward_rad2deg", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
Expand Down
40 changes: 38 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad',
'rad2deg']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1907,7 +1908,6 @@ def get_list(arrays):
return _npi.stack(*arrays, axis=axis, out=out)



@set_module('mxnet.ndarray.numpy')
def deg2rad(x, out=None):
r"""
Expand Down Expand Up @@ -1942,3 +1942,39 @@ def deg2rad(x, out=None):
3.1415927
"""
return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out)


@set_module('mxnet.ndarray.numpy')
def rad2deg(x, out=None):
r"""
rad2deg(x, out=None)
Convert angles from radians to degrees.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"rad2deg(x)" is "x *180 / pi".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.rad2deg(np.pi/2)
90.0
"""
return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out)
38 changes: 37 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'deg2rad']
'stack', 'deg2rad', 'rad2deg']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -3122,3 +3122,39 @@ def deg2rad(x, out=None):
3.1415927
"""
return _mx_nd_np.deg2rad(x, out=out)


@set_module('mxnet.numpy')
def rad2deg(x, out=None):
r"""
rad2deg(x, out=None)
Convert angles from radians to degrees.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"rad2deg(x)" is "x * 180 / pi".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.rad2deg(np.pi/2)
90.0
"""
return _mx_nd_np.rad2deg(x, out=out)
33 changes: 32 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad',
'rad2deg']


def _num_outputs(sym):
Expand Down Expand Up @@ -2358,4 +2359,34 @@ def deg2rad(x, out=None):
return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out)


@set_module('mxnet.symbol.numpy')
def rad2deg(x, out=None):
r"""
rad2deg(x, out=None)
Convert angles from radians to degrees.
Parameters
----------
x : _Symbol or scalar
Angles in degrees.
out : _Symbol or None, optional
A location into which the result is stored.
Returns
-------
y : _Symbol or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"rad2deg(x)" is "x * 180 / pi".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
"""
return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out)


_set_np_symbol_class(_Symbol)
37 changes: 33 additions & 4 deletions src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ NNVM_REGISTER_OP(_backward_contrib_tvm_vadd)
.set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_cpu>);

inline bool Deg2radOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
inline bool DegandradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

Expand Down Expand Up @@ -210,7 +210,7 @@ NNVM_REGISTER_OP(_npi_deg2rad)
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", mxnet::op::ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::Deg2radOpType)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::DegandradOpType)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMUnaryCompute<func_deg2rad_gpu>)
#endif // MXNET_USE_CUDA
Expand All @@ -227,6 +227,35 @@ NNVM_REGISTER_OP(_backward_npi_deg2rad)
.set_attr<FCompute>("FCompute<cpu>",
mxnet::op::TVMUnaryBackwardComputeUseNone<func_backward_deg2rad_cpu>);

static constexpr char func_rad2deg_cpu[] = "rad2deg";
static constexpr char func_rad2deg_gpu[] = "cuda_rad2deg";
static constexpr char func_backward_rad2deg_cpu[] = "backward_rad2deg";
static constexpr char func_backward_rad2deg_gpu[] = "cuda_backward_rad2deg";

NNVM_REGISTER_OP(_npi_rad2deg)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", mxnet::op::ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::DegandradOpType)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMUnaryCompute<func_rad2deg_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMUnaryCompute<func_rad2deg_cpu>)
.add_argument("data", "NDArray-or-Symbol", "the input")
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_rad2deg"});

NNVM_REGISTER_OP(_backward_npi_rad2deg)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_CUDA
.set_attr<FCompute>("FCompute<gpu>",
mxnet::op::TVMUnaryBackwardComputeUseNone<func_backward_rad2deg_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<FCompute>("FCompute<cpu>",
mxnet::op::TVMUnaryBackwardComputeUseNone<func_backward_rad2deg_cpu>);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_TVM_OP
Loading

0 comments on commit 39e0f1e

Please sign in to comment.