Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tvm numpy operator deg2rad && rad2deg
Browse files Browse the repository at this point in the history
* fix format error

* change type

* constangt must be tvm.const in tvm and add backward test, do not support float16

* add addto test

* handle 0-dim and 0-size

* add 0-dim test case

* register to npi and add wrapper with doc

* change function name, add infer type

* fix format error

* merge rad2deg to deg2rad

* fix error according to review

* change infer type

* add TVM_OP in test
  • Loading branch information
Ying authored and tingying2020 committed Sep 19, 2019
1 parent a37a76c commit 36f3a21
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 8 deletions.
137 changes: 136 additions & 1 deletion contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# coding: utf-8
import tvm
from .. import defop, AllTypes
from .. import defop, AllTypes, RealTypes
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
Expand Down Expand Up @@ -98,3 +98,138 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, in_grad_a, in_grad]


def compute_degandrad(dtype, ndim, n):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
import math
if n == 0:
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
else:
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B')
s = tvm.create_schedule(B.op)
return s, A, B


@defop(name="deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 0)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]


@defop(name="rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]


@defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 0)
s = tvm.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B]


@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
s = tvm.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B]


def compute_backward_degandrad(dtype, ndim, req, n):
ishape = [tvm.var() for _ in range(ndim)]
in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype)
import math
if n == 0:
ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
else:
ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype))
if (req == "kAddTo"):
in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
else:
in_grad = tvm.compute(ishape, lambda *index: ret[index])
s = tvm.create_schedule(in_grad.op)
return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad]


@defop(name="backward_deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="backward_rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="cuda_backward_deg2rad", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="cuda_backward_rad2deg", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]
78 changes: 75 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

__all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute',
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'rad2deg', 'log2',
'log1p', 'rint', 'radians', 'deg2rad', 'reciprocal', 'square', 'negative', 'fix', 'ceil',
'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel']
Expand Down Expand Up @@ -1245,6 +1245,42 @@ def degrees(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.degrees, _np.degrees, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
def rad2deg(x, out=None):
r"""
rad2deg(x, out=None)
Convert angles from radians to degrees.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"rad2deg(x)" is "x *180 / pi".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.rad2deg(np.pi/2)
90.0
"""
return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out)


@set_module('mxnet.ndarray.numpy')
def rint(x, out=None, **kwargs):
"""
Expand Down Expand Up @@ -1388,6 +1424,42 @@ def radians(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
def deg2rad(x, out=None):
r"""
deg2rad(x, out=None)
Convert angles from degrees to radians.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"deg2rad(x)" is "x * pi / 180".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.deg2rad(180)
3.1415927
"""
return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out)


@set_module('mxnet.ndarray.numpy')
def reciprocal(x, out=None, **kwargs):
r"""
Expand Down
74 changes: 73 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide',
'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt',
'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'degrees', 'rad2deg', 'log2', 'log1p', 'rint', 'radians', 'deg2rad', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
Expand Down Expand Up @@ -2725,6 +2725,42 @@ def degrees(x, out=None, **kwargs):
return _mx_nd_np.degrees(x, out=out, **kwargs)


@set_module('mxnet.numpy')
def rad2deg(x, out=None):
r"""
rad2deg(x, out=None)
Convert angles from radians to degrees.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"rad2deg(x)" is "x * 180 / pi".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.rad2deg(np.pi/2)
90.0
"""
return _mx_nd_np.rad2deg(x, out=out)


@set_module('mxnet.numpy')
def radians(x, out=None, **kwargs):
"""
Expand Down Expand Up @@ -2760,6 +2796,42 @@ def radians(x, out=None, **kwargs):
return _mx_nd_np.radians(x, out=out, **kwargs)


@set_module('mxnet.numpy')
def deg2rad(x, out=None):
r"""
deg2rad(x, out=None)
Convert angles from degrees to radians.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"deg2rad(x)" is "x * pi / 180".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.deg2rad(180)
3.1415927
"""
return _mx_nd_np.deg2rad(x, out=out)


@set_module('mxnet.numpy')
def reciprocal(x, out=None, **kwargs):
r"""
Expand Down
Loading

0 comments on commit 36f3a21

Please sign in to comment.