From 275d961b8b54098aefc40660e5d56ed4a554eade Mon Sep 17 00:00:00 2001 From: Alicia Date: Wed, 22 Jan 2020 20:25:14 +0800 Subject: [PATCH] add polynomial polyval --- python/mxnet/ndarray/numpy/__init__.py | 1 + python/mxnet/ndarray/numpy/polynomial.py | 77 ++++++++ python/mxnet/numpy/__init__.py | 1 + python/mxnet/numpy/polynomial.py | 71 +++++++ python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/__init__.py | 1 + python/mxnet/symbol/numpy/polynomial.py | 64 +++++++ src/operator/numpy/np_polynomial_op-inl.h | 174 ++++++++++++++++++ src/operator/numpy/np_polynomial_op.cc | 49 +++++ src/operator/numpy/np_polynomial_op.cu | 37 ++++ .../unittest/test_numpy_interoperability.py | 15 ++ tests/python/unittest/test_numpy_op.py | 67 +++++++ 12 files changed, 558 insertions(+) create mode 100644 python/mxnet/ndarray/numpy/polynomial.py create mode 100644 python/mxnet/numpy/polynomial.py create mode 100644 python/mxnet/symbol/numpy/polynomial.py create mode 100644 src/operator/numpy/np_polynomial_op-inl.h create mode 100644 src/operator/numpy/np_polynomial_op.cc create mode 100644 src/operator/numpy/np_polynomial_op.cu diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py index 7eb478f792f5..da45a6fa31b6 100644 --- a/python/mxnet/ndarray/numpy/__init__.py +++ b/python/mxnet/ndarray/numpy/__init__.py @@ -21,6 +21,7 @@ from . import linalg from . import _op, _internal from . import _register +from .polynomial import * # pylint: disable=wildcard-import from ._op import * # pylint: disable=wildcard-import __all__ = _op.__all__ diff --git a/python/mxnet/ndarray/numpy/polynomial.py b/python/mxnet/ndarray/numpy/polynomial.py new file mode 100644 index 000000000000..577c24eeb288 --- /dev/null +++ b/python/mxnet/ndarray/numpy/polynomial.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=ndarray.""" +from __future__ import absolute_import +import numpy as _np +from ...util import set_module +from . import _internal as _npi + + +__all__ = ['polyval'] + + +@set_module('mxnet.ndarray.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : ndarray + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : ndarray + An array of numbers, at which to evaluate p. + + Returns + ------- + values : ndarray + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + + Examples + -------- + >>> p = np.array([3, 0, 1]) + array([3., 0., 1.]) + >>> x = np.array([5]) + array([5.]) + >>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1 + array([76.]) + >>> x = np.array([5, 4]) + array([5., 4.]) + >>> np.polyval(p, x) + array([76., 49.]) + """ + from ...numpy import ndarray + if isinstance(p, ndarray) and isinstance(x, ndarray): + return _npi.polyval(p, x) + elif not isinstance(p, ndarray) and not isinstance(x, ndarray): + return _np.polyval(p, x) + else: + raise TypeError('type not supported') diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 1994148d14d1..3a4623504324 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -24,6 +24,7 @@ from .multiarray import * # pylint: disable=wildcard-import from . import _op from . import _register +from .polynomial import * # pylint: disable=wildcard-import from ._op import * # pylint: disable=wildcard-import from .utils import * # pylint: disable=wildcard-import from .function_base import * # pylint: disable=wildcard-import diff --git a/python/mxnet/numpy/polynomial.py b/python/mxnet/numpy/polynomial.py new file mode 100644 index 000000000000..aac327bb8c78 --- /dev/null +++ b/python/mxnet/numpy/polynomial.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for ops used in imperative programming.""" + +from __future__ import absolute_import +from ..ndarray import numpy as _mx_nd_np +from ..util import set_module + + +__all__ = ['polyval'] + + +@set_module('mxnet.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : ndarray + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : ndarray + An array of numbers, at which to evaluate p. + + Returns + ------- + values : ndarray + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + + Examples + -------- + >>> p = np.array([3, 0, 1]) + array([3., 0., 1.]) + >>> x = np.array([5]) + array([5.]) + >>> np.polyval(p, x) # 3 * 5**2 + 0 * 5**1 + 1 + array([76.]) + >>> x = np.array([5, 4]) + array([5., 4.]) + >>> np.polyval(p, x) + array([76., 49.]) + """ + return _mx_nd_np.polyval(p, x) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 56944facac81..86f5fec519d4 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -164,6 +164,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'bincount', 'empty_like', 'nan_to_num', + 'polyval', ] diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py index 857849c4ae62..040666b2d7b3 100644 --- a/python/mxnet/symbol/numpy/__init__.py +++ b/python/mxnet/symbol/numpy/__init__.py @@ -22,6 +22,7 @@ from . import _op, _symbol, _internal from ._symbol import _Symbol from . import _register +from .polynomial import * # pylint: disable=wildcard-import from ._op import * # pylint: disable=wildcard-import from ._symbol import * # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy/polynomial.py b/python/mxnet/symbol/numpy/polynomial.py new file mode 100644 index 000000000000..818c38bb5142 --- /dev/null +++ b/python/mxnet/symbol/numpy/polynomial.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=symbol.""" + +from __future__ import absolute_import +import numpy as _np +from ...util import set_module +from . import _internal as _npi +from ..symbol import Symbol + + +__all__ = ['polyval'] + +@set_module('mxnet.symbol.numpy') +def polyval(p, x): + """ + Evaluate a polynomial at specific values. + If p is of length N, this function returns the value: + p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1] + If x is a sequence, then p(x) is returned for each element of x. + If x is another polynomial then the composite polynomial p(x(t)) is returned. + + Parameters + ---------- + p : _Symbol + 1D array of polynomial coefficients (including coefficients equal to zero) + from highest degree to the constant term. + x : _Symbol + An array of numbers, at which to evaluate p. + + Returns + ------- + values : _Symbol + Result array of polynomials + + Notes + ----- + This function differs from the original `numpy.polyval + `_ in + the following way(s): + - Does not support poly1d. + - X should be ndarray type even if it contains only one element. + """ + if isinstance(p, Symbol) and isinstance(x, Symbol): + return _npi.polyval(p, x) + elif not isinstance(p, Symbol) and not isinstance(x, Symbol): + return _np.polyval(p, x) + else: + raise TypeError('type not supported') diff --git a/src/operator/numpy/np_polynomial_op-inl.h b/src/operator/numpy/np_polynomial_op-inl.h new file mode 100644 index 000000000000..bb6324027fed --- /dev/null +++ b/src/operator/numpy/np_polynomial_op-inl.h @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_polynomial_op.h + * \brief Functions for dealing with polynomials. + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_INL_H_ + +#include +#include +#include +#include +#include +#include "../mxnet_op.h" +#include "../../common/utils.h" +#include "../tensor/elemwise_binary_broadcast_op.h" + + +namespace mxnet { +namespace op { + +inline bool NumpyPolyvalShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& p_shape = in_attrs->at(0); + const mxnet::TShape& x_shape = in_attrs->at(1); + const mxnet::TShape& v_shape = out_attrs->at(0); + CHECK_EQ(p_shape.ndim(), 1U) << "ValueError: p has to be an 1-D array."; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, x_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, v_shape); + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +template +struct polyval_forward { + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + const DType* p_data, + const DType* x_data, + const index_t p_size) { + DType val = 0; + for (index_t j = 0; j < p_size; j++) { + val = val * x_data[i] + p_data[j]; + } + KERNEL_ASSIGN(out_data[i], req, val); + } +}; + +template +void NumpyPolyvalForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& p_data = inputs[0]; + const TBlob& x_data = inputs[1]; + const TBlob& out_data = outputs[0]; + const size_t p_size = p_data.Size(); + using namespace mxnet_op; + + MSHADOW_TYPE_SWITCH(x_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, out_data.Size(), out_data.dptr(), + p_data.dptr(), x_data.dptr(), p_size); + }); + }); +} + +template +struct polyval_backward_x { + template + MSHADOW_XINLINE static void Map(int i, const DType* p_dptr, const DType* x_dptr, + DType* igrad_x_dptr, const DType* ograd_dptr, + const index_t p_size) { + DType igrad_x = 0; + index_t j = p_size - 1; + while (j > 0) { + igrad_x = igrad_x * x_dptr[i] + p_dptr[p_size - j - 1] * j; + j--; + } + KERNEL_ASSIGN(igrad_x_dptr[i], req, igrad_x * ograd_dptr[i]); + } +}; + +template +struct polyval_backward_p { + template + MSHADOW_XINLINE static void Map(int i, const DType* p_dptr, const DType* x_dptr, + DType* igrad_p_dptr, const DType* ograd_dptr, + const index_t p_size, const index_t x_size) { + DType igrad_p = 0; + index_t j = x_size - 1; + while (j >= 0) { + igrad_p += pow(x_dptr[j], p_size - i - 1) * ograd_dptr[j]; + j--; + } + KERNEL_ASSIGN(igrad_p_dptr[i], req, igrad_p); + } +}; + +template +void NumpyPolyvalBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_NE(req[0], kWriteInplace); + + MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, PType, { + MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, XType, { + if (!std::is_same::value || + std::is_integral::value || + std::is_integral::value) { + return; + } + }) + }) + + mshadow::Stream *s = ctx.get_stream(); + const TBlob& ograd = inputs[0]; + const TBlob& p = inputs[1]; + const TBlob& x = inputs[2]; + const TBlob& igrad_p = outputs[0]; + const TBlob& igrad_x = outputs[1]; + const size_t p_size = p.Size(); + + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, ograd.Size(), p.dptr(), x.dptr(), + igrad_x.dptr(), ograd.dptr(), p_size); + Kernel, xpu>::Launch( + s, p_size, p.dptr(), x.dptr(), + igrad_p.dptr(), ograd.dptr(), p_size, x.Size()); + }); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_POLYNOMIAL_INL_H_ diff --git a/src/operator/numpy/np_polynomial_op.cc b/src/operator/numpy/np_polynomial_op.cc new file mode 100644 index 000000000000..22b0683f417a --- /dev/null +++ b/src/operator/numpy/np_polynomial_op.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 202o by Contributors + * \file np_polynomial_op.cc +*/ +#include "np_polynomial_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_polyval) + .set_num_inputs(2) + .set_num_outputs(1) + .add_argument("p", "NDArray-or-Symbol", "polynomial coefficients") + .add_argument("x", "NDArray-or-Symbol", "variables") + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"p", "x"}; + }) + .set_attr("FInferShape", NumpyPolyvalShape) + .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) + .set_attr("FCompute", NumpyPolyvalForward) + .set_attr("FGradient", ElemwiseGradUseIn{"_npi_backward_polyval"}); + +NNVM_REGISTER_OP(_npi_backward_polyval) + .set_num_inputs(3) + .set_num_outputs(2) + .set_attr("TIsBackward", true) + .set_attr("FCompute", NumpyPolyvalBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_polynomial_op.cu b/src/operator/numpy/np_polynomial_op.cu new file mode 100644 index 000000000000..a23e41e7a417 --- /dev/null +++ b/src/operator/numpy/np_polynomial_op.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_polynomial_op.cu + */ + +#include "np_polynomial_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_polyval) +.set_attr("FCompute", NumpyPolyvalForward); + +NNVM_REGISTER_OP(_npi_backward_polyval) +.set_attr("FCompute", NumpyPolyvalBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 1ee116c20959..da7a24423078 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1648,6 +1648,20 @@ def _add_workload_nan_to_num(): OpArgMngr.add_workload('nan_to_num', array3, True) +def _add_workload_polyval(): + p1 = np.arange(20) + p2 = np.arange(1) + x1 = np.arange(20) + x2 = np.ones((3,3)) + x3 = np.array(2) + OpArgMngr.add_workload('polyval', p1, x1) + OpArgMngr.add_workload('polyval', p1, x2) + OpArgMngr.add_workload('polyval', p1, x3) + OpArgMngr.add_workload('polyval', p2, x1) + OpArgMngr.add_workload('polyval', p2, x2) + OpArgMngr.add_workload('polyval', p2, x3) + + def _add_workload_linalg_cond(): A = np.array([[1., 0, 1], [0, -2., 0], [0, 0, 3.]]) OpArgMngr.add_workload('linalg.cond', A, np.inf) @@ -1820,6 +1834,7 @@ def _prepare_workloads(): _add_workload_full_like(array_pool) _add_workload_empty_like() _add_workload_nan_to_num() + _add_workload_polyval() _add_workload_heaviside() _add_workload_spacing() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 773476e91903..559c66e0a5aa 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -6391,6 +6391,73 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_polyval(): + class TestPolyval(HybridBlock): + def __init__(self): + super(TestPolyval, self).__init__() + + def hybrid_forward(self, F, p, x, *args, **kwargs): + return F.np.polyval(p, x) + + def polyval_grad(p, x): + x_shape = x.shape + x = x.reshape((x.size, 1)) + x = _np.broadcast_to(x, (x.size, p.size)) + exp = _np.arange(p.size-1, -1, -1) + p_grad = _np.power(x, exp) + coeff = exp-1 + coeff[-1] = 0 + x_grad = _np.power(x, coeff) * p * exp + p_grad = _np.sum(p_grad, axis=0) + x_grad = _np.sum(x_grad, axis=-1).reshape(x_shape) + return (p_grad, x_grad) + + dtypes = ['float32', 'float64', 'int32', 'int64'] + x_shapes = [ + (5,), + (10), + (3, 3), + (3, 4), + (3, 3, 3), + (2, 2, 4, 3), + (2, 0, 2, 3) + ] + flags = [True, False] + for dtype, x_shape, hybridize in itertools.product(dtypes, x_shapes, flags): + p_shape = (random.randint(1, 8),) + test_polyval = TestPolyval() + if hybridize: + test_polyval.hybridize() + rtol = 1e-2 + atol = 1e-4 + if dtype in ['int32', 'int64']: + p = np.random.randint(-16, 16, p_shape, dtype=dtype) + x = np.random.randint(-5, 5, x_shape, dtype=dtype) + else: + p = np.random.uniform(-1.0, 1.0, size=p_shape, dtype=dtype) + x = np.random.uniform(-1.0, 1.0, size=x_shape, dtype=dtype) + + p.attach_grad() + x.attach_grad() + np_out = _np.polyval(p.asnumpy(), x.asnumpy()) + with mx.autograd.record(): + mx_out = test_polyval(p, x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol) + + mx_out.backward() + if dtype in ['float16', 'float32', 'float64']: + p_grad, x_grad = polyval_grad(p.asnumpy(), x.asnumpy()) + assert_almost_equal(p.grad.asnumpy(), p_grad, atol=atol, rtol=rtol) + assert_almost_equal(x.grad.asnumpy(), x_grad, atol=atol, rtol=rtol) + + mx_out = np.polyval(p, x) + np_out = _np.polyval(p.asnumpy(), x.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol) + + @with_seed() @use_np def test_np_where():