Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tvm numpy operator deg2rad
Browse files Browse the repository at this point in the history
* fix format error

* change type

* constangt must be tvm.const in tvm and add backward test, do not support float16

* add addto test

* handle 0-dim and 0-size

* add 0-dim test case

* register to npi and add wrapper with doc

* change function name, add infer type

* fix format error
  • Loading branch information
Ying committed Aug 26, 2019
1 parent d8b6e47 commit 93cbb6d
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 4 deletions.
78 changes: 77 additions & 1 deletion contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# coding: utf-8
import tvm
from .. import defop, AllTypes
from .. import defop, AllTypes, RealTypes
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
Expand Down Expand Up @@ -98,3 +98,79 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, in_grad_a, in_grad]


def compute_deg2rad(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
import math
B = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
s = tvm.create_schedule(B.op)
return s, A, B


@defop(name="deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad(dtype, ndim):
s, A, B = compute_deg2rad(dtype, ndim)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]


@defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad_gpu(dtype, ndim):
s, A, B = compute_deg2rad(dtype, ndim)
s = tvm.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B]


def compute_backward_deg2rad(dtype, ndim, req):
ishape = [tvm.var() for _ in range(ndim)]
in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype)
import math
ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
if (req == "kAddTo"):
in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
else:
in_grad = tvm.compute(ishape, lambda *index: ret[index])
s = tvm.create_schedule(in_grad.op)
return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad]


@defop(name="backward_deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]


@defop(name="cuda_backward_deg2rad", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_deg2rad(dtype, ndim, req)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]
39 changes: 38 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1905,3 +1905,40 @@ def get_list(arrays):

arrays = get_list(arrays)
return _npi.stack(*arrays, axis=axis, out=out)



@set_module('mxnet.ndarray.numpy')
def deg2rad(x, out=None):
r"""
deg2rad(x, out=None)
Convert angles from degrees to radians.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"deg2rad(x)" is "x * pi / 180".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.deg2rad(180)
3.1415927
"""
return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out)
38 changes: 37 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack']
'stack', 'deg2rad']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -3086,3 +3086,39 @@ def stack(arrays, axis=0, out=None):
stacked : ndarray
The stacked array has one more dimension than the input arrays."""
return _mx_nd_np.stack(arrays, axis=axis, out=out)


@set_module('mxnet.numpy')
def deg2rad(x, out=None):
r"""
deg2rad(x, out=None)
Convert angles from degrees to radians.
Parameters
----------
x : ndarray or scalar
Angles in degrees.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"deg2rad(x)" is "x * pi / 180".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
Examples
--------
>>> np.deg2rad(180)
3.1415927
"""
return _mx_nd_np.deg2rad(x, out=out)
32 changes: 31 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'deg2rad']


def _num_outputs(sym):
Expand Down Expand Up @@ -2328,4 +2328,34 @@ def get_list(arrays):
return _npi.stack(*arrays, axis=axis, out=out)


@set_module('mxnet.symbol.numpy')
def deg2rad(x, out=None):
r"""
deg2rad(x, out=None)
Convert angles from degrees to radians.
Parameters
----------
x : _Symbol or scalar
Angles in degrees.
out : _Symbol or None, optional
A location into which the result is stored.
Returns
-------
y : _Symbol or scalar
The corresponding angle in radians.
This is a scalar if `x` is a scalar.
Notes
-----
"deg2rad(x)" is "x * pi / 180".
This function differs from the original numpy.arange in the following aspects:
- Only support float32 and float64.
- `out` must be in the same size of input.
"""
return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out)


_set_np_symbol_class(_Symbol)
83 changes: 83 additions & 0 deletions src/operator/contrib/tvmop/ufunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,89 @@ NNVM_REGISTER_OP(_backward_contrib_tvm_vadd)
.set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBinaryBackwardComputeUseNone<func_bakcward_vadd_cpu>);

inline bool Deg2radOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
CHECK_LE(in_attrs->at(0), mshadow::kFloat64)
<< "Only support float32 and float64.";
return out_attrs->at(0) != -1;
}

template<const char* func>
void TVMUnaryCompute(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (0 == inputs[0].shape_.Size()) {
// 0-size
return;
}
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], outputs[0]});
}

template<const char* func>
void TVMUnaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (0 == inputs[0].shape_.Size()) {
// 0-size
return;
}
std::string funcname = func;
funcname += "req_";
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (req_type == kWriteTo) {
funcname += "kWriteTo";
} else {
funcname += "kAddTo";
}
tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx,
{inputs[0], outputs[0], outputs[0]});
})
}

static constexpr char func_deg2rad_cpu[] = "deg2rad";
static constexpr char func_deg2rad_gpu[] = "cuda_deg2rad";
static constexpr char func_backward_deg2rad_cpu[] = "backward_deg2rad";
static constexpr char func_backward_deg2rad_gpu[] = "cuda_backward_deg2rad";

NNVM_REGISTER_OP(_npi_deg2rad)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", mxnet::op::ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", mxnet::op::Deg2radOpType)
#if MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<gpu>", mxnet::op::TVMUnaryCompute<func_deg2rad_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<mxnet::FCompute>("FCompute<cpu>", mxnet::op::TVMUnaryCompute<func_deg2rad_cpu>)
.add_argument("data", "NDArray-or-Symbol", "the input")
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_deg2rad"});

NNVM_REGISTER_OP(_backward_npi_deg2rad)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_CUDA
.set_attr<FCompute>("FCompute<gpu>",
mxnet::op::TVMUnaryBackwardComputeUseNone<func_backward_deg2rad_gpu>)
#endif // MXNET_USE_CUDA
.set_attr<FCompute>("FCompute<cpu>",
mxnet::op::TVMUnaryBackwardComputeUseNone<func_backward_deg2rad_cpu>);

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_TVM_OP
49 changes: 49 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,55 @@ def hybrid_forward(self, F, a, *args):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@use_np
def test_np_deg2rad():
class TestDeg2rad(HybridBlock):
def __init__(self):
super(TestDeg2rad, self).__init__()

def hybrid_forward(self, F, x):
return F.np.deg2rad(x)

types = ['float64', 'float32']
for hybridize in [True, False]:
for shape in [(),
(1,),
(1, 1),
(1, 2, 3),
(1, 0),
(2, 0, 3)
]:
for oneType in types:
rtol=1e-3
atol=1e-5
test_deg2rad = TestDeg2rad()
if hybridize:
test_deg2rad.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
x.attach_grad()
np_out = _np.deg2rad(x.asnumpy())
with mx.autograd.record():
mx_out = test_deg2rad(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)
mx_out.backward()
import math
np_backward = math.pi / 180
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol)

mx_out = np.deg2rad(x)
np_out = _np.deg2rad(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)

# Test AddTo Request
with mx.autograd.record():
a = test_deg2rad(x)
b = test_deg2rad(x)
mx.autograd.backward([a, b])
assert_almost_equal(x.grad.asnumpy(), 2 * np_backward, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 93cbb6d

Please sign in to comment.