From d7ac7ce79a0937218268775796c549c3208dd2e0 Mon Sep 17 00:00:00 2001 From: Minghao Liu <40382964+Tommliu@users.noreply.github.com> Date: Fri, 6 Dec 2019 11:40:23 +0800 Subject: [PATCH 1/2] numpy diagonal --- python/mxnet/ndarray/numpy/_op.py | 55 ++++- python/mxnet/numpy/multiarray.py | 57 ++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 34 ++- src/operator/numpy/np_matrix_op-inl.h | 216 ++++++++++++++++++ src/operator/numpy/np_matrix_op.cc | 23 ++ src/operator/numpy/np_matrix_op.cu | 6 + .../unittest/test_numpy_interoperability.py | 18 ++ tests/python/unittest/test_numpy_op.py | 79 +++++++ 9 files changed, 486 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index aed88eaaa56b..2899119684ea 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -39,7 +39,7 @@ 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'diagonal'] @set_module('mxnet.ndarray.numpy') def shape(a): @@ -5525,3 +5525,56 @@ def where(condition, x=None, y=None): return nonzero(condition) else: return _npi.where(condition, x, y, out=None) + + +@set_module('mxnet.ndarray.numpy') +def diagonal(a, offset=0, axis1=0, axis2=1): + """ + If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of + the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the + resulting array can be determined by removing axis1 and axis2 and appending an index to the + right equal to the size of the resulting diagonals. + + Parameters + ---------- + a : Symbol + Input data from which diagonal are taken. + offset: int, Optional + Offset of the diagonal from the main diagonal + axis1: int, Optional + Axis to be used as the first axis of the 2-D sub-arrays + axis2: int, Optional + Axis to be used as the second axis of the 2-D sub-arrays + + Returns + ------- + out : Symbol + Output result + + Raises + ------- + ValueError: If the dimension of a is less than 2. + + Examples + -------- + >>> a = np.arange(4).reshape(2,2) + >>> a + array([[0, 1], + [2, 3]]) + >>> np.diagonal(a) + array([0, 3]) + >>> np.diagonal(a, 1) + array([1]) + + >>> a = np.arange(8).reshape(2,2,2) + >>>a + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.diagonal(a, 0, 0, 1) + array([[0, 6], + [1, 7]]) + """ + return _npi.diagonal(a, k=offset, axis1=axis1, axis2=axis2) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 1c6873d342d1..39c3bb6bb761 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -57,7 +57,8 @@ 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', - 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', + 'diagonal'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -7511,3 +7512,57 @@ def where(condition, x=None, y=None): [ 0., 3., -1.]]) """ return _mx_nd_np.where(condition, x, y) + + +@set_module('mxnet.numpy') +def diagonal(a, offset=0, axis1=0, axis2=1): + """ + If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of + the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the + resulting array can be determined by removing axis1 and axis2 and appending an index to the + right equal to the size of the resulting diagonals. + + Parameters + ---------- + a : Symbol + Input data from which diagonal are taken. + offset: int, Optional + Offset of the diagonal from the main diagonal + axis1: int, Optional + Axis to be used as the first axis of the 2-D sub-arrays + axis2: int, Optional + Axis to be used as the second axis of the 2-D sub-arrays + + Returns + ------- + out : Symbol + Output result + + Raises + ------- + ValueError: If the dimension of a is less than 2. + + Examples + -------- + >>> a = np.arange(4).reshape(2,2) + >>> a + array([[0, 1], + [2, 3]]) + >>> np.diagonal(a) + array([0, 3]) + >>> np.diagonal(a, 1) + array([1]) + + >>> a = np.arange(8).reshape(2,2,2) + >>>a + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.diagonal(a, 0, 0, 1) + array([[0, 6], + [1, 7]]) + """ + return _mx_nd_np.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) + diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1f68ca3c522a..937741560c4e 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -94,6 +94,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'copy', 'cumsum', 'diag', + 'diagonal', 'diagflat', 'dot', 'expand_dims', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 683bdb1cb200..b7b1a3aa397c 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -47,7 +47,7 @@ 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'diagonal'] @set_module('mxnet.symbol.numpy') @@ -5161,4 +5161,36 @@ def load_json(json_str): return _Symbol(handle) +@set_module('mxnet.symbol.numpy') +def diagonal(a, offset=0, axis1=0, axis2=1): + """ + If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of + the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the + resulting array can be determined by removing axis1 and axis2 and appending an index to the + right equal to the size of the resulting diagonals. + + Parameters + ---------- + a : Symbol + Input data from which diagonal are taken. + offset: int, Optional + Offset of the diagonal from the main diagonal + axis1: int, Optional + Axis to be used as the first axis of the 2-D sub-arrays + axis2: int, Optional + Axis to be used as the second axis of the 2-D sub-arrays + + Returns + ------- + out : Symbol + Output result + + Raises + ------- + ValueError: If the dimension of a is less than 2. + """ + return _npi.diagonal(a, k=offset, axis1=axis1, axis2=axis2) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index fee534315b77..41c60442e763 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1150,6 +1150,222 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, in_data.Size(), param.k, s, req[0]); } +struct NumpyDiagonalParam : public dmlc::Parameter { + int k; + int32_t axis1; + int32_t axis2; + DMLC_DECLARE_PARAMETER(NumpyDiagonalParam) { + DMLC_DECLARE_FIELD(k) + .set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. " + "If input has shape (S0 S1) k must be between -S0 and S1"); + DMLC_DECLARE_FIELD(axis1) + .set_default(0) + .describe("The first axis of the sub-arrays of interest. " + "Ignored when the input is a 1-D array."); + DMLC_DECLARE_FIELD(axis2) + .set_default(1) + .describe("The second axis of the sub-arrays of interest. " + "Ignored when the input is a 1-D array."); + } +}; + +inline mxnet::TShape NumpyDiagonalShapeImpl(const mxnet::TShape& ishape, const int k, + const int32_t axis1, const int32_t axis2) { + int32_t x1 = CheckAxis(axis1, ishape.ndim()); + int32_t x2 = CheckAxis(axis2, ishape.ndim()); + + CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the same axis " << x1; + + auto h = ishape[x1]; + auto w = ishape[x2]; + if (k > 0) { + w -= k; + } else if (k < 0) { + h += k; + } + auto s = std::min(h, w); + if (s < 0) s = 0; + if (x1 > x2) std::swap(x1, x2); + + int32_t n_dim = ishape.ndim() - 1; + mxnet::TShape oshape(n_dim, -1); + + // remove axis1 and axis2 and append the new axis to the end + uint32_t idx = 0; + for (int i = 0; i <= n_dim; ++i) { + if (i != x1 && i != x2) { + oshape[idx++] = ishape[i]; + } + } + oshape[n_dim - 1] = s; + return oshape; +} + +inline bool NumpyDiagonalOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& ishape = (*in_attrs)[0]; + CHECK_GE(ishape.ndim(), 2) << "Input array should be at least 2d"; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + + const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); + mxnet::TShape oshape = NumpyDiagonalShapeImpl(ishape, param.k, param.axis1, + param.axis2); + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + return shape_is_known(out_attrs->at(0)); +} + +inline bool NumpyDiagonalOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +struct diag_n { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, + mshadow::Shape oshape, + mshadow::Shape ishape, + index_t stride, index_t offset, + index_t base) { + using namespace mxnet_op; + index_t idx = i / base; + index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - idx * base); + if (back) { + KERNEL_ASSIGN(out[j], req, a[i]); + } else { + KERNEL_ASSIGN(out[i], req, a[j]); + } + } +}; + +template +void NumpyDiagonalOpImpl(const TBlob& in_data, + const TBlob& out_data, + const mxnet::TShape& ishape, + const mxnet::TShape& oshape, + index_t dsize, + const NumpyDiagonalParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + using namespace mxnet_op; + using namespace mshadow; + uint32_t x1 = CheckAxis(param.axis1, ishape.ndim()); + uint32_t x2 = CheckAxis(param.axis2, ishape.ndim()); + uint32_t idim = ishape.ndim(), odim = oshape.ndim(); + uint32_t minx = x1, maxx = x2; + if (minx > maxx) std::swap(minx, maxx); + + index_t oleading = 1, + obody = 1, + otrailing = 1; + for (uint32_t i = 0; i < minx; ++i) { + oleading *= ishape[i]; + } + for (uint32_t i = minx + 1; i < maxx; ++i) { + obody *= ishape[i]; + } + for (uint32_t i = maxx + 1; i < idim; ++i) { + otrailing *= ishape[i]; + } + + index_t ileading = oleading, + ibody = obody * ishape[minx], + itrailing = otrailing * ishape[maxx]; + + index_t stride1 = itrailing * obody, + stride2 = otrailing; + // stride1 + stride2 is the stride for iterating over the diagonal + + if (x1 == maxx) std::swap(stride1, stride2); + index_t offset; + int k = param.k; + if (k > 0) { + offset = stride2 * k; + } else if (k < 0) { + offset = stride1 * -k; + } else { + offset = 0; + } // the extra index offset introduced by k + + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (back && req[0] != kAddTo && req[0] != kNullOp) { + out_data.FlatTo1D(s) = 0; + } + if (ileading == 1) { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), + in_data.dptr(), Shape2(obody, otrailing), Shape2(ibody, itrailing), + stride1 + stride2, offset, oshape[odim - 1]); + } else { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), + in_data.dptr(), Shape3(oleading, obody, otrailing), Shape3(ileading, ibody, itrailing), + stride1 + stride2, offset, oshape[odim - 1]); + } + }); + }); +} + +template +void NumpyDiagonalOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); + + NumpyDiagonalOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); +} + +template +void NumpyDiagonalOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); + + NumpyDiagonalOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); +} + struct NumpyDiagflatParam : public dmlc::Parameter { int k; DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index e496202a0b41..86fce6354811 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -37,6 +37,7 @@ DMLC_REGISTER_PARAMETER(NumpyRot90Param); DMLC_REGISTER_PARAMETER(NumpyReshapeParam); DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); DMLC_REGISTER_PARAMETER(NumpyDiagParam); +DMLC_REGISTER_PARAMETER(NumpyDiagonalParam); DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); @@ -1326,6 +1327,28 @@ NNVM_REGISTER_OP(_backward_np_diag) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_npi_diagonal) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NumpyDiagonalOpShape) +.set_attr("FInferType", NumpyDiagonalOpType) +.set_attr("FCompute", NumpyDiagonalOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_diagonal"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyDiagonalParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_diagonal) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyDiagonalOpBackward); + NNVM_REGISTER_OP(_np_diagflat) .set_attr_parser(ParamParser) .set_num_inputs(1) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 6f292ab95802..a5b4e66eb9ce 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -124,6 +124,12 @@ NNVM_REGISTER_OP(_np_diag) NNVM_REGISTER_OP(_backward_np_diag) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_npi_diagonal) +.set_attr("FCompute", NumpyDiagonalOpForward); + +NNVM_REGISTER_OP(_backward_npi_diagonal) +.set_attr("FCompute", NumpyDiagonalOpBackward); + NNVM_REGISTER_OP(_np_diagflat) .set_attr("FCompute", NumpyDiagflatOpForward); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 930ad5260430..c0a9f69cd170 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -95,6 +95,23 @@ def get_mat(n): OpArgMngr.add_workload('diag', vals_f, k=-2) +def _add_workload_diagonal(): + A = np.arange(12).reshape((3, 4)) + B = np.arange(8).reshape((2,2,2)) + + OpArgMngr.add_workload('diagonal', A) + OpArgMngr.add_workload('diagonal', A, offset=0) + OpArgMngr.add_workload('diagonal', A, offset=-1) + OpArgMngr.add_workload('diagonal', A, offset=1) + OpArgMngr.add_workload('diagonal', B, offset=0) + OpArgMngr.add_workload('diagonal', B, offset=1) + OpArgMngr.add_workload('diagonal', B, offset=-1) + OpArgMngr.add_workload('diagonal', B, 0, 1, 2) + OpArgMngr.add_workload('diagonal', B, 0, 0, 1) + OpArgMngr.add_workload('diagonal', B, offset=1, axis1=0, axis2=2) + OpArgMngr.add_workload('diagonal', B, 0, 2, 1) + + def _add_workload_concatenate(array_pool): OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']]) OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']], axis=1) @@ -1321,6 +1338,7 @@ def _prepare_workloads(): _add_workload_ravel() _add_workload_unravel_index() _add_workload_diag() + _add_workload_diagonal() _add_workload_diagflat() _add_workload_dot() _add_workload_expand_dims() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9b7f7036bcda..0e0c7ea22456 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4796,6 +4796,85 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + +@with_seed() +@use_np +def test_np_diagonal(): + class TestDiagonal(HybridBlock): + def __init__(self, k=0, axis1=0, axis2=1): + super(TestDiagonal, self).__init__() + self._k = k + self._axis1 = axis1 + self._axis2 = axis2 + + def hybrid_forward(self, F, a): + return F.np.diagonal(a, self._k, self._axis1, self._axis2) + + configs = [ + [(1, 5), (0, 1)], [(2, 2),(0, 1)], + [(2, 5), (0, 1)], [(5, 5), (0, 1)], + [(2, 2, 2), (0, 1)], [(2, 4, 4), (0, 2)], + [(3, 3, 3), (1, 2)], [(4, 8, 8), (1, 2)], + [(4, 4, 4, 4), (1, 2)], [(5, 6, 7, 8), (2, 3)], + [(6, 7, 8, 9, 10), (3, 4)] + ] + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + offsets = [0, 2, 4, 6] + combination = itertools.product([False, True], configs, dtypes, offsets) + for hybridize, config, dtype, k in combination: + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + shape = config[0] + axis = config[1] + axis1 = axis[0] + axis2 = axis[1] + x = np.random.uniform(-5.0, 5.0, size=shape).astype(dtype) + x.attach_grad() + test_diagonal = TestDiagonal(k, axis1, axis2) + if hybridize: + test_diagonal.hybridize() + np_out = _np.diagonal(x.asnumpy(), offset=k, axis1=axis[0], axis2=axis[1]) + with mx.autograd.record(): + mx_out = test_diagonal(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + # check backward function + mx_out.backward() + size_out = np_out.size + shape_out = np_out.shape + ndim = len(shape) + h = shape[axis1] + w = shape[axis2] + np_backward_slice = _np.zeros((h, w)) + np_backward = _np.zeros(shape) + if k > 0: + w -= k + else: + h += k + s = min(w, h) + if s > 0: + if k >= 0: + for i in range(s): + np_backward_slice[0+i][k+i] = 1 + else: + for i in range(s): + np_backward_slice[-k+i][0+i] = 1 + ileading = int(size_out/s) + array_temp = _np.array([np_backward_slice for i in range(ileading)]) + array_temp = array_temp.reshape(shape_out[:-1] + (shape[axis1], shape[axis2])) + axis_idx = [i for i in range(ndim-2)] + axis_idx[axis1:axis1] = [ndim - 2] + axis_idx[axis2:axis2] = [ndim - 1] + np_backward = _np.transpose(array_temp, tuple(axis_idx)) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.diagonal(x, k, axis[0], axis[1]) + np_out = _np.diagonal(x.asnumpy(), offset=k, axis1=axis[0], axis2=axis[1]) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_nan_to_num(): From 81062ceeffa7de2f07428cab13317dd825417a48 Mon Sep 17 00:00:00 2001 From: Minghao Liu <40382964+Tommliu@users.noreply.github.com> Date: Fri, 6 Dec 2019 13:05:44 +0800 Subject: [PATCH 2/2] diagonal fix --- python/mxnet/_numpy_op_doc.py | 54 ++++++++++++++++++++++++- python/mxnet/ndarray/numpy/_op.py | 55 +------------------------- python/mxnet/numpy/multiarray.py | 57 +-------------------------- python/mxnet/symbol/numpy/_symbol.py | 34 +--------------- src/operator/numpy/np_matrix_op-inl.h | 20 +++++----- src/operator/numpy/np_matrix_op.cc | 6 +-- src/operator/numpy/np_matrix_op.cu | 4 +- 7 files changed, 72 insertions(+), 158 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index cf991fc8949f..3deb27019bf9 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1124,6 +1124,58 @@ def _np_diag(array, k=0): pass +def _np_diagonal(a, offset=0, axis1=0, axis2=1): + """ + If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of + the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the + resulting array can be determined by removing axis1 and axis2 and appending an index to the + right equal to the size of the resulting diagonals. + + Parameters + ---------- + a : Symbol + Input data from which diagonal are taken. + offset: int, Optional + Offset of the diagonal from the main diagonal + axis1: int, Optional + Axis to be used as the first axis of the 2-D sub-arrays + axis2: int, Optional + Axis to be used as the second axis of the 2-D sub-arrays + + Returns + ------- + out : Symbol + Output result + + Raises + ------- + ValueError: If the dimension of a is less than 2. + + Examples + -------- + >>> a = np.arange(4).reshape(2,2) + >>> a + array([[0, 1], + [2, 3]]) + >>> np.diagonal(a) + array([0, 3]) + >>> np.diagonal(a, 1) + array([1]) + + >>> a = np.arange(8).reshape(2,2,2) + >>>a + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.diagonal(a, 0, 0, 1) + array([[0, 6], + [1, 7]]) + """ + pass + + def _np_diagflat(array, k=0): """ Create a two-dimensional array with the flattened input as a diagonal. @@ -1157,4 +1209,4 @@ def _np_diagflat(array, k=0): [0, 0, 2], [0, 0, 0]]) """ - pass \ No newline at end of file + pass diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 2899119684ea..aed88eaaa56b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -39,7 +39,7 @@ 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'diagonal'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] @set_module('mxnet.ndarray.numpy') def shape(a): @@ -5525,56 +5525,3 @@ def where(condition, x=None, y=None): return nonzero(condition) else: return _npi.where(condition, x, y, out=None) - - -@set_module('mxnet.ndarray.numpy') -def diagonal(a, offset=0, axis1=0, axis2=1): - """ - If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of - the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and - axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the - resulting array can be determined by removing axis1 and axis2 and appending an index to the - right equal to the size of the resulting diagonals. - - Parameters - ---------- - a : Symbol - Input data from which diagonal are taken. - offset: int, Optional - Offset of the diagonal from the main diagonal - axis1: int, Optional - Axis to be used as the first axis of the 2-D sub-arrays - axis2: int, Optional - Axis to be used as the second axis of the 2-D sub-arrays - - Returns - ------- - out : Symbol - Output result - - Raises - ------- - ValueError: If the dimension of a is less than 2. - - Examples - -------- - >>> a = np.arange(4).reshape(2,2) - >>> a - array([[0, 1], - [2, 3]]) - >>> np.diagonal(a) - array([0, 3]) - >>> np.diagonal(a, 1) - array([1]) - - >>> a = np.arange(8).reshape(2,2,2) - >>>a - array([[[0, 1], - [2, 3]], - [[4, 5], - [6, 7]]]) - >>> np.diagonal(a, 0, 0, 1) - array([[0, 6], - [1, 7]]) - """ - return _npi.diagonal(a, k=offset, axis1=axis1, axis2=axis2) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 39c3bb6bb761..1c6873d342d1 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -57,8 +57,7 @@ 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', - 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', - 'diagonal'] + 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -7512,57 +7511,3 @@ def where(condition, x=None, y=None): [ 0., 3., -1.]]) """ return _mx_nd_np.where(condition, x, y) - - -@set_module('mxnet.numpy') -def diagonal(a, offset=0, axis1=0, axis2=1): - """ - If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of - the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and - axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the - resulting array can be determined by removing axis1 and axis2 and appending an index to the - right equal to the size of the resulting diagonals. - - Parameters - ---------- - a : Symbol - Input data from which diagonal are taken. - offset: int, Optional - Offset of the diagonal from the main diagonal - axis1: int, Optional - Axis to be used as the first axis of the 2-D sub-arrays - axis2: int, Optional - Axis to be used as the second axis of the 2-D sub-arrays - - Returns - ------- - out : Symbol - Output result - - Raises - ------- - ValueError: If the dimension of a is less than 2. - - Examples - -------- - >>> a = np.arange(4).reshape(2,2) - >>> a - array([[0, 1], - [2, 3]]) - >>> np.diagonal(a) - array([0, 3]) - >>> np.diagonal(a, 1) - array([1]) - - >>> a = np.arange(8).reshape(2,2,2) - >>>a - array([[[0, 1], - [2, 3]], - [[4, 5], - [6, 7]]]) - >>> np.diagonal(a, 0, 0, 1) - array([[0, 6], - [1, 7]]) - """ - return _mx_nd_np.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) - diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index b7b1a3aa397c..683bdb1cb200 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -47,7 +47,7 @@ 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'diagonal'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] @set_module('mxnet.symbol.numpy') @@ -5161,36 +5161,4 @@ def load_json(json_str): return _Symbol(handle) -@set_module('mxnet.symbol.numpy') -def diagonal(a, offset=0, axis1=0, axis2=1): - """ - If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of - the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and - axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the - resulting array can be determined by removing axis1 and axis2 and appending an index to the - right equal to the size of the resulting diagonals. - - Parameters - ---------- - a : Symbol - Input data from which diagonal are taken. - offset: int, Optional - Offset of the diagonal from the main diagonal - axis1: int, Optional - Axis to be used as the first axis of the 2-D sub-arrays - axis2: int, Optional - Axis to be used as the second axis of the 2-D sub-arrays - - Returns - ------- - out : Symbol - Output result - - Raises - ------- - ValueError: If the dimension of a is less than 2. - """ - return _npi.diagonal(a, k=offset, axis1=axis1, axis2=axis2) - - _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 41c60442e763..c9a5545f9e15 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1151,11 +1151,11 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, } struct NumpyDiagonalParam : public dmlc::Parameter { - int k; + int offset; int32_t axis1; int32_t axis2; DMLC_DECLARE_PARAMETER(NumpyDiagonalParam) { - DMLC_DECLARE_FIELD(k) + DMLC_DECLARE_FIELD(offset) .set_default(0) .describe("Diagonal in question. The default is 0. " "Use k>0 for diagonals above the main diagonal, " @@ -1173,7 +1173,7 @@ struct NumpyDiagonalParam : public dmlc::Parameter { }; inline mxnet::TShape NumpyDiagonalShapeImpl(const mxnet::TShape& ishape, const int k, - const int32_t axis1, const int32_t axis2) { + const int32_t axis1, const int32_t axis2) { int32_t x1 = CheckAxis(axis1, ishape.ndim()); int32_t x2 = CheckAxis(axis2, ishape.ndim()); @@ -1217,7 +1217,7 @@ inline bool NumpyDiagonalOpShape(const nnvm::NodeAttrs& attrs, } const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); - mxnet::TShape oshape = NumpyDiagonalShapeImpl(ishape, param.k, param.axis1, + mxnet::TShape oshape = NumpyDiagonalShapeImpl(ishape, param.offset, param.axis1, param.axis2); if (shape_is_none(oshape)) { LOG(FATAL) << "Diagonal does not exist."; @@ -1296,7 +1296,7 @@ void NumpyDiagonalOpImpl(const TBlob& in_data, if (x1 == maxx) std::swap(stride1, stride2); index_t offset; - int k = param.k; + int k = param.offset; if (k > 0) { offset = stride2 * k; } else if (k < 0) { @@ -1316,8 +1316,8 @@ void NumpyDiagonalOpImpl(const TBlob& in_data, stride1 + stride2, offset, oshape[odim - 1]); } else { Kernel, xpu>::Launch(s, dsize, out_data.dptr(), - in_data.dptr(), Shape3(oleading, obody, otrailing), Shape3(ileading, ibody, itrailing), - stride1 + stride2, offset, oshape[odim - 1]); + in_data.dptr(), Shape3(oleading, obody, otrailing), + Shape3(ileading, ibody, itrailing), stride1 + stride2, offset, oshape[odim - 1]); } }); }); @@ -1342,7 +1342,8 @@ void NumpyDiagonalOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = outputs[0].shape_; const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); - NumpyDiagonalOpImpl(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); + NumpyDiagonalOpImpl(in_data, out_data, ishape, oshape, + out_data.Size(), param, s, req); } template @@ -1363,7 +1364,8 @@ void NumpyDiagonalOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = outputs[0].shape_; const NumpyDiagonalParam& param = nnvm::get(attrs.parsed); - NumpyDiagonalOpImpl(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); + NumpyDiagonalOpImpl(in_data, out_data, oshape, ishape, + in_data.Size(), param, s, req); } struct NumpyDiagflatParam : public dmlc::Parameter { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 86fce6354811..8227aa748af7 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -1327,7 +1327,7 @@ NNVM_REGISTER_OP(_backward_np_diag) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyDiagOpBackward); -NNVM_REGISTER_OP(_npi_diagonal) +NNVM_REGISTER_OP(_np_diagonal) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) @@ -1338,11 +1338,11 @@ NNVM_REGISTER_OP(_npi_diagonal) .set_attr("FInferShape", NumpyDiagonalOpShape) .set_attr("FInferType", NumpyDiagonalOpType) .set_attr("FCompute", NumpyDiagonalOpForward) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_diagonal"}) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_diagonal"}) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyDiagonalParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_npi_diagonal) +NNVM_REGISTER_OP(_backward_np_diagonal) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index a5b4e66eb9ce..10ff0eac2c29 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -124,10 +124,10 @@ NNVM_REGISTER_OP(_np_diag) NNVM_REGISTER_OP(_backward_np_diag) .set_attr("FCompute", NumpyDiagOpBackward); -NNVM_REGISTER_OP(_npi_diagonal) +NNVM_REGISTER_OP(_np_diagonal) .set_attr("FCompute", NumpyDiagonalOpForward); -NNVM_REGISTER_OP(_backward_npi_diagonal) +NNVM_REGISTER_OP(_backward_np_diagonal) .set_attr("FCompute", NumpyDiagonalOpBackward); NNVM_REGISTER_OP(_np_diagflat)