From 36f3a21675a53c7dbf525bfaa451faa74d507552 Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 22 Aug 2019 18:28:49 +0800 Subject: [PATCH] tvm numpy operator deg2rad && rad2deg * fix format error * change type * constangt must be tvm.const in tvm and add backward test, do not support float16 * add addto test * handle 0-dim and 0-size * add 0-dim test case * register to npi and add wrapper with doc * change function name, add infer type * fix format error * merge rad2deg to deg2rad * fix error according to review * change infer type * add TVM_OP in test --- contrib/tvmop/basic/ufunc.py | 137 ++++++++++++++++++++++++- python/mxnet/ndarray/numpy/_op.py | 78 +++++++++++++- python/mxnet/numpy/multiarray.py | 74 ++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 64 +++++++++++- src/operator/contrib/tvmop/ufunc.cc | 113 ++++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 9 +- 6 files changed, 467 insertions(+), 8 deletions(-) diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py index 6bb102ccf7e3..87356428428f 100644 --- a/contrib/tvmop/basic/ufunc.py +++ b/contrib/tvmop/basic/ufunc.py @@ -17,7 +17,7 @@ # coding: utf-8 import tvm -from .. import defop, AllTypes +from .. import defop, AllTypes, RealTypes from .. import assign_by_req, reduce_axes def compute_add(dtype, ndim): @@ -98,3 +98,138 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req): s[t].bind(bx, block_x) s[t].bind(tx, thread_x) return s, [X, in_grad_a, in_grad] + + +def compute_degandrad(dtype, ndim, n): + A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype) + import math + if n == 0: + B = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B') + else: + B = tvm.compute([tvm.var() for _ in range(ndim)], + lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B') + s = tvm.create_schedule(B.op) + return s, A, B + + +@defop(name="deg2rad", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def deg2rad(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 0) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + s[B].parallel(fused) + return s, [A, B] + + +@defop(name="rad2deg", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def rad2deg(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 1) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + s[B].parallel(fused) + return s, [A, B] + + +@defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def deg2rad_gpu(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 0) + s = tvm.create_schedule(B.op) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + bx, tx = s[B].split(fused, factor=64) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + return s, [A, B] + + +@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6))) +def rad2deg_gpu(dtype, ndim): + s, A, B = compute_degandrad(dtype, ndim, 1) + s = tvm.create_schedule(B.op) + axes = [axis for axis in B.op.axis] + fused = s[B].fuse(*axes) + bx, tx = s[B].split(fused, factor=64) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + return s, [A, B] + + +def compute_backward_degandrad(dtype, ndim, req, n): + ishape = [tvm.var() for _ in range(ndim)] + in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype) + in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype) + out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype) + import math + if n == 0: + ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype)) + else: + ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype)) + if (req == "kAddTo"): + in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index]) + else: + in_grad = tvm.compute(ishape, lambda *index: ret[index]) + s = tvm.create_schedule(in_grad.op) + return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad] + + +@defop(name="backward_deg2rad", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def backward_deg2rad(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0) + for t in c_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [out_grad, in_grad, in_grad_tmp] + + +@defop(name="backward_rad2deg", target="cpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def backward_rad2deg(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1) + for t in c_list: + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + s[t].parallel(fused) + return s, [out_grad, in_grad, in_grad_tmp] + + +@defop(name="cuda_backward_deg2rad", target="gpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def cuda_backward_deg2rad(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0) + num_thread = 64 + for t in c_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_thread) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [out_grad, in_grad, in_grad_tmp] + + +@defop(name="cuda_backward_rad2deg", target="gpu", auto_broadcast=False, + dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"], + attrs=["req"]) +def cuda_backward_rad2deg(dtype, ndim, req): + s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1) + num_thread = 64 + for t in c_list: + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + axes = [axis for axis in t.op.axis] + fused = s[t].fuse(*axes) + bx, tx = s[t].split(fused, factor=num_thread) + s[t].bind(bx, block_x) + s[t].bind(tx, thread_x) + return s, [out_grad, in_grad, in_grad_tmp] diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 163d90835e2a..6433167d94cc 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -29,9 +29,9 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', - 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', + 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'rad2deg', 'log2', + 'log1p', 'rint', 'radians', 'deg2rad', 'reciprocal', 'square', 'negative', 'fix', 'ceil', + 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel'] @@ -1245,6 +1245,42 @@ def degrees(x, out=None, **kwargs): return _unary_func_helper(x, _npi.degrees, _np.degrees, out=out, **kwargs) +@set_module('mxnet.ndarray.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x *180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.rad2deg(np.pi/2) + 90.0 + """ + return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out) + + @set_module('mxnet.ndarray.numpy') def rint(x, out=None, **kwargs): """ @@ -1388,6 +1424,42 @@ def radians(x, out=None, **kwargs): return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs) +@set_module('mxnet.ndarray.numpy') +def deg2rad(x, out=None): + r""" + deg2rad(x, out=None) + + Convert angles from degrees to radians. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "deg2rad(x)" is "x * pi / 180". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.deg2rad(180) + 3.1415927 + """ + return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out) + + @set_module('mxnet.ndarray.numpy') def reciprocal(x, out=None, **kwargs): r""" diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 52dc9fb7a685..a7908b395d41 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', - 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', + 'degrees', 'rad2deg', 'log2', 'log1p', 'rint', 'radians', 'deg2rad', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', @@ -2725,6 +2725,42 @@ def degrees(x, out=None, **kwargs): return _mx_nd_np.degrees(x, out=out, **kwargs) +@set_module('mxnet.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x * 180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.rad2deg(np.pi/2) + 90.0 + """ + return _mx_nd_np.rad2deg(x, out=out) + + @set_module('mxnet.numpy') def radians(x, out=None, **kwargs): """ @@ -2760,6 +2796,42 @@ def radians(x, out=None, **kwargs): return _mx_nd_np.radians(x, out=out, **kwargs) +@set_module('mxnet.numpy') +def deg2rad(x, out=None): + r""" + deg2rad(x, out=None) + + Convert angles from degrees to radians. + Parameters + ---------- + x : ndarray or scalar + Angles in degrees. + out : ndarray or None, optional + A location into which the result is stored. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "deg2rad(x)" is "x * pi / 180". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + + Examples + -------- + >>> np.deg2rad(180) + 3.1415927 + """ + return _mx_nd_np.deg2rad(x, out=out) + + @set_module('mxnet.numpy') def reciprocal(x, out=None, **kwargs): r""" diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 962fee2d9375..6f0a5c199b6c 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -31,8 +31,8 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', - 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', + 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'rad2deg', 'log2', 'log1p', + 'rint', 'radians', 'deg2rad', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', @@ -1715,6 +1715,36 @@ def degrees(x, out=None, **kwargs): return _unary_func_helper(x, _npi.degrees, _np.degrees, out=out, **kwargs) +@set_module('mxnet.symbol.numpy') +def rad2deg(x, out=None): + r""" + rad2deg(x, out=None) + + Convert angles from radians to degrees. + Parameters + ---------- + x : _Symbol or scalar + Angles in degrees. + out : _Symbol or None, optional + A location into which the result is stored. + + Returns + ------- + y : _Symbol or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "rad2deg(x)" is "x * 180 / pi". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + """ + return _unary_func_helper(x, _npi.rad2deg, _np.rad2deg, out=out) + + def rint(x, out=None, **kwargs): """ Round elements of the array to the nearest integer. @@ -1840,6 +1870,36 @@ def radians(x, out=None, **kwargs): return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs) +@set_module('mxnet.symbol.numpy') +def deg2rad(x, out=None): + r""" + deg2rad(x, out=None) + + Convert angles from degrees to radians. + Parameters + ---------- + x : _Symbol or scalar + Angles in degrees. + out : _Symbol or None, optional + A location into which the result is stored. + + Returns + ------- + y : _Symbol or scalar + The corresponding angle in radians. + This is a scalar if `x` is a scalar. + + Notes + ----- + "deg2rad(x)" is "x * pi / 180". + + This function differs from the original numpy.arange in the following aspects: + - Only support float32 and float64. + - `out` must be in the same size of input. + """ + return _unary_func_helper(x, _npi.deg2rad, _np.deg2rad, out=out) + + @set_module('mxnet.symbol.numpy') def reciprocal(x, out=None, **kwargs): r""" diff --git a/src/operator/contrib/tvmop/ufunc.cc b/src/operator/contrib/tvmop/ufunc.cc index 89a90c022845..e6f5653022e0 100644 --- a/src/operator/contrib/tvmop/ufunc.cc +++ b/src/operator/contrib/tvmop/ufunc.cc @@ -144,6 +144,119 @@ NNVM_REGISTER_OP(_backward_contrib_tvm_vadd) .set_attr("FCompute", mxnet::op::TVMBinaryBackwardComputeUseNone); +inline bool DegandradOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + CHECK(in_attrs->at(0) == mshadow::kFloat64 || + in_attrs->at(0) == mshadow::kFloat32) + << "Only support float32 and float64."; + return out_attrs->at(0) != -1; +} + +template +void TVMUnaryCompute(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (0 == inputs[0].shape_.Size()) { + // 0-size + return; + } + tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], outputs[0]}); +} + +template +void TVMUnaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (0 == inputs[0].shape_.Size()) { + // 0-size + return; + } + std::string funcname = func; + funcname += "req_"; + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (req_type == kWriteTo) { + funcname += "kWriteTo"; + } else { + funcname += "kAddTo"; + } + tvm::runtime::TVMOpModule::Get()->Call(funcname, ctx, + {inputs[0], outputs[0], outputs[0]}); + }) +} + +static constexpr char func_deg2rad_cpu[] = "deg2rad"; +static constexpr char func_deg2rad_gpu[] = "cuda_deg2rad"; +static constexpr char func_backward_deg2rad_cpu[] = "backward_deg2rad"; +static constexpr char func_backward_deg2rad_gpu[] = "cuda_backward_deg2rad"; + +NNVM_REGISTER_OP(_npi_deg2rad) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", mxnet::op::ElemwiseShape<1, 1>) +.set_attr("FInferType", mxnet::op::DegandradOpType) +#if MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +.add_argument("data", "NDArray-or-Symbol", "the input") +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_deg2rad"}); + +NNVM_REGISTER_OP(_backward_npi_deg2rad) +.set_attr("TIsBackward", true) +#if MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone); + +static constexpr char func_rad2deg_cpu[] = "rad2deg"; +static constexpr char func_rad2deg_gpu[] = "cuda_rad2deg"; +static constexpr char func_backward_rad2deg_cpu[] = "backward_rad2deg"; +static constexpr char func_backward_rad2deg_gpu[] = "cuda_backward_rad2deg"; + +NNVM_REGISTER_OP(_npi_rad2deg) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", mxnet::op::ElemwiseShape<1, 1>) +.set_attr("FInferType", mxnet::op::DegandradOpType) +#if MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", mxnet::op::TVMUnaryCompute) +.add_argument("data", "NDArray-or-Symbol", "the input") +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_rad2deg"}); + +NNVM_REGISTER_OP(_backward_npi_rad2deg) +.set_attr("TIsBackward", true) +#if MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone) +#endif // MXNET_USE_CUDA +.set_attr("FCompute", + mxnet::op::TVMUnaryBackwardComputeUseNone); } // namespace op } // namespace mxnet #endif // MXNET_USE_TVM_OP diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 07bd2864cfb8..ccef78a23ee2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -27,10 +27,14 @@ from common import assertRaises, with_seed import random import scipy.stats as ss -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry +from mxnet.runtime import Features import platform +_features = Features() + + @with_seed() @use_np def test_np_tensordot(): @@ -1015,6 +1019,9 @@ def hybrid_forward(self, F, a, *args, **kwargs): 'arccosh' : (lambda x: 1./(x**2 - 1.)**(1./2.), 2.0, 5.0), 'arctanh' : (lambda x: -1./(x**2 - 1.), -0.99, 0.99) } + if _features.is_enabled("TVM_OP"): + funcs['rad2deg'] = (lambda x: 180. / _np.pi * _np.ones(x.shape), -1.0, 1.0) + funcs['deg2rad'] = (lambda x: _np.pi / 180. * _np.ones(x.shape), -1.0, 1.0) ndim = random.choice([2, 3, 4]) shape = random.choice([rand_shape_nd(ndim, dim=3), (1, 0, 2)]) for shape in [rand_shape_nd(ndim, dim=3), (1, 0, 2)]: