diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 84aa4a1572d9..297c40b1431e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] @set_module('mxnet.ndarray.numpy') @@ -4983,3 +4983,51 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _npi.share_memory(a, b).item() + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ef88638c857e..a6d90881da4f 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -56,7 +56,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory'] + 'may_share_memory', 'diff'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -6975,3 +6975,52 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _mx_nd_np.may_share_memory(a, b, max_work) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _mx_nd_np.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 6a5f166a70eb..2411f51b7aa6 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'einsum', 'shares_memory', 'may_share_memory', + 'diff', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 2e6d41446930..6f7f912d6e36 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory'] + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff'] def _num_outputs(sym): @@ -4629,4 +4629,53 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b) +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_diff-inl.h b/src/operator/numpy/np_diff-inl.h new file mode 100644 index 000000000000..69f175e802dd --- /dev/null +++ b/src/operator/numpy/np_diff-inl.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff-inl.h + * \brief Function definition of numpy-compatible diff operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct DiffParam : public dmlc::Parameter { + int n, axis; + dmlc::optional prepend; + dmlc::optional append; + DMLC_DECLARE_PARAMETER(DiffParam) { + DMLC_DECLARE_FIELD(n).set_default(1).describe( + "The number of times values are differenced." + " If zero, the input is returned as-is."); + DMLC_DECLARE_FIELD(axis).set_default(-1).describe( + "Axis along which the cumulative sum is computed." + " The default (None) is to compute the diff over the flattened array."); + } +}; + +inline void YanghuiTri(std::vector* buffer, int n) { + // apply basic yanghui's triangular to calculate the factors + (*buffer)[0] = 1; + for (int i = 1; i <= n; ++i) { + (*buffer)[i] = 1; + for (int j = i - 1; j > 0; --j) { + (*buffer)[j] += (*buffer)[j - 1]; + } + } +} + +struct diff_forward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* out, + const IType* in, const int n, + const int stride, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + + // j represent the memory index of the corresponding input entry + int j = ravel(unravel(i, oshape), ishape); + int indicator = 1; + out[i] = 0; + for (int k = n; k >= 0; --k) { + out[i] += in[j + stride * k] * indicator * diffFactor[k]; + indicator *= -1; + } + } +}; + +template +void DiffForwardImpl(const OpContext& ctx, const TBlob& in, const TBlob& out, + const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, in.ndim()); + // nothing in the output + if (n >= in.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = in.ndim() - 1; i > axis_checked; --i) { + stride *= in.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { + MXNET_NDIM_SWITCH(in.ndim(), ndim, { + Kernel::Launch( + s, out.Size(), diffFactor.dptr_, + out.dptr(), in.dptr(), + n, stride, out.shape_.get(), + in.shape_.get()); + }); + }); + }); +} + +template +void DiffForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffForwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +struct diff_backward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* igrad, + const IType* ograd, const int n, + const int stride, const int axis, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + if (n == 0) { + igrad[i] = ograd[i]; + return; + } + + Shape coor = unravel(i, oshape); + // one head thread for a whole sequence along the axis + if (coor[axis] != 0) return; + int j = ravel(coor, ishape); + // initialize the elements of output array + for (int k = 0; k < oshape[axis]; ++k) igrad[i + k * stride] = 0; + for (int k = 0; k < ishape[axis]; ++k) { + int indicator = 1; + for (int m = n; m >= 0; --m) { + igrad[i + (m + k) * stride] += + ograd[j + k * stride] * indicator * diffFactor[m]; + indicator *= -1; + } + } + } +}; + +template +void DiffBackwardImpl(const OpContext& ctx, const TBlob& ograd, + const TBlob& igrad, const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, igrad.ndim()); + // nothing in the ograd and igrad + if (n >= igrad.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = igrad.ndim() - 1; i > axis_checked; --i) { + stride *= igrad.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(igrad.type_flag_, OType, { + MXNET_NDIM_SWITCH(igrad.ndim(), ndim, { + Kernel::Launch( + s, igrad.Size(), diffFactor.dptr_, + igrad.dptr(), ograd.dptr(), + n, stride, axis_checked, + igrad.shape_.get(), ograd.shape_.get()); + }); + }); + }); +} + +template +void DiffBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffBackwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ diff --git a/src/operator/numpy/np_diff.cc b/src/operator/numpy/np_diff.cc new file mode 100644 index 000000000000..a3dae332d842 --- /dev/null +++ b/src/operator/numpy/np_diff.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cc + * \brief CPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +inline TShape NumpyDiffShapeImpl(const TShape& ishape, + const int n, + const int axis) { + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, ishape.ndim()); + + TShape oshape = ishape; + if (n >= ishape[axis_checked]) { + oshape[axis_checked] = 0; + } else { + oshape[axis_checked] -= n; + } + return oshape; +} + +inline bool DiffShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const DiffParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyDiffShapeImpl((*in_attrs)[0], param.n, param.axis)); + return shape_is_known(out_attrs->at(0)); +} + +inline bool DiffType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(DiffParam); + +NNVM_REGISTER_OP(_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FInferShape", DiffShape) +.set_attr("FInferType", DiffType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffForward) +.set_attr("FGradient", + ElemwiseGradUseNone{"_backward_npi_diff"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("a", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(DiffParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_diff.cu b/src/operator/numpy/np_diff.cu new file mode 100644 index 000000000000..daea6e368e05 --- /dev/null +++ b/src/operator/numpy/np_diff.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cu + * \brief GPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_diff) +.set_attr("FCompute", DiffForward); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 624fc0a107b0..103f2c117ea6 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1085,6 +1085,29 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diff(): + x = np.array([1, 4, 6, 7, 12]) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, 2) + OpArgMngr.add_workload('diff', x, 3) + OpArgMngr.add_workload('diff', np.array([1.1, 2.2, 3.0, -0.2, -0.1])) + x = np.zeros((10, 20, 30)) + x[:, 1::2, :] = 1 + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, axis=-1) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, axis=1) + OpArgMngr.add_workload('diff', x, axis=-2) + x = 20 * np.random.uniform(size=(10,20,30)) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, n=2) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, n=2, axis=0) + x = np.array([list(range(3))]) + for n in range(1, 5): + OpArgMngr.add_workload('diff', x, n=n) + + @use_np def _prepare_workloads(): array_pool = { @@ -1190,6 +1213,7 @@ def _prepare_workloads(): _add_workload_greater_equal(array_pool) _add_workload_less(array_pool) _add_workload_less_equal(array_pool) + _add_workload_diff() _prepare_workloads() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index bfe6c3d43b50..67c1ede6cc1a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3764,6 +3764,58 @@ def test_np_share_memory(): assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt)) +@with_seed() +@use_np +def test_np_diff(): + def np_diff_backward(ograd, n, axis): + res = ograd + for i in range(n): + res = _np.negative(_np.diff(res, n=1, axis=axis, prepend=0, append=0)) + return res + + class TestDiff(HybridBlock): + def __init__(self, n=1, axis=-1): + super(TestDiff, self).__init__() + self._n = n + self._axis = axis + + def hybrid_forward(self, F, a): + return F.np.diff(a, n=self._n, axis=self._axis) + + shapes = [tuple(random.randrange(10) for i in range(random.randrange(6))) for j in range(5)] + for hybridize in [True, False]: + for shape in shapes: + for axis in [i for i in range(-len(shape), len(shape))]: + for n in [i for i in range(0, shape[axis]+1)]: + test_np_diff = TestDiff(n=n, axis=axis) + if hybridize: + test_np_diff.hybridize() + for itype in [_np.float16, _np.float32, _np.float64]: + # note the tolerance shall be scaled by the input n + if itype == _np.float16: + rtol = atol = 1e-2*len(shape)*n + else: + rtol = atol = 1e-5*len(shape)*n + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + x.attach_grad() + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + with mx.autograd.record(): + mx_out = test_np_diff(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + if (np_out.size == 0): + np_backward = _np.zeros(shape) + else: + np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) + assert x.grad.shape == np_backward.shape + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + mx_out = np.diff(x, n=n, axis=axis) + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()