From 5aa74e066277c2b55a8307f32973425bf85de4c0 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 29 Oct 2019 13:30:57 -0700 Subject: [PATCH 1/6] Move ops which don't support FP16 dtype to FP32 list (#16668) --- python/mxnet/contrib/amp/lists/symbol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 397f4775f8cd..2146853a6866 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -368,6 +368,9 @@ 'arctanh', '_sparse_arcsin', '_sparse_arctanh', + '_contrib_MultiBoxDetection', + '_contrib_MultiBoxPrior', + '_contrib_MultiBoxTarget', # Exponents 'exp', @@ -575,9 +578,6 @@ 'stack', '_Maximum', '_Minimum', - '_contrib_MultiBoxDetection', - '_contrib_MultiBoxPrior', - '_contrib_MultiBoxTarget', '_contrib_MultiProposal', '_contrib_PSROIPooling', '_contrib_Proposal', From 8e50fd9d862a2917e0727b8cd17f1e5cf0eba081 Mon Sep 17 00:00:00 2001 From: phinzphinz Date: Wed, 30 Oct 2019 01:55:02 +0100 Subject: [PATCH 2/6] no such method => modified function args (#16610) * no such method => modified function args ERROR: MethodError: no method matching mapreduce(::getfield(MXNet.mx, Symbol("##8072#8073")), ::typeof(+), ::Float64, ::Array{NDArray{Float32,1},1}) * julia could not build package before * Update julia/src/metric.jl Co-Authored-By: Iblis Lin --- julia/deps/build.jl | 2 +- julia/src/metric.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/julia/deps/build.jl b/julia/deps/build.jl index a87343d9dab5..a79d2a062c18 100644 --- a/julia/deps/build.jl +++ b/julia/deps/build.jl @@ -54,7 +54,7 @@ if Sys.isunix() nvcc_path = Sys.which("nvcc") if nvcc_path ≢ nothing @info "Found nvcc: $nvcc_path" - push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc", "lib64")) + push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc" => "lib64")) end end diff --git a/julia/src/metric.jl b/julia/src/metric.jl index f1cdc68d947f..2ae7fc85144b 100644 --- a/julia/src/metric.jl +++ b/julia/src/metric.jl @@ -260,7 +260,7 @@ end function get(metric::MSE) # Delay copy until last possible moment - mse_sum = mapreduce(nda->copy(nda)[1], +, 0.0, metric.mse_sum) + mse_sum = mapreduce(nda->copy(nda)[1], +, metric.mse_sum; init = zero(MX_float)) [(:MSE, mse_sum / metric.n_sample)] end From 27bddf8f1667cd7161f0c41ecabe6123736e403d Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 30 Oct 2019 17:34:50 +0800 Subject: [PATCH 3/6] [Numpy] Numpy operator diff (#15906) * numpy diff operator implemented append and prepend not supported yet remove the prepend and append checking interface from the backend refine the code, enrich the test set and all tests passed registered the diff operator into npi scope all tests passed comments and minor modification format codes and fix warning for sanity check minor modification for sanity check fix sanity fix the tolerance bound of testing np.diff resolve minor coding style issue replace the given tests by random picking minor fix * interoperability test added --- python/mxnet/ndarray/numpy/_op.py | 50 +++- python/mxnet/numpy/multiarray.py | 51 +++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 51 +++- src/operator/numpy/np_diff-inl.h | 220 ++++++++++++++++++ src/operator/numpy/np_diff.cc | 109 +++++++++ src/operator/numpy/np_diff.cu | 37 +++ .../unittest/test_numpy_interoperability.py | 24 ++ tests/python/unittest/test_numpy_op.py | 52 +++++ 9 files changed, 592 insertions(+), 3 deletions(-) create mode 100644 src/operator/numpy/np_diff-inl.h create mode 100644 src/operator/numpy/np_diff.cc create mode 100644 src/operator/numpy/np_diff.cu diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 84aa4a1572d9..297c40b1431e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] @set_module('mxnet.ndarray.numpy') @@ -4983,3 +4983,51 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _npi.share_memory(a, b).item() + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index ef88638c857e..a6d90881da4f 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -56,7 +56,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory'] + 'may_share_memory', 'diff'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -6975,3 +6975,52 @@ def may_share_memory(a, b, max_work=None): - Actually it is same as `shares_memory` in MXNet DeepNumPy """ return _mx_nd_np.may_share_memory(a, b, max_work) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _mx_nd_np.diff(a, n=n, axis=axis) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 6a5f166a70eb..2411f51b7aa6 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'einsum', 'shares_memory', 'may_share_memory', + 'diff', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 2e6d41446930..6f7f912d6e36 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory'] + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff'] def _num_outputs(sym): @@ -4629,4 +4629,53 @@ def may_share_memory(a, b, max_work=None): return _npi.share_memory(a, b) +def diff(a, n=1, axis=-1, prepend=None, append=None): + r""" + numpy.diff(a, n=1, axis=-1, prepend=, append=) + + Calculate the n-th discrete difference along the given axis. + + Parameters + ---------- + a : ndarray + Input array + n : int, optional + The number of times values are differenced. If zero, the input is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the last axis. + prepend, append : ndarray, optional + Not supported yet + + Returns + ------- + diff : ndarray + The n-th differences. + The shape of the output is the same as a except along axis where the dimension is smaller by n. + The type of the output is the same as the type of the difference between any two elements of a. + This is the same as the type of a in most cases. + + Examples + -------- + >>> x = np.array([1, 2, 4, 7, 0]) + >>> np.diff(x) + array([ 1, 2, 3, -7]) + >>> np.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> np.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> np.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + + Notes + ----- + Optional inputs `prepend` and `append` are not supported yet + """ + if (prepend or append): + raise NotImplementedError('prepend and append options are not supported yet') + return _npi.diff(a, n=n, axis=axis) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_diff-inl.h b/src/operator/numpy/np_diff-inl.h new file mode 100644 index 000000000000..69f175e802dd --- /dev/null +++ b/src/operator/numpy/np_diff-inl.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff-inl.h + * \brief Function definition of numpy-compatible diff operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct DiffParam : public dmlc::Parameter { + int n, axis; + dmlc::optional prepend; + dmlc::optional append; + DMLC_DECLARE_PARAMETER(DiffParam) { + DMLC_DECLARE_FIELD(n).set_default(1).describe( + "The number of times values are differenced." + " If zero, the input is returned as-is."); + DMLC_DECLARE_FIELD(axis).set_default(-1).describe( + "Axis along which the cumulative sum is computed." + " The default (None) is to compute the diff over the flattened array."); + } +}; + +inline void YanghuiTri(std::vector* buffer, int n) { + // apply basic yanghui's triangular to calculate the factors + (*buffer)[0] = 1; + for (int i = 1; i <= n; ++i) { + (*buffer)[i] = 1; + for (int j = i - 1; j > 0; --j) { + (*buffer)[j] += (*buffer)[j - 1]; + } + } +} + +struct diff_forward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* out, + const IType* in, const int n, + const int stride, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + + // j represent the memory index of the corresponding input entry + int j = ravel(unravel(i, oshape), ishape); + int indicator = 1; + out[i] = 0; + for (int k = n; k >= 0; --k) { + out[i] += in[j + stride * k] * indicator * diffFactor[k]; + indicator *= -1; + } + } +}; + +template +void DiffForwardImpl(const OpContext& ctx, const TBlob& in, const TBlob& out, + const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, in.ndim()); + // nothing in the output + if (n >= in.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = in.ndim() - 1; i > axis_checked; --i) { + stride *= in.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { + MXNET_NDIM_SWITCH(in.ndim(), ndim, { + Kernel::Launch( + s, out.Size(), diffFactor.dptr_, + out.dptr(), in.dptr(), + n, stride, out.shape_.get(), + in.shape_.get()); + }); + }); + }); +} + +template +void DiffForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffForwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +struct diff_backward { + template + MSHADOW_XINLINE static void Map(int i, int* diffFactor, OType* igrad, + const IType* ograd, const int n, + const int stride, const int axis, + const mshadow::Shape oshape, + const mshadow::Shape ishape) { + using namespace broadcast; + if (n == 0) { + igrad[i] = ograd[i]; + return; + } + + Shape coor = unravel(i, oshape); + // one head thread for a whole sequence along the axis + if (coor[axis] != 0) return; + int j = ravel(coor, ishape); + // initialize the elements of output array + for (int k = 0; k < oshape[axis]; ++k) igrad[i + k * stride] = 0; + for (int k = 0; k < ishape[axis]; ++k) { + int indicator = 1; + for (int m = n; m >= 0; --m) { + igrad[i + (m + k) * stride] += + ograd[j + k * stride] * indicator * diffFactor[m]; + indicator *= -1; + } + } + } +}; + +template +void DiffBackwardImpl(const OpContext& ctx, const TBlob& ograd, + const TBlob& igrad, const int n, const int axis) { + using namespace mshadow; + using namespace mxnet_op; + + // undefined behavior for n < 0 + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, igrad.ndim()); + // nothing in the ograd and igrad + if (n >= igrad.shape_[axis_checked]) return; + // stride for elements on the given axis, same in input and output + int stride = 1; + for (int i = igrad.ndim() - 1; i > axis_checked; --i) { + stride *= igrad.shape_[i]; + } + + Stream* s = ctx.get_stream(); + std::vector buffer(n+1, 0); + YanghuiTri(&buffer, n); + Tensor diffFactor = + ctx.requested[0].get_space_typed(Shape1(n + 1), s); + Copy(diffFactor, Tensor(&buffer[0], Shape1(n + 1), 0), s); + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(igrad.type_flag_, OType, { + MXNET_NDIM_SWITCH(igrad.ndim(), ndim, { + Kernel::Launch( + s, igrad.Size(), diffFactor.dptr_, + igrad.dptr(), ograd.dptr(), + n, stride, axis_checked, + igrad.shape_.get(), ograd.shape_.get()); + }); + }); + }); +} + +template +void DiffBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const DiffParam& param = nnvm::get(attrs.parsed); + + DiffBackwardImpl(ctx, inputs[0], outputs[0], param.n, param.axis); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_DIFF_INL_H_ diff --git a/src/operator/numpy/np_diff.cc b/src/operator/numpy/np_diff.cc new file mode 100644 index 000000000000..a3dae332d842 --- /dev/null +++ b/src/operator/numpy/np_diff.cc @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cc + * \brief CPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +inline TShape NumpyDiffShapeImpl(const TShape& ishape, + const int n, + const int axis) { + CHECK_GE(n, 0); + int axis_checked = CheckAxis(axis, ishape.ndim()); + + TShape oshape = ishape; + if (n >= ishape[axis_checked]) { + oshape[axis_checked] = 0; + } else { + oshape[axis_checked] -= n; + } + return oshape; +} + +inline bool DiffShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const DiffParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyDiffShapeImpl((*in_attrs)[0], param.n, param.axis)); + return shape_is_known(out_attrs->at(0)); +} + +inline bool DiffType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(DiffParam); + +NNVM_REGISTER_OP(_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FInferShape", DiffShape) +.set_attr("FInferType", DiffType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffForward) +.set_attr("FGradient", + ElemwiseGradUseNone{"_backward_npi_diff"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("a", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(DiffParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_diff.cu b/src/operator/numpy/np_diff.cu new file mode 100644 index 000000000000..daea6e368e05 --- /dev/null +++ b/src/operator/numpy/np_diff.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_diff.cu + * \brief GPU implementation of numpy-compatible diff operator + */ + +#include "./np_diff-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_diff) +.set_attr("FCompute", DiffForward); + +NNVM_REGISTER_OP(_backward_npi_diff) +.set_attr("FCompute", DiffBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 624fc0a107b0..103f2c117ea6 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1085,6 +1085,29 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diff(): + x = np.array([1, 4, 6, 7, 12]) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, 2) + OpArgMngr.add_workload('diff', x, 3) + OpArgMngr.add_workload('diff', np.array([1.1, 2.2, 3.0, -0.2, -0.1])) + x = np.zeros((10, 20, 30)) + x[:, 1::2, :] = 1 + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, axis=-1) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, axis=1) + OpArgMngr.add_workload('diff', x, axis=-2) + x = 20 * np.random.uniform(size=(10,20,30)) + OpArgMngr.add_workload('diff', x) + OpArgMngr.add_workload('diff', x, n=2) + OpArgMngr.add_workload('diff', x, axis=0) + OpArgMngr.add_workload('diff', x, n=2, axis=0) + x = np.array([list(range(3))]) + for n in range(1, 5): + OpArgMngr.add_workload('diff', x, n=n) + + @use_np def _prepare_workloads(): array_pool = { @@ -1190,6 +1213,7 @@ def _prepare_workloads(): _add_workload_greater_equal(array_pool) _add_workload_less(array_pool) _add_workload_less_equal(array_pool) + _add_workload_diff() _prepare_workloads() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index bfe6c3d43b50..67c1ede6cc1a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3764,6 +3764,58 @@ def test_np_share_memory(): assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt)) +@with_seed() +@use_np +def test_np_diff(): + def np_diff_backward(ograd, n, axis): + res = ograd + for i in range(n): + res = _np.negative(_np.diff(res, n=1, axis=axis, prepend=0, append=0)) + return res + + class TestDiff(HybridBlock): + def __init__(self, n=1, axis=-1): + super(TestDiff, self).__init__() + self._n = n + self._axis = axis + + def hybrid_forward(self, F, a): + return F.np.diff(a, n=self._n, axis=self._axis) + + shapes = [tuple(random.randrange(10) for i in range(random.randrange(6))) for j in range(5)] + for hybridize in [True, False]: + for shape in shapes: + for axis in [i for i in range(-len(shape), len(shape))]: + for n in [i for i in range(0, shape[axis]+1)]: + test_np_diff = TestDiff(n=n, axis=axis) + if hybridize: + test_np_diff.hybridize() + for itype in [_np.float16, _np.float32, _np.float64]: + # note the tolerance shall be scaled by the input n + if itype == _np.float16: + rtol = atol = 1e-2*len(shape)*n + else: + rtol = atol = 1e-5*len(shape)*n + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + x.attach_grad() + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + with mx.autograd.record(): + mx_out = test_np_diff(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + if (np_out.size == 0): + np_backward = _np.zeros(shape) + else: + np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) + assert x.grad.shape == np_backward.shape + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + mx_out = np.diff(x, n=n, axis=axis) + np_out = _np.diff(x.asnumpy(), n=n, axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule() From 77e8f516e7f3d7d2a6c40387028a82ffd761909c Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Wed, 30 Oct 2019 16:33:15 -0700 Subject: [PATCH 4/6] fix cuDNN RNN dtype_with_fallback_ bug (#16671) --- src/operator/rnn-inl.h | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index b448261f215d..d5fd351986e3 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1376,21 +1376,12 @@ class RNNOp { seed_)); // RNN descriptors - cudnnDataType_t dtype_with_fallback_; + // adopt pseudo-fp16 for all architectures + cudnnDataType_t dtype_with_fallback_ = + (cudnnGetVersion() >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT + : dtype_; cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional; - // On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute. - // Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell. - // That's not the case begining from 7.5.0. Thereby adding fallback explicitly here. -#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500 - if (dtype_ == CUDNN_DATA_HALF) { - dtype_with_fallback_ = CUDNN_DATA_FLOAT; - } else { - dtype_with_fallback_ = dtype_; - } -#else - dtype_with_fallback_ = dtype_; -#endif CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, rnn_desc_, param_.state_size, From a6a9706bf962c756e2c934a2e86377c009935e9c Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 30 Oct 2019 21:02:06 -0700 Subject: [PATCH 5/6] Miscellaneous fix for several numpy issues (#16664) * fix behavior of np.array when given official numpy ndarray * bool for expand_dims and cast * recover original Makefile * address comments * add boolean support for cumsum * add gpu cast boolean support * add error message --- python/mxnet/numpy/multiarray.py | 21 ++++++++------ src/ndarray/ndarray_function.cu | 2 +- src/operator/mxnet_op.h | 2 +- src/operator/numpy/np_cumsum-inl.h | 4 +-- src/operator/numpy/np_cumsum.cc | 3 ++ src/operator/numpy/np_init_op.h | 2 +- src/operator/tensor/elemwise_unary_op.h | 2 +- .../unittest/test_numpy_interoperability.py | 5 ++-- tests/python/unittest/test_numpy_ndarray.py | 28 +++++++++++-------- tests/python/unittest/test_numpy_op.py | 12 +++++++- 10 files changed, 53 insertions(+), 28 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a6d90881da4f..bc4b409d5be7 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -577,7 +577,7 @@ def __setitem__(self, key, value): if not isinstance(key, tuple) or len(key) != 0: raise IndexError('scalar tensor can only accept `()` as index') if isinstance(value, numeric_types): - self.full(value) + self._full(value) elif isinstance(value, ndarray) and value.size == 1: if value.shape != self.shape: value = value.reshape(self.shape) @@ -1993,15 +1993,20 @@ def array(object, dtype=None, ctx=None): """ if ctx is None: ctx = current_context() - if isinstance(object, ndarray): + if isinstance(object, (ndarray, _np.ndarray)): dtype = object.dtype if dtype is None else dtype + elif isinstance(object, NDArray): + raise ValueError("If you're trying to create a mxnet.numpy.ndarray " + "from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.") else: - dtype = _np.float32 if dtype is None else dtype - if not isinstance(object, (ndarray, _np.ndarray)): - try: - object = _np.array(object, dtype=dtype) - except Exception as e: - raise TypeError('{}'.format(str(e))) + if dtype is None: + dtype = object.dtype if hasattr(object, "dtype") else _np.float32 + try: + object = _np.array(object, dtype=dtype) + except Exception as e: + # printing out the error raised by official NumPy's array function + # for transparency on users' side + raise TypeError('{}'.format(str(e))) ret = empty(object.shape, dtype=dtype, ctx=ctx) if len(object.shape) == 0: ret[()] = object diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 2a1461cc8c48..da7b60db7f13 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -76,7 +76,7 @@ void Copy(const TBlob &from, TBlob *to, from.FlatTo1D(s), s); } else { - MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D(s) = mshadow::expr::tcast(from.FlatTo1D(s)); }) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 463c71b5b0eb..91478660a123 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -671,7 +671,7 @@ template MSHADOW_CINLINE void copy(mshadow::Stream *s, const TBlob& to, const TBlob& from) { CHECK_EQ(from.Size(), to.Size()); CHECK_EQ(from.dev_mask(), to.dev_mask()); - MSHADOW_TYPE_SWITCH(to.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(to.type_flag_, DType, { if (to.type_flag_ == from.type_flag_) { mshadow::Copy(to.FlatTo1D(s), from.FlatTo1D(s), s); } else { diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h index 6c6b56d46e76..375d83b2240f 100644 --- a/src/operator/numpy/np_cumsum-inl.h +++ b/src/operator/numpy/np_cumsum-inl.h @@ -98,7 +98,7 @@ void CumsumForwardImpl(const OpContext& ctx, } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(in.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, IType, { MSHADOW_TYPE_SWITCH(out.type_flag_, OType, { Kernel::Launch( s, out.Size() / middle, out.dptr(), @@ -157,7 +157,7 @@ void CumsumBackwardImpl(const OpContext& ctx, } } Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(igrad.type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(igrad.type_flag_, IType, { MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { Kernel::Launch( s, igrad.Size() / middle, igrad.dptr(), diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc index 0ddbf521186c..2d5dbb99f90a 100644 --- a/src/operator/numpy/np_cumsum.cc +++ b/src/operator/numpy/np_cumsum.cc @@ -55,6 +55,9 @@ inline bool CumsumType(const nnvm::NodeAttrs& attrs, } else { TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + if (out_attrs->at(0) == mshadow::kBool) { + (*out_attrs)[0] = mshadow::kInt64; + } } return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 69999ae8710e..df30d611aa02 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -205,7 +205,7 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const TBlob& out_data = outputs[0]; int n = out_data.shape_[0]; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( s, out_data.Size(), out_data.dptr(), n); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index b7625fccf258..188ccd68a340 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -451,7 +451,7 @@ void CastCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, { Tensor out = outputs[0].FlatTo1D(s); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); if (outputs[0].type_flag_ != inputs[0].type_flag_ || req[0] != kWriteInplace) { diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 103f2c117ea6..5d6e8af7fa47 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -93,6 +93,7 @@ def _add_workload_copy(): def _add_workload_expand_dims(): OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)), -1) + OpArgMngr.add_workload('expand_dims', np.random.uniform(size=(4, 1)) > 0.5, -1) for axis in range(-5, 4): OpArgMngr.add_workload('expand_dims', np.empty((2, 3, 4, 5)), axis) @@ -852,8 +853,8 @@ def _signs(dt): # test_float_remainder_corner_cases # Check remainder magnitude. for ct in _FLOAT_DTYPES: - b = _np.array(1.0) - a = np.array(_np.nextafter(_np.array(0.0), -b), dtype=ct) + b = _np.array(1.0, dtype=ct) + a = np.array(_np.nextafter(_np.array(0.0, dtype=ct), -b), dtype=ct) b = np.array(b, dtype=ct) OpArgMngr.add_workload('remainder', a, b) OpArgMngr.add_workload('remainder', -a, -b) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 6077f4df13ae..239f300e028e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -18,6 +18,7 @@ # pylint: skip-file from __future__ import absolute_import from __future__ import division +import itertools import os import unittest import numpy as _np @@ -87,6 +88,7 @@ def test_np_array_creation(): [], (), [[1, 2], [3, 4]], + _np.random.randint(-10, 10, size=rand_shape_nd(3)), _np.random.uniform(size=rand_shape_nd(3)), _np.random.uniform(size=(3, 0, 4)) ] @@ -94,10 +96,12 @@ def test_np_array_creation(): for src in objects: mx_arr = np.array(src, dtype=dtype) assert mx_arr.ctx == mx.current_context() + if dtype is None: + dtype = src.dtype if isinstance(src, _np.ndarray) else _np.float32 if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src.asnumpy(), dtype=dtype) else: - np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) + np_arr = _np.array(src, dtype=dtype) assert mx_arr.dtype == np_arr.dtype assert same(mx_arr.asnumpy(), np_arr) @@ -471,9 +475,6 @@ def test_np_grad_ndarray_type(): @with_seed() @use_np def test_np_ndarray_astype(): - mx_data = np.array([2, 3, 4, 5], dtype=_np.int32) - np_data = mx_data.asnumpy() - class TestAstype(HybridBlock): def __init__(self, dtype, copy): super(TestAstype, self).__init__() @@ -483,24 +484,29 @@ def __init__(self, dtype, copy): def hybrid_forward(self, F, x): return x.astype(dtype=self._dtype, copy=self._copy) - def check_astype_equal(dtype, copy, expect_zero_copy=False, hybridize=False): - test_astype = TestAstype(dtype, copy) + def check_astype_equal(itype, otype, copy, expect_zero_copy=False, hybridize=False): + expect_zero_copy = copy is False and itype == otype + mx_data = np.array([2, 3, 4, 5], dtype=itype) + np_data = mx_data.asnumpy() + test_astype = TestAstype(otype, copy) if hybridize: test_astype.hybridize() mx_ret = test_astype(mx_data) assert type(mx_ret) is np.ndarray - np_ret = np_data.astype(dtype=dtype, copy=copy) + np_ret = np_data.astype(dtype=otype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) if expect_zero_copy and not hybridize: assert id(mx_ret) == id(mx_data) assert id(np_ret) == id(np_data) - for dtype in [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, - 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool']: + dtypes = [np.int8, np.uint8, np.int32, np.float16, np.float32, np.float64, np.bool, np.bool_, + 'int8', 'uint8', 'int32', 'float16', 'float32', 'float64', 'bool'] + + for itype, otype in itertools.product(dtypes, dtypes): for copy in [True, False]: for hybridize in [True, False]: - check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype, hybridize) + check_astype_equal(itype, otype, copy, hybridize) @with_seed() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 67c1ede6cc1a..0b15c7ea0d2d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -442,7 +442,7 @@ def is_int(dtype): for axis in ([i for i in range(in_data_dim)] + [(), None]): for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']: for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']: - if (is_int(dtype) and not is_int(itype))\ + if (is_int(dtype) and not is_int(itype)) or (is_windows and is_int(itype))\ or (itype == 'bool' and\ (dtype not in ('float32', 'float64', 'int32', 'int64') or is_windows)): continue @@ -2390,6 +2390,16 @@ def hybrid_forward(self, F, a): np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + for shape in shapes: + for axis in [None] + [i for i in range(0, len(shape))]: + for otype in [None, _np.int32, _np.int64]: + for itype in [_np.bool, _np.int8, _np.int32, _np.int64]: + x = rand_ndarray(shape).astype(itype).as_np_ndarray() + np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype) + mx_out = np.cumsum(x, axis=axis, dtype=otype) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + @with_seed() @use_np From f9baec9a020d125d601790659e15fd7f9c2978c1 Mon Sep 17 00:00:00 2001 From: vexilligera Date: Thu, 31 Oct 2019 04:59:27 +0000 Subject: [PATCH 6/6] [Numpy] implement np.column_stack (#16594) * implement np.column_stack * cpplint * remove column_stack from numpy interoperability test temporarily * style and test fix * fix pylint and add interoperability test * fix doc string, add comment, remove dead code * pylint * ci * ci * ci * [Numpy] Numpy operator diff (#15906) * numpy diff operator implemented append and prepend not supported yet remove the prepend and append checking interface from the backend refine the code, enrich the test set and all tests passed registered the diff operator into npi scope all tests passed comments and minor modification format codes and fix warning for sanity check minor modification for sanity check fix sanity fix the tolerance bound of testing np.diff resolve minor coding style issue replace the given tests by random picking minor fix * interoperability test added * implement np.column_stack * cpplint * remove column_stack from numpy interoperability test temporarily * style and test fix * fix pylint and add interoperability test * fix doc string, add comment, remove dead code * pylint * ci * ci * ci * rebase resolve conflicts * pylint --- python/mxnet/ndarray/numpy/_op.py | 35 ++++- python/mxnet/numpy/multiarray.py | 38 ++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 40 ++++- src/operator/numpy/np_matrix_op-inl.h | 80 ++++++++++ src/operator/numpy/np_matrix_op.cc | 148 +++++++++++++++++- src/operator/numpy/np_matrix_op.cu | 6 + .../unittest/test_numpy_interoperability.py | 6 + tests/python/unittest/test_numpy_op.py | 106 +++++++++++++ 9 files changed, 452 insertions(+), 8 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 297c40b1431e..256cfb7d5708 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -34,13 +34,12 @@ 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', - 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', - 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', + 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] - @set_module('mxnet.ndarray.numpy') def zeros(shape, dtype=_np.float32, order='C', ctx=None): """Return a new array of given shape and type, filled with zeros. @@ -3004,6 +3003,36 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.ndarray.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Returns + -------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _npi.column_stack(*tup) + + @set_module('mxnet.ndarray.numpy') def dstack(arrays): """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index bc4b409d5be7..8e0d5b209a8d 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -51,7 +51,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', - 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', + 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', @@ -4904,6 +4904,42 @@ def vstack(arrays, out=None): return _mx_nd_np.vstack(arrays) +@set_module('mxnet.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Parameters + ---------- + tup : sequence of 1-D or 2-D arrays. + Arrays to stack. All of them must have the same first dimension. + + Returns + -------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _mx_nd_np.column_stack(tup) + + @set_module('mxnet.numpy') def dstack(arrays): """ diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 2411f51b7aa6..cfab2a49699d 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -121,6 +121,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'var', 'vdot', 'vstack', + 'column_stack', 'zeros_like', 'linalg.norm', 'trace', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 6f7f912d6e36..7469875f267a 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -36,8 +36,8 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', - 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', - 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', + 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff'] @@ -3072,6 +3072,42 @@ def get_list(arrays): return _npi.vstack(*arrays) +@set_module('mxnet.symbol.numpy') +def column_stack(tup): + """ + Stack 1-D arrays as columns into a 2-D array. + + Take a sequence of 1-D arrays and stack them as columns + to make a single 2-D array. 2-D arrays are stacked as-is, + just like with `hstack`. 1-D arrays are turned into 2-D columns + first. + + Parameters + ---------- + tup : sequence of 1-D or 2-D arrays. + Arrays to stack. All of them must have the same first dimension. + + Returns + ------- + stacked : 2-D array + The array formed by stacking the given arrays. + + See Also + -------- + stack, hstack, vstack, concatenate + + Examples + -------- + >>> a = np.array((1,2,3)) + >>> b = np.array((2,3,4)) + >>> np.column_stack((a,b)) + array([[1., 2.], + [2., 3.], + [3., 4.]]) + """ + return _npi.column_stack(*tup) + + @set_module('mxnet.symbol.numpy') def dstack(arrays): """ diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 9ce84835f1a8..2545adcb3555 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -52,6 +52,14 @@ struct NumpyVstackParam : public dmlc::Parameter { } }; +struct NumpyColumnStackParam : public dmlc::Parameter { + int num_args; + DMLC_DECLARE_PARAMETER(NumpyColumnStackParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs to be column stacked"); + } +}; + struct NumpyReshapeParam : public dmlc::Parameter { mxnet::TShape newshape; std::string order; @@ -124,6 +132,78 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } +template +void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), param.num_args); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(req.size(), 1); + + // reshape if necessary + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) { + TShape shape = Shape2(inputs[i].shape_.Size(), 1); + data[i] = inputs[i].reshape(shape); + } else { + data[i] = inputs[i]; + } + } + + // initialize ConcatOp + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = 1; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Forward(ctx, data, req, outputs); + }); +} + +template +void NumpyColumnStackBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), param.num_args); + CHECK_EQ(req.size(), param.num_args); + + // reshape if necessary + std::vector data(param.num_args); + for (int i = 0; i < param.num_args; i++) { + if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) { + TShape shape = Shape2(outputs[i].shape_.Size(), 1); + data[i] = outputs[i].reshape(shape); + } else { + data[i] = outputs[i]; + } + } + + // initialize ConcatOp + ConcatParam cparam; + cparam.num_args = param.num_args; + cparam.dim = 1; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + ConcatOp op; + op.Init(cparam); + op.Backward(ctx, inputs[0], req, data); + }); +} + template void NumpyVstackForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 0a6f9a150d8b..18594cd9cff1 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -613,6 +613,152 @@ Examples:: .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") .add_arguments(StackParam::__FIELDS__()); +bool NumpyColumnStackType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), param.num_args); + CHECK_EQ(out_type->size(), 1); + int dtype = -1; + for (int i = 0; i < param.num_args; i++) { + if (dtype == -1) { + dtype = in_type->at(i); + } + } + if (dtype == -1) { + dtype = out_type->at(0); + } + for (int i = 0; i < param.num_args; i++) { + TYPE_ASSIGN_CHECK(*in_type, i, dtype); + } + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return dtype != -1; +} + +bool NumpyColumnStackShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(out_attrs->size(), 1U); + const NumpyColumnStackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), param.num_args); + std::vector in_attrs_tmp(param.num_args); + TShape dshape; + // For each array in the input, reshape to 2D if ndim < 2. + for (int i = 0; i < param.num_args; i++) { + if ((*in_attrs)[i].ndim() == 0) { + in_attrs_tmp[i] = TShape(2, 1); + } else if ((*in_attrs)[i].ndim() == 1) { + // Transpose 1D row into a column. + in_attrs_tmp[i] = TShape(2, 1); + in_attrs_tmp[i][0] = (*in_attrs)[i][0]; + } else { + in_attrs_tmp[i] = (*in_attrs)[i]; + } + TShape tmp(in_attrs_tmp[i].ndim(), -1); + shape_assign(&dshape, tmp); + } + TShape tmp((*out_attrs)[0].ndim(), -1); + shape_assign(&dshape, tmp); + for (int i = 0; i < param.num_args; i++) { + SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape) + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape) + if (dshape.ndim() == -1) { + return false; + } + // Accumulate along column axis. + int cnt = 0, sum = 0, pos = -1; + for (int i = 0; i < param.num_args; i++) { + TShape tmp = in_attrs_tmp[i]; + if (!dim_size_is_known(tmp, 1)) { + cnt++; + pos = i; + } else { + sum += tmp[1]; + } + tmp[1] = -1; + shape_assign(&dshape, tmp); + } + tmp = out_attrs->at(0); + if (!dim_size_is_known(tmp, 1)) { + cnt++; + pos = -1; + } else { + sum += tmp[1]; + } + tmp[1] = -1; + shape_assign(&dshape, tmp); + for (int i = 0; i < param.num_args; i++) { + SHAPE_ASSIGN_CHECK(in_attrs_tmp, i, dshape) + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape)\ + dshape[1] = 0; + if (!shape_is_known(dshape)) { + return false; + } + dshape[1] = sum; + if (cnt == 0) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + } else if (cnt == 1) { + // Infer missing dimension if only one column dimension of the input is missing + if (pos >= 0) { + in_attrs_tmp[pos][1] = out_attrs->at(0)[1] - sum; + } else { + out_attrs->at(0)[1] = sum; + } + } else { + return false; + } + for (int i = 0; i < param.num_args; i++) { + if (in_attrs->at(i).ndim() == 1) { + in_attrs->at(i)[0] = in_attrs_tmp[i][1]; + } else if (in_attrs->at(i).ndim() >= 2) { + in_attrs->at(i) = in_attrs_tmp[i]; + } + } + + return true; +} + +DMLC_REGISTER_PARAMETER(NumpyColumnStackParam); + +NNVM_REGISTER_OP(_npi_column_stack) +.describe(R"code()code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const NumpyColumnStackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + int num_args = dmlc::get(attrs.parsed).num_args; + std::vector ret; + for (int i = 0; i < num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; + }) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferShape", NumpyColumnStackShape) +.set_attr("FInferType", NumpyColumnStackType) +.set_attr("FCompute", NumpyColumnStackForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_column_stack"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to column_stack") +.add_arguments(NumpyColumnStackParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_column_stack) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const NumpyColumnStackParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyColumnStackBackward); + +DMLC_REGISTER_PARAMETER(NumpyVstackParam); + bool NumpyVstackType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -718,8 +864,6 @@ bool NumpyVstackShape(const nnvm::NodeAttrs& attrs, return true; } -DMLC_REGISTER_PARAMETER(NumpyVstackParam); - NNVM_REGISTER_OP(_npi_vstack) .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 6b4f7a11a9a2..fccc8f257e64 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -59,6 +59,12 @@ NNVM_REGISTER_OP(_npi_dstack) NNVM_REGISTER_OP(_backward_np_dstack) .set_attr("FCompute", DStackGradCompute); +NNVM_REGISTER_OP(_npi_column_stack) +.set_attr("FCompute", NumpyColumnStackForward); + +NNVM_REGISTER_OP(_backward_np_column_stack) +.set_attr("FCompute", NumpyColumnStackBackward); + NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 5d6e8af7fa47..15912dc47ad3 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1021,6 +1021,11 @@ def _add_workload_vstack(array_pool): OpArgMngr.add_workload('vstack', array_pool['4x1']) OpArgMngr.add_workload('vstack', array_pool['1x1x0']) +def _add_workload_column_stack(): + OpArgMngr.add_workload('column_stack', (np.array([1, 2, 3]), np.array([2, 3, 4]))) + OpArgMngr.add_workload('column_stack', (np.array([[1], [2], [3]]), np.array([[2], [3], [4]]))) + OpArgMngr.add_workload('column_stack', [np.array(_np.arange(3)) for _ in range(2)]) + def _add_workload_equal(array_pool): # TODO(junwu): fp16 does not work yet with TVM generated ops @@ -1208,6 +1213,7 @@ def _prepare_workloads(): _add_workload_logical_not(array_pool) _add_workload_vdot() _add_workload_vstack(array_pool) + _add_workload_column_stack() _add_workload_equal(array_pool) _add_workload_not_equal(array_pool) _add_workload_greater(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0b15c7ea0d2d..605fa85e1f77 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3690,6 +3690,58 @@ def test_np_true_divide(): @with_seed() @use_np +def test_np_column_stack(): + class TestColumnStack(HybridBlock): + def __init__(self): + super(TestColumnStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.column_stack([a] + list(args)) + + def g(data): + return _np.ones_like(data) + + configs = [ + ((), (), ()), + ((2), (2), (2)), + ((0), (0), (0)), + ((0, 3, 0), (0, 0, 0), (0, 1, 0)), + ((2, 2), (2, 1), (2, 3)), + ((4, 3), (4, 0), (4, 1)), + ((2, 2, 2), (2, 4, 2), (2, 2, 2)), + ((0, 1, 1), (0, 1, 1), (0, 1, 1)) + ] + types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for config, hybridize, dtype in itertools.product(configs, [True, False], types): + test_column_stack = TestColumnStack() + if hybridize: + test_column_stack.hybridize() + rtol = 1e-3 + atol = 1e-5 + v = [] + v_np = [] + for i in range(3): + v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype)) + v.append(mx.nd.array(v_np[i]).as_np_ndarray()) + v[i].attach_grad() + expected_np = _np.column_stack(v_np) + with mx.autograd.record(): + mx_out = test_column_stack(*v) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + for i in range(3): + expected_grad = g(v_np[i]) + assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.column_stack(v) + expected_np = _np.column_stack(v_np) + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + def test_npx_reshape(): class TestNumpyXReshape(HybridBlock): def __init__(self, newshape, reverse): @@ -3826,6 +3878,60 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_column_stack(): + class TestColumnStack(HybridBlock): + def __init__(self): + super(TestColumnStack, self).__init__() + + def hybrid_forward(self, F, a, *args): + return F.np.column_stack([a] + list(args)) + + def g(data): + return _np.ones_like(data) + + configs = [ + ((), (), ()), + ((2), (2), (2)), + ((0), (0), (0)), + ((0, 3, 0), (0, 0, 0), (0, 1, 0)), + ((2, 2), (2, 1), (2, 3)), + ((4, 3), (4, 0), (4, 1)), + ((2, 2, 2), (2, 4, 2), (2, 2, 2)), + ((0, 1, 1), (0, 1, 1), (0, 1, 1)) + ] + types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for config, hybridize, dtype in itertools.product(configs, [True, False], types): + test_column_stack = TestColumnStack() + if hybridize: + test_column_stack.hybridize() + rtol = 1e-3 + atol = 1e-5 + v = [] + v_np = [] + for i in range(3): + v_np.append(_np.array(_np.random.uniform(-10.0, 10.0, config[i]), dtype=dtype)) + v.append(mx.nd.array(v_np[i]).as_np_ndarray()) + v[i].attach_grad() + expected_np = _np.column_stack(v_np) + with mx.autograd.record(): + mx_out = test_column_stack(*v) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + for i in range(3): + expected_grad = g(v_np[i]) + assert_almost_equal(v[i].grad.asnumpy(), expected_grad, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.column_stack(v) + expected_np = _np.column_stack(v_np) + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()