From 27bddf8f1667cd7161f0c41ecabe6123736e403d Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 30 Oct 2019 17:34:50 +0800 Subject: [PATCH] [Numpy] Numpy operator diff (#15906) * numpy diff operator implemented append and prepend not supported yet remove the prepend and append checking interface from the backend refine the code, enrich the test set and all tests passed registered the diff operator into npi scope all tests passed comments and minor modification format codes and fix warning for sanity check minor modification for sanity check fix sanity fix the tolerance bound of testing np.diff resolve minor coding style issue replace the given tests by random picking minor fix * interoperability test added --- python/mxnet/ndarray/numpy/_op.py | 50 +++- python/mxnet/numpy/multiarray.py | 51 +++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 51 +++- src/operator/numpy/np_diff-inl.h | 220 ++++++++++++++++++ src/operator/numpy/np_diff.cc | 109 +++++++++ src/operator/numpy/np_diff.cu | 37 +++ .../unittest/test_numpy_interoperability.py | 24 ++ tests/python/unittest/test_numpy_op.py | 52 +++++ 9 files changed, 592 insertions(+), 3 deletions(-) create mode 100644 src/operator/numpy/np_diff-inl.h create mode 100644 src/operator/numpy/np_diff.cc create mode 100644 src/operator/numpy/np_diff.cu diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 84aa4a1572d9..297c40b1431e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] @set_module('mxnet.ndarray.numpy') @@ -4983,3 +4983,51 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _npi.share_memory(a, b).item() + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ef88638c857e..a6d90881da4f 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -56,7 +56,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory'] + 'may_share_memory', 'diff'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -6975,3 +6975,52 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _mx_nd_np.may_share_memory(a, b, max_work) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _mx_nd_np.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 6a5f166a70eb..2411f51b7aa6 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'einsum', 'shares_memory', 'may_share_memory', + 'diff', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 2e6d41446930..6f7f912d6e36 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory'] + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff'] def _num_outputs(sym): @@ -4629,4 +4629,53 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b) +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_diff-inl.h b/src/operator/numpy/np_diff-inl.h new file mode 100644 index 000000000000..69f175e802dd --- /dev/null +++ b/src/operator/numpy/np_diff-inl.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff-inl.h + * \brief Function definition of numpy-compatible diff operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct DiffParam : public dmlc::Parameter { + int n, axis; + dmlc::optional prepend; + dmlc::optional append; + DMLC_DECLARE_PARAMETER(DiffParam) { + DMLC_DECLARE_FIELD(n).set_default(1).describe( + "The number of times values are differenced." + " If zero, the input is returned as-is."); + DMLC_DECLARE_FIELD(axis).set_default(-1).describe( + "Axis along which the cumulative sum is computed." + " The default (None) is to compute the diff over the flattened array."); + } +}; + +inline void YanghuiTri(std::vector* buffer, int n) { + // apply basic yanghui's triangular to calculate the factors + (*buffer)[0] = 1; + for (int i = 1; i <= n; ++i) { + (*buffer)[i] = 1; + for (int j = i - 1; j > 0; --j) { + (*buffer)[j] += (*buffer)[j - 1]; + } + } +} + +struct diff_forward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* out, + const IType* in, const int n, + const int stride, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + + // j represent the memory index of the corresponding input entry + int j = ravel(unravel(i, oshape), ishape); + int indicator = 1; + out[i] = 0; + for (int k = n; k >= 0; --k) { + out[i] += in[j + stride * k] * indicator * diffFactor[k]; + indicator *= -1; + } + } +}; + +template +void DiffForwardImpl(const OpContext& ctx, const TBlob& in, const TBlob& out, + const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, in.ndim()); + // nothing in the output + if (n >= in.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = in.ndim() - 1; i > axis_checked; --i) { + stride *= in.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { + MXNET_NDIM_SWITCH(in.ndim(), ndim, { + Kernel::Launch( + s, out.Size(), diffFactor.dptr_, + out.dptr(), in.dptr(), + n, stride, out.shape_.get(), + in.shape_.get()); + }); + }); + }); +} + +template +void DiffForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffForwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +struct diff_backward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* igrad, + const IType* ograd, const int n, + const int stride, const int axis, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + if (n == 0) { + igrad[i] = ograd[i]; + return; + } + + Shape coor = unravel(i, oshape); + // one head thread for a whole sequence along the axis + if (coor[axis] != 0) return; + int j = ravel(coor, ishape); + // initialize the elements of output array + for (int k = 0; k < oshape[axis]; ++k) igrad[i + k * stride] = 0; + for (int k = 0; k < ishape[axis]; ++k) { + int indicator = 1; + for (int m = n; m >= 0; --m) { + igrad[i + (m + k) * stride] += + ograd[j + k * stride] * indicator * diffFactor[m]; + indicator *= -1; + } + } + } +}; + +template +void DiffBackwardImpl(const OpContext& ctx, const TBlob& ograd, + const TBlob& igrad, const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, igrad.ndim()); + // nothing in the ograd and igrad + if (n >= igrad.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = igrad.ndim() - 1; i > axis_checked; --i) { + stride *= igrad.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(igrad.type_flag_, OType, { + MXNET_NDIM_SWITCH(igrad.ndim(), ndim, { + Kernel::Launch( + s, igrad.Size(), diffFactor.dptr_, + igrad.dptr(), ograd.dptr(), + n, stride, axis_checked, + igrad.shape_.get(), ograd.shape_.get()); + }); + }); + }); +} + +template +void DiffBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffBackwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ diff --git a/src/operator/numpy/np_diff.cc b/src/operator/numpy/np_diff.cc new file mode 100644 index 000000000000..a3dae332d842 --- /dev/null +++ b/src/operator/numpy/np_diff.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cc + * \brief CPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +inline TShape NumpyDiffShapeImpl(const TShape& ishape, + const int n, + const int axis) { + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, ishape.ndim()); + + TShape oshape = ishape; + if (n >= ishape[axis_checked]) { + oshape[axis_checked] = 0; + } else { + oshape[axis_checked] -= n; + } + return oshape; +} + +inline bool DiffShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const DiffParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyDiffShapeImpl((*in_attrs)[0], param.n, param.axis)); + return shape_is_known(out_attrs->at(0)); +} + +inline bool DiffType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(DiffParam); + +NNVM_REGISTER_OP(_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FInferShape", DiffShape) +.set_attr("FInferType", DiffType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffForward) +.set_attr("FGradient", + ElemwiseGradUseNone{"_backward_npi_diff"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("a", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(DiffParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_diff.cu b/src/operator/numpy/np_diff.cu new file mode 100644 index 000000000000..daea6e368e05 --- /dev/null +++ b/src/operator/numpy/np_diff.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cu + * \brief GPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_diff) +.set_attr("FCompute", DiffForward); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 624fc0a107b0..103f2c117ea6 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1085,6 +1085,29 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diff(): + x = np.array([1, 4, 6, 7, 12]) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, 2) + OpArgMngr.add_workload('diff', x, 3) + OpArgMngr.add_workload('diff', np.array([1.1, 2.2, 3.0, -0.2, -0.1])) + x = np.zeros((10, 20, 30)) + x[:, 1::2, :] = 1 + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, axis=-1) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, axis=1) + OpArgMngr.add_workload('diff', x, axis=-2) + x = 20 * np.random.uniform(size=(10,20,30)) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, n=2) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, n=2, axis=0) + x = np.array([list(range(3))]) + for n in range(1, 5): + OpArgMngr.add_workload('diff', x, n=n) + + @use_np def _prepare_workloads(): array_pool = { @@ -1190,6 +1213,7 @@ def _prepare_workloads(): _add_workload_greater_equal(array_pool) _add_workload_less(array_pool) _add_workload_less_equal(array_pool) + _add_workload_diff() _prepare_workloads() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index bfe6c3d43b50..67c1ede6cc1a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3764,6 +3764,58 @@ def test_np_share_memory(): assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt)) +@with_seed() +@use_np +def test_np_diff(): + def np_diff_backward(ograd, n, axis): + res = ograd + for i in range(n): + res = _np.negative(_np.diff(res, n=1, axis=axis, prepend=0, append=0)) + return res + + class TestDiff(HybridBlock): + def __init__(self, n=1, axis=-1): + super(TestDiff, self).__init__() + self._n = n + self._axis = axis + + def hybrid_forward(self, F, a): + return F.np.diff(a, n=self._n, axis=self._axis) + + shapes = [tuple(random.randrange(10) for i in range(random.randrange(6))) for j in range(5)] + for hybridize in [True, False]: + for shape in shapes: + for axis in [i for i in range(-len(shape), len(shape))]: + for n in [i for i in range(0, shape[axis]+1)]: + test_np_diff = TestDiff(n=n, axis=axis) + if hybridize: + test_np_diff.hybridize() + for itype in [_np.float16, _np.float32, _np.float64]: + # note the tolerance shall be scaled by the input n + if itype == _np.float16: + rtol = atol = 1e-2*len(shape)*n + else: + rtol = atol = 1e-5*len(shape)*n + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + x.attach_grad() + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + with mx.autograd.record(): + mx_out = test_np_diff(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + if (np_out.size == 0): + np_backward = _np.zeros(shape) + else: + np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) + assert x.grad.shape == np_backward.shape + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + mx_out = np.diff(x, n=n, axis=axis) + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()