From a7d0c54ada2359964f0912aa79395a5444af165c Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Tue, 19 Nov 2019 21:41:15 +0800 Subject: [PATCH] Fix redefinition error --- python/mxnet/ndarray/numpy/_op.py | 2 + python/mxnet/numpy_dispatch_protocol.py | 1 + src/operator/numpy/np_matrix_op-inl.h | 154 ------------------ src/operator/numpy/np_matrix_op.cc | 25 +-- .../unittest/test_numpy_interoperability.py | 1 - 5 files changed, 4 insertions(+), 179 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 9c063a7005e2..871c64124c37 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -41,6 +41,7 @@ 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + @set_module('mxnet.ndarray.numpy') def zeros(shape, dtype=_np.float32, order='C', ctx=None): """Return a new array of given shape and type, filled with zeros. @@ -5388,6 +5389,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): def where(condition, x=None, y=None): """where(condition, [x, y]) Return elements chosen from `x` or `y` depending on `condition`. + .. note:: When only `condition` is provided, this function is a shorthand for ``np.asarray(condition).nonzero()``. The rest of this documentation diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 09b053c8d28f..d076d0cc51ad 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -137,6 +137,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'shares_memory', 'may_share_memory', 'diff', + 'resize', 'where', ] diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 5a4dc6a6ff36..5d058ff36000 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -49,16 +49,6 @@ struct NumpyTransposeParam : public dmlc::Parameter { } }; - -struct NumpyDiagflatParam : public dmlc::Parameter { - int k; - DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { - DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. " - "Use k>0 for diagonals above the main diagonal, " - "and k<0 for diagonals below the main diagonal. "); - } -}; - struct NumpyVstackParam : public dmlc::Parameter { int num_args; DMLC_DECLARE_PARAMETER(NumpyVstackParam) { @@ -147,150 +137,6 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } -inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) -{ - if (ishape.ndim() == 1) { - auto s = ishape[0] + std::abs(k); - return mxnet::TShape({s, s}); - } - - if (ishape.ndim() >=2 ){ - auto s = 1; - for(int i = 0; i < ishape.ndim(); i++){ - if(ishape[i] >= 2){ - s = s * ishape[i]; - } - } - s = s + std::abs(k); - return mxnet::TShape({s,s}); - } - return mxnet::TShape({-1,-1}); -} - -inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_attrs, - mxnet::ShapeVector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - const mxnet::TShape& ishape = (*in_attrs)[0]; - if (!mxnet::ndim_is_known(ishape)) { - return false; - } - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - - mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, - param.k); - - if (shape_is_none(oshape)) { - LOG(FATAL) << "Diagonal does not exist."; - } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - - return shape_is_known(out_attrs->at(0)); -} - -inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); - TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); - return (*out_attrs)[0] != -1; -} - -template -struct diagflat_gen { -template - MSHADOW_XINLINE static void Map(index_t i, - DType* out, - const DType* a, - mshadow::Shape<2> oshape, - int k){ - using namespace mxnet_op; - auto j = unravel(i,oshape); - if (j[1] == j[0] + k){ - auto l = j[0] < j[1] ? j[0] : j[1]; - if (back){ - KERNEL_ASSIGN(out[l],req,a[i]); - }else{ - KERNEL_ASSIGN(out[i],req,a[l]); - } - }else if(!back){ - KERNEL_ASSIGN(out[i],req,static_cast(0)); - } - } -}; - -template -void NumpyDiagflatOpImpl(const TBlob& in_data, - const TBlob& out_data, - const mxnet::TShape& ishape, - const mxnet::TShape& oshape, - index_t dsize, - const NumpyDiagflatParam& param, - mxnet_op::Stream *s, - const std::vector& req) { - - using namespace mxnet_op; - using namespace mshadow; - MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{ - MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{ - Kernel, xpu>::Launch(s, - dsize, - out_data.dptr(), - in_data.dptr(), - Shape2(oshape[0],oshape[1]), - param.k); - }); - }); - -} - -template -void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) -{ - using namespace mxnet_op; - using namespace mshadow; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - CHECK_EQ(req[0], kWriteTo); - Stream *s = ctx.get_stream(); - const TBlob& in_data = inputs[0]; - const TBlob& out_data = outputs[0]; - const mxnet::TShape& ishape = inputs[0].shape_; - const mxnet::TShape& oshape = outputs[0].shape_; - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - NumpyDiagflatOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); -} - -template -void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mxnet_op; - using namespace mshadow; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); - - const TBlob& in_data = inputs[0]; - const TBlob& out_data = outputs[0]; - const mxnet::TShape& ishape = inputs[0].shape_; - const mxnet::TShape& oshape = outputs[0].shape_; - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - - NumpyDiagflatOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); -} - template void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 964c8d9e9e8c..7b7e116059aa 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1274,29 +1274,6 @@ inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs, return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis); } -DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); -NNVM_REGISTER_OP(_npi_diagflat) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) -.set_attr("FInferShape", NumpyDiagflatOpShape) -.set_attr("FInferType", NumpyDiagflatOpType) -.set_attr("FCompute",NumpyDiagflatOpForward) -.set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) -.add_argument("data","NDArray-or-Symbol","Input ndarray") -.add_arguments(NumpyDiagflatParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_npi_diagflat) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("TIsBackward",true) -.set_attr("FCompute",NumpyDiagflatOpBackward); - NNVM_REGISTER_OP(_npi_hsplit) .set_attr_parser(ParamParser) .set_num_inputs(1) @@ -1358,7 +1335,7 @@ NNVM_REGISTER_OP(_npi_diagflat) return std::vector{"data"}; }) .set_attr("FInferShape", NumpyDiagflatOpShape) -.set_attr("FInferType", NumpyDiagOpType) +.set_attr("FInferType", NumpyDiagflatOpType) .set_attr("FCompute",NumpyDiagflatOpForward) .set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) .add_argument("data","NDArray-or-Symbol","Input ndarray") diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 430977d68fde..30486416ce23 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1381,7 +1381,6 @@ def _prepare_workloads(): _add_workload_where() _add_workload_diff() _add_workload_resize() - _add_workload_diagflat() _prepare_workloads()