From 34a7cf07f60bd86822339b98ab58dce3d125dba6 Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Thu, 14 Nov 2019 13:51:32 +0800 Subject: [PATCH 1/8] Feature: diagflat --- python/mxnet/ndarray/numpy/_op.py | 39 +++++- python/mxnet/numpy/multiarray.py | 39 +++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 39 +++++- src/operator/numpy/np_matrix_op-inl.h | 132 ++++++++++++++++++ src/operator/numpy/np_matrix_op.cc | 23 +++ src/operator/numpy/np_matrix_op.cu | 6 + .../unittest/test_numpy_interoperability.py | 30 ++++ tests/python/unittest/test_numpy_op.py | 41 ++++++ 9 files changed, 347 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 3cc5b85c8384..ec8cb42db5ca 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -29,7 +29,7 @@ from ..ndarray import NDArray __all__ = ['shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', - 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', + 'arctan2', 'sin', 'cos', 'tan', 'diagflat', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', @@ -1441,6 +1441,43 @@ def tanh(x, out=None, **kwargs): return _unary_func_helper(x, _npi.tanh, _np.tanh, out=out, **kwargs) +@set_module('mxnet.ndarray.numpy') +def diagflat(arr, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + return _npi.diagflat(arr, k=k) + + @set_module('mxnet.ndarray.numpy') @wrap_np_unary_func def log10(x, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e94d4c8341b4..54811756376e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -47,7 +47,7 @@ from ..ndarray.ndarray import _storage_type __all__ = ['ndarray', 'empty', 'array', 'shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', + 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'diagflat', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', @@ -2981,6 +2981,43 @@ def tanh(x, out=None, **kwargs): return _mx_nd_np.tanh(x, out=out, **kwargs) +@set_module('mxnet.numpy') +def diagflat(arr, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + return _npi.diagflat(arr, k=k) + + @set_module('mxnet.numpy') @wrap_np_unary_func def log10(x, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index f58159303d0f..67af2724503e 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -94,6 +94,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'copy', 'cumsum', 'diag', + 'diagflat', 'dot', 'expand_dims', 'fix', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 7da771966f1f..b90c031cc818 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', + 'less_equal', 'hsplit', 'rot90', 'einsum', 'diagflat', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] @@ -4708,6 +4708,43 @@ def einsum(*operands, **kwargs): return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg)) +@set_module('mxnet.symbol.numpy') +def diagflat(arr, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + return _npi.diagflat(arr, k=k) + + @set_module('mxnet.symbol.numpy') def shares_memory(a, b, max_work=None): """ diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 508968718af0..afce241982ff 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1150,6 +1150,138 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, in_data.Size(), param.k, s, req[0]); } +struct NumpyDiagflatParam : public dmlc::Parameter { + int k; + DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { + DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. "); + } +}; + +inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) +{ + if (ishape.ndim() == 1) { + auto s = ishape[0] + std::abs(k); + return mxnet::TShape({s, s}); + } + + if (ishape.ndim() >=2 ){ + auto s = 1; + for(int i = 0; i < ishape.ndim(); i++){ + if(ishape[i] >= 2){ + s = s * ishape[i]; + } + } + s = s + std::abs(k); + return mxnet::TShape({s,s}); + } + return mxnet::TShape({-1,-1}); +} + +inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, + param.k); + + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return shape_is_known(out_attrs->at(0)); +} + +inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +void NumpyDiagflatOpImpl(const TBlob& in_data, + const TBlob& out_data, + const mxnet::TShape& ishape, + const mxnet::TShape& oshape, + index_t dsize, + const NumpyDiagflatParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + + using namespace mxnet_op; + using namespace mshadow; + MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{ + MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{ + Kernel, xpu>::Launch(s, + dsize, + out_data.dptr(), + in_data.dptr(), + Shape2(oshape[0],oshape[1]), + param.k); + }); + }); + +} + +template +void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) +{ + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + NumpyDiagflatOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); +} + +template +void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + NumpyDiagflatOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 912b32c2e8fb..578234aa3414 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -37,6 +37,7 @@ DMLC_REGISTER_PARAMETER(NumpyRot90Param); DMLC_REGISTER_PARAMETER(NumpyReshapeParam); DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); DMLC_REGISTER_PARAMETER(NumpyDiagParam); +DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, @@ -1325,5 +1326,27 @@ NNVM_REGISTER_OP(_backward_np_diag) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_npi_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NumpyDiagflatOpShape) +.set_attr("FInferType", NumpyDiagOpType) +.set_attr("FCompute",NumpyDiagflatOpForward) +.set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) +.add_argument("data","NDArray-or-Symbol","Input ndarray") +.add_arguments(NumpyDiagflatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward",true) +.set_attr("FCompute",NumpyDiagflatOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 33f5aab7717c..15439c402be0 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -124,5 +124,11 @@ NNVM_REGISTER_OP(_np_diag) NNVM_REGISTER_OP(_backward_np_diag) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_npi_diagflat) +.set_attr("FCompute", NumpyDiagflatOpForward); + +NNVM_REGISTER_OP(_backward_npi_diagflat) +.set_attr("FCompute", NumpyDiagflatOpBackward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index e52d25239d22..05e7343ba0cd 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1201,6 +1201,35 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([False, False, False], dtype=np.bool_)) OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diagflat(): + def get_mat(n): + data = _np.arange(n) + data = _np.add.outer(data,data) + return data + + A = np.array([[1,2],[3,4],[5,6]]) + vals = (100 * np.arange(5)).astype('l') + vals_c = (100 * np.array(get_mat(5)) + 1).astype('l') + vals_f = _np.array((100 * get_mat(5) + 1), order='F', dtype='l') + vals_f = np.array(vals_f) + + OpArgMngr.add_workload('diagflat', A, k=2) + OpArgMngr.add_workload('diagflat', A, k=1) + OpArgMngr.add_workload('diagflat', A, k=0) + OpArgMngr.add_workload('diagflat', A, k=-1) + OpArgMngr.add_workload('diagflat', A, k=-2) + OpArgMngr.add_workload('diagflat', A, k=-3) + OpArgMngr.add_workload('diagflat', vals, k=0) + OpArgMngr.add_workload('diagflat', vals, k=2) + OpArgMngr.add_workload('diagflat', vals, k=-2) + OpArgMngr.add_workload('diagflat', vals_c, k=0) + OpArgMngr.add_workload('diagflat', vals_c, k=2) + OpArgMngr.add_workload('diagflat', vals_c, k=-2) + OpArgMngr.add_workload('diagflat', vals_f, k=0) + OpArgMngr.add_workload('diagflat', vals_f, k=2) + OpArgMngr.add_workload('diagflat', vals_f, k=-2) + + def _add_workload_shape(): OpArgMngr.add_workload('shape', np.random.uniform(size=())) @@ -1263,6 +1292,7 @@ def _prepare_workloads(): _add_workload_cumsum() _add_workload_ravel() _add_workload_diag() + _add_workload_diagflat() _add_workload_dot() _add_workload_expand_dims() _add_workload_fix() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ff9dd4119968..6fa22e58b918 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4164,6 +4164,47 @@ def dbg(name, data): for (iop, op) in enumerate(grad[0]): assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_diagflat(): + class TestDiagflat(HybridBlock): + def __init__(self, k=0): + super(TestDiagflat,self).__init__() + self._k = k + def hybrid_forward(self,F,a): + return F.np.diagflat(a, k=self._k) + shapes = [(2,),5 , (1,5), (2,2), (2,5), (3,3), (4,3),(4,4,5)] # test_shapes, remember to include zero-dim shape and zero-size shapes + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] # remember to include all meaningful data types for the operator + range_k = 6 + for hybridize,shape,dtype, in itertools.product([False,True],shapes,dtypes): + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + + for k in range(-range_k,range_k): + test_diagflat = TestDiagflat(k) + if hybridize: + test_diagflat.hybridize() + + x = np.random.uniform(-1.0,1.0, size = shape).astype(dtype) + x.attach_grad() + + np_out = _np.diagflat(x.asnumpy(), k) + with mx.autograd.record(): + mx_out = test_diagflat(x) + + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(),np_out,rtol = rtol, atol = atol) + + mx_out.backward() + # Code to get the reference backward value + np_backward = np.ones(shape) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.diagflat(x, k) + np_out = _np.diagflat(x.asnumpy(), k) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + @with_seed() @use_np From bcb6f89fe9744af8b5510c4ca440ed88b5f4d4ba Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Tue, 19 Nov 2019 14:22:32 +0800 Subject: [PATCH 2/8] Fix codes position error --- src/operator/numpy/np_matrix_op-inl.h | 154 ++++++++++++++++++++++++++ src/operator/numpy/np_matrix_op.cc | 23 ++++ src/operator/numpy/np_matrix_op.cu | 6 + 3 files changed, 183 insertions(+) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index afce241982ff..9e18bb1d6ff1 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -49,6 +49,16 @@ struct NumpyTransposeParam : public dmlc::Parameter { } }; + +struct NumpyDiagflatParam : public dmlc::Parameter { + int k; + DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { + DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. "); + } +}; + struct NumpyVstackParam : public dmlc::Parameter { int num_args; DMLC_DECLARE_PARAMETER(NumpyVstackParam) { @@ -137,6 +147,150 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } +inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) +{ + if (ishape.ndim() == 1) { + auto s = ishape[0] + std::abs(k); + return mxnet::TShape({s, s}); + } + + if (ishape.ndim() >=2 ){ + auto s = 1; + for(int i = 0; i < ishape.ndim(); i++){ + if(ishape[i] >= 2){ + s = s * ishape[i]; + } + } + s = s + std::abs(k); + return mxnet::TShape({s,s}); + } + return mxnet::TShape({-1,-1}); +} + +inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, + param.k); + + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return shape_is_known(out_attrs->at(0)); +} + +inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +struct diagflat_gen { +template + MSHADOW_XINLINE static void Map(index_t i, + DType* out, + const DType* a, + mshadow::Shape<2> oshape, + int k){ + using namespace mxnet_op; + auto j = unravel(i,oshape); + if (j[1] == j[0] + k){ + auto l = j[0] < j[1] ? j[0] : j[1]; + if (back){ + KERNEL_ASSIGN(out[l],req,a[i]); + }else{ + KERNEL_ASSIGN(out[i],req,a[l]); + } + }else if(!back){ + KERNEL_ASSIGN(out[i],req,static_cast(0)); + } + } +}; + +template +void NumpyDiagflatOpImpl(const TBlob& in_data, + const TBlob& out_data, + const mxnet::TShape& ishape, + const mxnet::TShape& oshape, + index_t dsize, + const NumpyDiagflatParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + + using namespace mxnet_op; + using namespace mshadow; + MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{ + MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{ + Kernel, xpu>::Launch(s, + dsize, + out_data.dptr(), + in_data.dptr(), + Shape2(oshape[0],oshape[1]), + param.k); + }); + }); + +} + +template +void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) +{ + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + NumpyDiagflatOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); +} + +template +void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + NumpyDiagflatOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); +} + template void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 578234aa3414..964c8d9e9e8c 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1274,6 +1274,29 @@ inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs, return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis); } +DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); +NNVM_REGISTER_OP(_npi_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NumpyDiagflatOpShape) +.set_attr("FInferType", NumpyDiagflatOpType) +.set_attr("FCompute",NumpyDiagflatOpForward) +.set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) +.add_argument("data","NDArray-or-Symbol","Input ndarray") +.add_arguments(NumpyDiagflatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward",true) +.set_attr("FCompute",NumpyDiagflatOpBackward); + NNVM_REGISTER_OP(_npi_hsplit) .set_attr_parser(ParamParser) .set_num_inputs(1) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 15439c402be0..6c1e54faea2b 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -68,6 +68,12 @@ NNVM_REGISTER_OP(_backward_np_column_stack) NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); +NNVM_REGISTER_OP(_npi_diagflat) +.set_attr("FCompute", NumpyDiagflatOpForward); + +NNVM_REGISTER_OP(_backward_npi_diagflat) +.set_attr("FCompute", NumpyDiagflatOpBackward); + template<> void NumpyFlipForwardImpl(const OpContext& ctx, const std::vector& inputs, From e86798bb608f0423ff486029fb5fabb74d3e0aa5 Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Tue, 19 Nov 2019 21:41:15 +0800 Subject: [PATCH 3/8] Fix redefinition error --- src/operator/numpy/np_matrix_op-inl.h | 154 -------------------------- src/operator/numpy/np_matrix_op.cc | 25 +---- 2 files changed, 1 insertion(+), 178 deletions(-) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 9e18bb1d6ff1..afce241982ff 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -49,16 +49,6 @@ struct NumpyTransposeParam : public dmlc::Parameter { } }; - -struct NumpyDiagflatParam : public dmlc::Parameter { - int k; - DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { - DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. " - "Use k>0 for diagonals above the main diagonal, " - "and k<0 for diagonals below the main diagonal. "); - } -}; - struct NumpyVstackParam : public dmlc::Parameter { int num_args; DMLC_DECLARE_PARAMETER(NumpyVstackParam) { @@ -147,150 +137,6 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } -inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) -{ - if (ishape.ndim() == 1) { - auto s = ishape[0] + std::abs(k); - return mxnet::TShape({s, s}); - } - - if (ishape.ndim() >=2 ){ - auto s = 1; - for(int i = 0; i < ishape.ndim(); i++){ - if(ishape[i] >= 2){ - s = s * ishape[i]; - } - } - s = s + std::abs(k); - return mxnet::TShape({s,s}); - } - return mxnet::TShape({-1,-1}); -} - -inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_attrs, - mxnet::ShapeVector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - const mxnet::TShape& ishape = (*in_attrs)[0]; - if (!mxnet::ndim_is_known(ishape)) { - return false; - } - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - - mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, - param.k); - - if (shape_is_none(oshape)) { - LOG(FATAL) << "Diagonal does not exist."; - } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - - return shape_is_known(out_attrs->at(0)); -} - -inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); - TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); - return (*out_attrs)[0] != -1; -} - -template -struct diagflat_gen { -template - MSHADOW_XINLINE static void Map(index_t i, - DType* out, - const DType* a, - mshadow::Shape<2> oshape, - int k){ - using namespace mxnet_op; - auto j = unravel(i,oshape); - if (j[1] == j[0] + k){ - auto l = j[0] < j[1] ? j[0] : j[1]; - if (back){ - KERNEL_ASSIGN(out[l],req,a[i]); - }else{ - KERNEL_ASSIGN(out[i],req,a[l]); - } - }else if(!back){ - KERNEL_ASSIGN(out[i],req,static_cast(0)); - } - } -}; - -template -void NumpyDiagflatOpImpl(const TBlob& in_data, - const TBlob& out_data, - const mxnet::TShape& ishape, - const mxnet::TShape& oshape, - index_t dsize, - const NumpyDiagflatParam& param, - mxnet_op::Stream *s, - const std::vector& req) { - - using namespace mxnet_op; - using namespace mshadow; - MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{ - MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{ - Kernel, xpu>::Launch(s, - dsize, - out_data.dptr(), - in_data.dptr(), - Shape2(oshape[0],oshape[1]), - param.k); - }); - }); - -} - -template -void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) -{ - using namespace mxnet_op; - using namespace mshadow; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - CHECK_EQ(req[0], kWriteTo); - Stream *s = ctx.get_stream(); - const TBlob& in_data = inputs[0]; - const TBlob& out_data = outputs[0]; - const mxnet::TShape& ishape = inputs[0].shape_; - const mxnet::TShape& oshape = outputs[0].shape_; - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - NumpyDiagflatOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); -} - -template -void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mxnet_op; - using namespace mshadow; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); - - const TBlob& in_data = inputs[0]; - const TBlob& out_data = outputs[0]; - const mxnet::TShape& ishape = inputs[0].shape_; - const mxnet::TShape& oshape = outputs[0].shape_; - const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - - NumpyDiagflatOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); -} - template void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 964c8d9e9e8c..7b7e116059aa 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1274,29 +1274,6 @@ inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs, return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis); } -DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); -NNVM_REGISTER_OP(_npi_diagflat) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) -.set_attr("FInferShape", NumpyDiagflatOpShape) -.set_attr("FInferType", NumpyDiagflatOpType) -.set_attr("FCompute",NumpyDiagflatOpForward) -.set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) -.add_argument("data","NDArray-or-Symbol","Input ndarray") -.add_arguments(NumpyDiagflatParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_npi_diagflat) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("TIsBackward",true) -.set_attr("FCompute",NumpyDiagflatOpBackward); - NNVM_REGISTER_OP(_npi_hsplit) .set_attr_parser(ParamParser) .set_num_inputs(1) @@ -1358,7 +1335,7 @@ NNVM_REGISTER_OP(_npi_diagflat) return std::vector{"data"}; }) .set_attr("FInferShape", NumpyDiagflatOpShape) -.set_attr("FInferType", NumpyDiagOpType) +.set_attr("FInferType", NumpyDiagflatOpType) .set_attr("FCompute",NumpyDiagflatOpForward) .set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) .add_argument("data","NDArray-or-Symbol","Input ndarray") From db7b12c5ae58757647e0867048578a75d98f773a Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Tue, 19 Nov 2019 22:13:39 +0800 Subject: [PATCH 4/8] Fix style error&redefinition error --- src/operator/numpy/np_matrix_op.cc | 10 +++++----- src/operator/numpy/np_matrix_op.cu | 6 ------ 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 7b7e116059aa..863edaa292d7 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1336,17 +1336,17 @@ NNVM_REGISTER_OP(_npi_diagflat) }) .set_attr("FInferShape", NumpyDiagflatOpShape) .set_attr("FInferType", NumpyDiagflatOpType) -.set_attr("FCompute",NumpyDiagflatOpForward) -.set_attr("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"}) -.add_argument("data","NDArray-or-Symbol","Input ndarray") +.set_attr("FCompute", NumpyDiagflatOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_diagflat"}) +.add_argument("data","NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyDiagflatParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_npi_diagflat) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) -.set_attr("TIsBackward",true) -.set_attr("FCompute",NumpyDiagflatOpBackward); +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyDiagflatOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 6c1e54faea2b..15439c402be0 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -68,12 +68,6 @@ NNVM_REGISTER_OP(_backward_np_column_stack) NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); -NNVM_REGISTER_OP(_npi_diagflat) -.set_attr("FCompute", NumpyDiagflatOpForward); - -NNVM_REGISTER_OP(_backward_npi_diagflat) -.set_attr("FCompute", NumpyDiagflatOpBackward); - template<> void NumpyFlipForwardImpl(const OpContext& ctx, const std::vector& inputs, From a249b47826662df51d329e5a78eef694c7bf3e43 Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Wed, 20 Nov 2019 00:07:09 +0800 Subject: [PATCH 5/8] Fix style error --- src/operator/numpy/np_matrix_op-inl.h | 34 +++++++++++++-------------- src/operator/numpy/np_matrix_op.cc | 2 +- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index afce241982ff..7a2f089d7040 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1159,24 +1159,23 @@ struct NumpyDiagflatParam : public dmlc::Parameter { } }; -inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) -{ +inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) { if (ishape.ndim() == 1) { auto s = ishape[0] + std::abs(k); return mxnet::TShape({s, s}); } - if (ishape.ndim() >=2 ){ + if ( ishape.ndim() >= 2 ) { auto s = 1; - for(int i = 0; i < ishape.ndim(); i++){ - if(ishape[i] >= 2){ + for ( int i = 0; i < ishape.ndim(); i++ ) { + if ( ishape[i] >= 2 ) { s = s * ishape[i]; } } s = s + std::abs(k); - return mxnet::TShape({s,s}); + return mxnet::TShape({s, s}); } - return mxnet::TShape({-1,-1}); + return mxnet::TShape({-1, -1}); } inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, @@ -1184,7 +1183,7 @@ inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - + const mxnet::TShape& ishape = (*in_attrs)[0]; if (!mxnet::ndim_is_known(ishape)) { return false; @@ -1222,20 +1221,18 @@ void NumpyDiagflatOpImpl(const TBlob& in_data, const NumpyDiagflatParam& param, mxnet_op::Stream *s, const std::vector& req) { - using namespace mxnet_op; using namespace mshadow; - MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{ - MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{ - Kernel, xpu>::Launch(s, + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), in_data.dptr(), - Shape2(oshape[0],oshape[1]), + Shape2(oshape[0], oshape[1]), param.k); }); }); - } template @@ -1243,8 +1240,7 @@ void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs) -{ + const std::vector& outputs) { using namespace mxnet_op; using namespace mshadow; CHECK_EQ(inputs.size(), 1U); @@ -1258,7 +1254,8 @@ void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = outputs[0].shape_; const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - NumpyDiagflatOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); + NumpyDiagflatOpImpl(in_data, out_data, ishape, + oshape, out_data.Size(), param, s, req); } template @@ -1279,7 +1276,8 @@ void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = outputs[0].shape_; const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - NumpyDiagflatOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); + NumpyDiagflatOpImpl(in_data, out_data, oshape, + ishape, in_data.Size(), param, s, req); } } // namespace op diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 863edaa292d7..5a6b533fd890 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1338,7 +1338,7 @@ NNVM_REGISTER_OP(_npi_diagflat) .set_attr("FInferType", NumpyDiagflatOpType) .set_attr("FCompute", NumpyDiagflatOpForward) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_diagflat"}) -.add_argument("data","NDArray-or-Symbol", "Input ndarray") +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyDiagflatParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_npi_diagflat) From 9b4152efa99de2009a55137a4132ad406fadd02e Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Wed, 20 Nov 2019 00:35:59 +0800 Subject: [PATCH 6/8] Fix python style error --- python/mxnet/symbol/numpy/_symbol.py | 78 ++++++++++++++-------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index b90c031cc818..27de718baada 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -33,14 +33,14 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', + 'rint', 'radians', 'diagflat', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum', 'diagflat', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] @@ -2270,6 +2270,43 @@ def radians(x, out=None, **kwargs): return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs) +@set_module('mxnet.symbol.numpy') +def diagflat(arr, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + return _npi.diagflat(arr, k=k) + + @set_module('mxnet.symbol.numpy') @wrap_np_unary_func def deg2rad(x, out=None, **kwargs): @@ -4708,43 +4745,6 @@ def einsum(*operands, **kwargs): return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg)) -@set_module('mxnet.symbol.numpy') -def diagflat(arr, k=0): - """ - Create a two-dimensional array with the flattened input as a diagonal. - Parameters - ---------- - arr : ndarray - Input data, which is flattened and set as the `k`-th - diagonal of the output. - k : int, optional - Diagonal to set; 0, the default, corresponds to the "main" diagonal, - a positive (negative) `k` giving the number of the diagonal above - (below) the main. - Returns - ------- - out : ndarray - The 2-D output array. - See Also - -------- - diag : MATLAB work-alike for 1-D and 2-D arrays. - diagonal : Return specified diagonals. - trace : Sum along diagonals. - Examples - -------- - >>> np.diagflat([[1,2], [3,4]]) - array([[1, 0, 0, 0], - [0, 2, 0, 0], - [0, 0, 3, 0], - [0, 0, 0, 4]]) - >>> np.diagflat([1,2], 1) - array([[0, 1, 0], - [0, 0, 2], - [0, 0, 0]]) - """ - return _npi.diagflat(arr, k=k) - - @set_module('mxnet.symbol.numpy') def shares_memory(a, b, max_work=None): """ From 81a48bb691b64b07592399eeaa5391289de6ae0b Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Wed, 20 Nov 2019 11:17:32 +0800 Subject: [PATCH 7/8] Fix python style error --- tests/python/unittest/test_numpy_interoperability.py | 1 - tests/python/unittest/test_numpy_op.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 05e7343ba0cd..088306ff01c8 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1230,7 +1230,6 @@ def get_mat(n): OpArgMngr.add_workload('diagflat', vals_f, k=-2) - def _add_workload_shape(): OpArgMngr.add_workload('shape', np.random.uniform(size=())) OpArgMngr.add_workload('shape', np.random.uniform(size=(0, 1))) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6fa22e58b918..7dd165b5421f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4164,6 +4164,7 @@ def dbg(name, data): for (iop, op) in enumerate(grad[0]): assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol) + @with_seed() @use_np def test_np_diagflat(): From 28e1a2a99621a905298a62fc3de15605675b70e9 Mon Sep 17 00:00:00 2001 From: cassiniXu Date: Thu, 21 Nov 2019 16:36:41 +0800 Subject: [PATCH 8/8] Fix npi error --- python/mxnet/_numpy_op_doc.py | 38 +++++++++++++++- python/mxnet/ndarray/numpy/_op.py | 39 +--------------- python/mxnet/numpy/multiarray.py | 39 +--------------- python/mxnet/symbol/numpy/_symbol.py | 39 +--------------- src/operator/numpy/np_matrix_op-inl.h | 65 +++++++++++++-------------- src/operator/numpy/np_matrix_op.cc | 6 +-- src/operator/numpy/np_matrix_op.cu | 4 +- 7 files changed, 76 insertions(+), 154 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 9e237cb0049b..cf991fc8949f 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1089,7 +1089,7 @@ def _npx_reshape(a, newshape, reverse=False, order='C'): pass -def _np_diag(array, k = 0): +def _np_diag(array, k=0): """ Extracts a diagonal or constructs a diagonal array. - 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero. @@ -1122,3 +1122,39 @@ def _np_diag(array, k = 0): [0, 0, 8]]) """ pass + + +def _np_diagflat(array, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + pass \ No newline at end of file diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index ec8cb42db5ca..3cc5b85c8384 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -29,7 +29,7 @@ from ..ndarray import NDArray __all__ = ['shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', - 'arctan2', 'sin', 'cos', 'tan', 'diagflat', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', + 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', @@ -1441,43 +1441,6 @@ def tanh(x, out=None, **kwargs): return _unary_func_helper(x, _npi.tanh, _np.tanh, out=out, **kwargs) -@set_module('mxnet.ndarray.numpy') -def diagflat(arr, k=0): - """ - Create a two-dimensional array with the flattened input as a diagonal. - Parameters - ---------- - arr : ndarray - Input data, which is flattened and set as the `k`-th - diagonal of the output. - k : int, optional - Diagonal to set; 0, the default, corresponds to the "main" diagonal, - a positive (negative) `k` giving the number of the diagonal above - (below) the main. - Returns - ------- - out : ndarray - The 2-D output array. - See Also - -------- - diag : MATLAB work-alike for 1-D and 2-D arrays. - diagonal : Return specified diagonals. - trace : Sum along diagonals. - Examples - -------- - >>> np.diagflat([[1,2], [3,4]]) - array([[1, 0, 0, 0], - [0, 2, 0, 0], - [0, 0, 3, 0], - [0, 0, 0, 4]]) - >>> np.diagflat([1,2], 1) - array([[0, 1, 0], - [0, 0, 2], - [0, 0, 0]]) - """ - return _npi.diagflat(arr, k=k) - - @set_module('mxnet.ndarray.numpy') @wrap_np_unary_func def log10(x, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 54811756376e..e94d4c8341b4 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -47,7 +47,7 @@ from ..ndarray.ndarray import _storage_type __all__ = ['ndarray', 'empty', 'array', 'shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'diagflat', 'sinh', 'cosh', 'tanh', 'log10', + 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', @@ -2981,43 +2981,6 @@ def tanh(x, out=None, **kwargs): return _mx_nd_np.tanh(x, out=out, **kwargs) -@set_module('mxnet.numpy') -def diagflat(arr, k=0): - """ - Create a two-dimensional array with the flattened input as a diagonal. - Parameters - ---------- - arr : ndarray - Input data, which is flattened and set as the `k`-th - diagonal of the output. - k : int, optional - Diagonal to set; 0, the default, corresponds to the "main" diagonal, - a positive (negative) `k` giving the number of the diagonal above - (below) the main. - Returns - ------- - out : ndarray - The 2-D output array. - See Also - -------- - diag : MATLAB work-alike for 1-D and 2-D arrays. - diagonal : Return specified diagonals. - trace : Sum along diagonals. - Examples - -------- - >>> np.diagflat([[1,2], [3,4]]) - array([[1, 0, 0, 0], - [0, 2, 0, 0], - [0, 0, 3, 0], - [0, 0, 0, 4]]) - >>> np.diagflat([1,2], 1) - array([[0, 1, 0], - [0, 0, 2], - [0, 0, 0]]) - """ - return _npi.diagflat(arr, k=k) - - @set_module('mxnet.numpy') @wrap_np_unary_func def log10(x, out=None, **kwargs): diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 27de718baada..7da771966f1f 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'diagflat', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', + 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', @@ -2270,43 +2270,6 @@ def radians(x, out=None, **kwargs): return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs) -@set_module('mxnet.symbol.numpy') -def diagflat(arr, k=0): - """ - Create a two-dimensional array with the flattened input as a diagonal. - Parameters - ---------- - arr : ndarray - Input data, which is flattened and set as the `k`-th - diagonal of the output. - k : int, optional - Diagonal to set; 0, the default, corresponds to the "main" diagonal, - a positive (negative) `k` giving the number of the diagonal above - (below) the main. - Returns - ------- - out : ndarray - The 2-D output array. - See Also - -------- - diag : MATLAB work-alike for 1-D and 2-D arrays. - diagonal : Return specified diagonals. - trace : Sum along diagonals. - Examples - -------- - >>> np.diagflat([[1,2], [3,4]]) - array([[1, 0, 0, 0], - [0, 2, 0, 0], - [0, 0, 3, 0], - [0, 0, 0, 4]]) - >>> np.diagflat([1,2], 1) - array([[0, 1, 0], - [0, 0, 2], - [0, 0, 0]]) - """ - return _npi.diagflat(arr, k=k) - - @set_module('mxnet.symbol.numpy') @wrap_np_unary_func def deg2rad(x, out=None, **kwargs): diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 7a2f089d7040..fee534315b77 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1153,9 +1153,10 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, struct NumpyDiagflatParam : public dmlc::Parameter { int k; DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { - DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. " - "Use k>0 for diagonals above the main diagonal, " - "and k<0 for diagonals below the main diagonal. "); + DMLC_DECLARE_FIELD(k).set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. "); } }; @@ -1165,10 +1166,10 @@ inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const i return mxnet::TShape({s, s}); } - if ( ishape.ndim() >= 2 ) { + if (ishape.ndim() >= 2) { auto s = 1; - for ( int i = 0; i < ishape.ndim(); i++ ) { - if ( ishape[i] >= 2 ) { + for (int i = 0; i < ishape.ndim(); i++) { + if (ishape[i] >= 2) { s = s * ishape[i]; } } @@ -1179,8 +1180,8 @@ inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const i } inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_attrs, - mxnet::ShapeVector* out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); @@ -1190,8 +1191,7 @@ inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, } const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); - mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, - param.k); + mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, param.k); if (shape_is_none(oshape)) { LOG(FATAL) << "Diagonal does not exist."; @@ -1202,8 +1202,8 @@ inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, } inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); @@ -1214,33 +1214,30 @@ inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, template void NumpyDiagflatOpImpl(const TBlob& in_data, - const TBlob& out_data, - const mxnet::TShape& ishape, - const mxnet::TShape& oshape, - index_t dsize, - const NumpyDiagflatParam& param, - mxnet_op::Stream *s, - const std::vector& req) { + const TBlob& out_data, + const mxnet::TShape& ishape, + const mxnet::TShape& oshape, + index_t dsize, + const NumpyDiagflatParam& param, + mxnet_op::Stream *s, + const std::vector& req) { using namespace mxnet_op; using namespace mshadow; MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, - dsize, - out_data.dptr(), - in_data.dptr(), - Shape2(oshape[0], oshape[1]), - param.k); + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape2(oshape[0], oshape[1]), param.k); }); }); } template void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; using namespace mshadow; CHECK_EQ(inputs.size(), 1U); @@ -1260,10 +1257,10 @@ void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, template void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; using namespace mshadow; CHECK_EQ(inputs.size(), 1U); @@ -1277,7 +1274,7 @@ void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); NumpyDiagflatOpImpl(in_data, out_data, oshape, - ishape, in_data.Size(), param, s, req); + ishape, in_data.Size(), param, s, req); } } // namespace op diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 5a6b533fd890..f07cb0e5f4b2 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1326,7 +1326,7 @@ NNVM_REGISTER_OP(_backward_np_diag) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyDiagOpBackward); -NNVM_REGISTER_OP(_npi_diagflat) +NNVM_REGISTER_OP(_np_diagflat) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) @@ -1337,11 +1337,11 @@ NNVM_REGISTER_OP(_npi_diagflat) .set_attr("FInferShape", NumpyDiagflatOpShape) .set_attr("FInferType", NumpyDiagflatOpType) .set_attr("FCompute", NumpyDiagflatOpForward) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_diagflat"}) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_diagflat"}) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyDiagflatParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_npi_diagflat) +NNVM_REGISTER_OP(_backward_np_diagflat) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 15439c402be0..6f292ab95802 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -124,10 +124,10 @@ NNVM_REGISTER_OP(_np_diag) NNVM_REGISTER_OP(_backward_np_diag) .set_attr("FCompute", NumpyDiagOpBackward); -NNVM_REGISTER_OP(_npi_diagflat) +NNVM_REGISTER_OP(_np_diagflat) .set_attr("FCompute", NumpyDiagflatOpForward); -NNVM_REGISTER_OP(_backward_npi_diagflat) +NNVM_REGISTER_OP(_backward_np_diagflat) .set_attr("FCompute", NumpyDiagflatOpBackward); } // namespace op