diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8d8aeaca73e4..3ae61298de8e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -178,3 +178,5 @@ List of Contributors * [Aaron Markham](https://github.com/aaronmarkham) * [Sam Skalicky](https://github.com/samskalicky) * [Per Goncalves da Silva](https://github.com/perdasilva) +* [Zhijingcheng Yu](https://github.com/jasonyu1996) +* [Cheng-Che Lee](https://github.com/stu1130) diff --git a/src/operator/tensor/diag_op-inl.h b/src/operator/tensor/diag_op-inl.h index 3bc240f206b4..deab2569e489 100644 --- a/src/operator/tensor/diag_op-inl.h +++ b/src/operator/tensor/diag_op-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file diag_op-inl.h * \brief CPU Implementation of the diag op -* \author Istvan Fehervari +* \author Istvan Fehervari, Zhijingcheng Yu */ #ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_ @@ -30,33 +30,51 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../operator_common.h" #include "../elemwise_op_common.h" +#include "./broadcast_reduce_op.h" namespace mxnet { namespace op { struct DiagParam : public dmlc::Parameter { - dmlc::optional k; + int k; + int32_t axis1; + int32_t axis2; DMLC_DECLARE_PARAMETER(DiagParam) { DMLC_DECLARE_FIELD(k) - .set_default(dmlc::optional(0)) - .describe("Diagonal in question. The default is 0. " - "Use k>0 for diagonals above the main diagonal, " - "and k<0 for diagonals below the main diagonal. " - "If input has shape (S0 S1) k must be between -S0 and S1"); + .set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. " + "If input has shape (S0 S1) k must be between -S0 and S1"); + DMLC_DECLARE_FIELD(axis1) + .set_default(0) + .describe("The first axis of the sub-arrays of interest. " + "Ignored when the input is a 1-D array."); + DMLC_DECLARE_FIELD(axis2) + .set_default(1) + .describe("The second axis of the sub-arrays of interest. " + "Ignored when the input is a 1-D array."); } }; -inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) { +inline TShape DiagShapeImpl(const TShape& ishape, const int k, + const int32_t axis1, const int32_t axis2) { if (ishape.ndim() == 1) { auto s = ishape[0] + std::abs(k); return TShape({s, s}); } - auto h = ishape[0]; - auto w = ishape[1]; + int32_t x1 = CheckAxis(axis1, ishape.ndim()); + int32_t x2 = CheckAxis(axis2, ishape.ndim()); + + CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << x1; + + auto h = ishape[x1]; + auto w = ishape[x2]; if (k > 0) { w -= k; @@ -69,7 +87,24 @@ inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) { s = 0; } - return TShape({s}); + if (x1 > x2) { + std::swap(x1, x2); + } + + int32_t n_dim = static_cast(ishape.ndim()) - 1; + TShape oshape(n_dim); + + // remove axis1 and axis2 and append the new axis to the end + uint32_t idx = 0; + for (int32_t i = 0; i <= n_dim; ++i) { + if (i != x1 && i != x2) { + oshape[idx++] = ishape[i]; + } + } + + oshape[n_dim - 1] = s; + + return oshape; } inline bool DiagOpShape(const nnvm::NodeAttrs& attrs, @@ -79,12 +114,16 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const TShape& ishape = (*in_attrs)[0]; - if (ishape.ndim() == 0) return false; - if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d."; + if (ishape.ndim() == 0) { + return false; + } const DiagParam& param = nnvm::get(attrs.parsed); - TShape oshape = DiagShapeImpl(ishape, param.k.value()); + TShape oshape = DiagShapeImpl(ishape, + param.k, + param.axis1, + param.axis2); if (shape_is_none(oshape)) { LOG(FATAL) << "Diagonal does not exist."; } @@ -104,42 +143,144 @@ inline bool DiagOpType(const nnvm::NodeAttrs& attrs, return (*out_attrs)[0] != -1; } -template +template struct diag { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, - mshadow::Shape<2> ishape, int k) { + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, + mshadow::Shape oshape, + mshadow::Shape ishape, + index_t stride, index_t offset, + index_t base) { using namespace mxnet_op; - int j = 0; - if (k > 0) { - j = ravel(mshadow::Shape2(i, i + k), ishape); - } else if (k < 0) { - j = ravel(mshadow::Shape2(i - k, i), ishape); + index_t idx = i / base; + index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - idx * base); + if (back) { + KERNEL_ASSIGN(out[j], req, a[i]); } else { - j = ravel(mshadow::Shape2(i, i), ishape); + KERNEL_ASSIGN(out[i], req, a[j]); } - - KERNEL_ASSIGN(out[i], req, a[j]); } }; -template +template struct diag_gen { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, mshadow::Shape<2> oshape, int k) { using namespace mxnet_op; auto j = unravel(i, oshape); if (j[1] == (j[0] + k)) { auto l = j[0] < j[1] ? j[0] : j[1]; - KERNEL_ASSIGN(out[i], req, a[l]); - } else { + if (back) { + KERNEL_ASSIGN(out[l], req, a[i]); + } else { + KERNEL_ASSIGN(out[i], req, a[l]); + } + } else if (!back) { KERNEL_ASSIGN(out[i], req, static_cast(0)); } } }; +template +void DiagOpProcess(const TBlob& in_data, + const TBlob& out_data, + const TShape& ishape, + const TShape& oshape, + index_t dsize, + const DiagParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + using namespace mxnet_op; + using namespace mshadow; + if (ishape.ndim() > 1) { + // input : (leading + i, body + i, trailing) + uint32_t x1 = CheckAxis(param.axis1, ishape.ndim()); + uint32_t x2 = CheckAxis(param.axis2, ishape.ndim()); + + uint32_t idim = ishape.ndim(), odim = oshape.ndim(); + + uint32_t minx = x1, maxx = x2; + if (minx > maxx) { + std::swap(minx, maxx); + } + + // merges contiguous axes that are not separated + // by axis1 or axis2 since they can be directly + // mapped to the output and there is no need + // to distinguish them + // (After this the input will have no more than + // three axes, hence improving the rave and + // unravel efficiency) + + index_t oleading = 1, + obody = 1, + otrailing = 1; + + for (uint32_t i = 0; i < minx; ++i) { + oleading *= ishape[i]; + } + for (uint32_t i = minx + 1; i < maxx; ++i) { + obody *= ishape[i]; + } + for (uint32_t i = maxx + 1; i < idim; ++i) { + otrailing *= ishape[i]; + } + + index_t ileading = oleading, + ibody = obody * ishape[minx], + itrailing = otrailing * ishape[maxx]; + + index_t stride1 = itrailing * obody, + stride2 = otrailing; + // stride1 + stride2 is the stride for + // iterating over the diagonal in question + + if (x1 == maxx) { + std::swap(stride1, stride2); + } + + // the extra index offset introduced by k + index_t offset; + int k = param.k; + if (k > 0) { + offset = stride2 * k; + } else if (k < 0) { + offset = stride1 * -k; + } else { + offset = 0; + } + + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (back && req[0] != kAddTo && req[0] != kNullOp) { + out_data.FlatTo1D(s) = 0; + } + if (ileading == 1) { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), + in_data.dptr(), Shape2(obody, otrailing), + Shape2(ibody, itrailing), + stride1 + stride2, offset, oshape[odim - 1]); + } else { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), + in_data.dptr(), Shape3(oleading, obody, otrailing), + Shape3(ileading, ibody, itrailing), + stride1 + stride2, offset, oshape[odim - 1]); + } + }); + }); + } else { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, dsize, out_data.dptr(), + in_data.dptr(), Shape2(oshape[0], oshape[1]), + param.k); + }); + }); + } +} + template void DiagOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -159,21 +300,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[0].shape_; const DiagParam& param = nnvm::get(attrs.parsed); - if (ishape.ndim() == 2) { - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), - in_data.dptr(), Shape2(ishape[0], ishape[1]), param.k.value()); - }); - }); - } else { - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), - in_data.dptr(), Shape2(oshape[0], oshape[1]), param.k.value()); - }); - }); - } + DiagOpProcess(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req); } template @@ -194,23 +321,10 @@ void DiagOpBackward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[0].shape_; const DiagParam& param = nnvm::get(attrs.parsed); - if (oshape.ndim() == 2) { - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), - in_data.dptr(), Shape2(oshape[0], oshape[1]), param.k.value()); - }); - }); - } else { - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), - in_data.dptr(), Shape2(ishape[0], ishape[1]), param.k.value()); - }); - }); - } + DiagOpProcess(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); } + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc index 1ad3b8adc028..cd5be9d0fd5c 100644 --- a/src/operator/tensor/diag_op.cc +++ b/src/operator/tensor/diag_op.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file diag_op.cc * \brief -* \author Istvan Fehervari +* \author Istvan Fehervari, Zhijingcheng Yu */ #include "./diag_op-inl.h" @@ -36,9 +36,13 @@ NNVM_REGISTER_OP(diag) ``diag``'s behavior depends on the input array dimensions: -- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero -- 2-D arrays: returns elements in the diagonal as a new 1-D array -- N-D arrays: not supported yet +- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero. +- N-D arrays: extracts the diagonals of the sub-arrays with axes specified by ``axis1`` and ``axis2``. + The output shape would be decided by removing the axes numbered ``axis1`` and ``axis2`` from the + input shape and appending to the result a new axis with the size of the diagonals in question. + + For example, when the input shape is `(2, 3, 4, 5)`, ``axis1`` and ``axis2`` are 0 and 2 + respectively and ``k`` is 0, the resulting shape would be `(3, 5, 2)`. Examples:: @@ -65,6 +69,21 @@ Examples:: [1, 0, 0], [0, 2, 0]] + x = [[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]] + + diag(x) = [[1, 7], + [2, 8]] + + diag(x, k=1) = [[3], + [4]] + + diag(x, axis1=-2, axis2=-1) = [[1, 4], + [5, 8]] + )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2bf7e848850a..3c052ed66084 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4339,7 +4339,7 @@ def test_invalid_shape(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') where_sym = mx.sym.where(condition, x, y) - + assert_exception(lambda: where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]), y=mx.nd.array([[8,9],[10,11],[12,13]]), condition=mx.nd.array([1,0])), MXNetError) @@ -4982,7 +4982,7 @@ def _validate_sample_location(input_rois, input_offset, spatial_scale, pooled_w, trans_x = input_offset[roi_idx, class_id * 2, part_h, part_w] * trans_std trans_y = input_offset[roi_idx, class_id * 2 + 1, part_h, part_w] * trans_std bin_h_start, bin_w_start = ph * bin_size_h + roi_start_h, pw * bin_size_w + roi_start_w - + need_check = True while need_check: pass_check = True @@ -6812,6 +6812,50 @@ def test_diag(): diag_sym = mx.sym.diag(data=data, k=-1) check_numeric_gradient(diag_sym, [a_np]) + # Test 4d input + x1 = np.random.randint(3,9) + x2 = np.random.randint(3,9) + x3 = np.random.randint(3,9) + x4 = np.random.randint(3,9) + a_np = np.random.random((x1, x2, x3, x4)).astype(np.float32) + a = mx.nd.array(a_np).astype('float32') + + # k = 0, axis1=0, axis2=1 + r = mx.nd.diag(data=a, k=0, axis1=0, axis2=1) + assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=0, axis1=0, axis2=1)) + + # k = 1, axis1=1, axis2=0 + r = mx.nd.diag(data=a, k=1, axis1=1, axis2=0) + assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=1, axis1=1, axis2=0)) + + # k = -1 axis1=1, axis3=3 + r = mx.nd.diag(data=a, k=-1, axis1=1, axis2=3) + assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=-1, axis1=1, axis2=3)) + + # k = 2, axis1=-2, axis2=0 + r = mx.nd.diag(data=a, k=2, axis1=-2, axis2=0) + assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=2, axis1=-2, axis2=0)) + + # Test 4d backward, k=0, axis1=3, axis2=0 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=0, axis1=3, axis2=0) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 4d backward, k=1, axis1=1, axis2=2 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=1, axis1=1, axis2=2) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 4d backward, k=-1, axis1=2, axis2=0 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=-1, axis1=2, axis2=0) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 4d backward, k=-2, axis1=1, axis2=-1 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=-2, axis1=1, axis2=-1) + check_numeric_gradient(diag_sym, [a_np]) + @with_seed() def test_depthtospace(): def f(x, blocksize):