diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 9e237cb0049b..cf991fc8949f 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1089,7 +1089,7 @@ def _npx_reshape(a, newshape, reverse=False, order='C'): pass -def _np_diag(array, k = 0): +def _np_diag(array, k=0): """ Extracts a diagonal or constructs a diagonal array. - 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero. @@ -1122,3 +1122,39 @@ def _np_diag(array, k = 0): [0, 0, 8]]) """ pass + + +def _np_diagflat(array, k=0): + """ + Create a two-dimensional array with the flattened input as a diagonal. + Parameters + ---------- + arr : ndarray + Input data, which is flattened and set as the `k`-th + diagonal of the output. + k : int, optional + Diagonal to set; 0, the default, corresponds to the "main" diagonal, + a positive (negative) `k` giving the number of the diagonal above + (below) the main. + Returns + ------- + out : ndarray + The 2-D output array. + See Also + -------- + diag : MATLAB work-alike for 1-D and 2-D arrays. + diagonal : Return specified diagonals. + trace : Sum along diagonals. + Examples + -------- + >>> np.diagflat([[1,2], [3,4]]) + array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + >>> np.diagflat([1,2], 1) + array([[0, 1, 0], + [0, 0, 2], + [0, 0, 0]]) + """ + pass \ No newline at end of file diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index f58159303d0f..67af2724503e 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -94,6 +94,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'copy', 'cumsum', 'diag', + 'diagflat', 'dot', 'expand_dims', 'fix', diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 508968718af0..fee534315b77 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1150,6 +1150,133 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, in_data.Size(), param.k, s, req[0]); } +struct NumpyDiagflatParam : public dmlc::Parameter { + int k; + DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { + DMLC_DECLARE_FIELD(k).set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. "); + } +}; + +inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k) { + if (ishape.ndim() == 1) { + auto s = ishape[0] + std::abs(k); + return mxnet::TShape({s, s}); + } + + if (ishape.ndim() >= 2) { + auto s = 1; + for (int i = 0; i < ishape.ndim(); i++) { + if (ishape[i] >= 2) { + s = s * ishape[i]; + } + } + s = s + std::abs(k); + return mxnet::TShape({s, s}); + } + return mxnet::TShape({-1, -1}); +} + +inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape, param.k); + + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return shape_is_known(out_attrs->at(0)); +} + +inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +void NumpyDiagflatOpImpl(const TBlob& in_data, + const TBlob& out_data, + const mxnet::TShape& ishape, + const mxnet::TShape& oshape, + index_t dsize, + const NumpyDiagflatParam& param, + mxnet_op::Stream *s, + const std::vector& req) { + using namespace mxnet_op; + using namespace mshadow; + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape2(oshape[0], oshape[1]), param.k); + }); + }); +} + +template +void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + NumpyDiagflatOpImpl(in_data, out_data, ishape, + oshape, out_data.Size(), param, s, req); +} + +template +void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const mxnet::TShape& ishape = inputs[0].shape_; + const mxnet::TShape& oshape = outputs[0].shape_; + const NumpyDiagflatParam& param = nnvm::get(attrs.parsed); + + NumpyDiagflatOpImpl(in_data, out_data, oshape, + ishape, in_data.Size(), param, s, req); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 912b32c2e8fb..f07cb0e5f4b2 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -37,6 +37,7 @@ DMLC_REGISTER_PARAMETER(NumpyRot90Param); DMLC_REGISTER_PARAMETER(NumpyReshapeParam); DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); DMLC_REGISTER_PARAMETER(NumpyDiagParam); +DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, @@ -1325,5 +1326,27 @@ NNVM_REGISTER_OP(_backward_np_diag) .set_attr("TIsBackward", true) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_np_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NumpyDiagflatOpShape) +.set_attr("FInferType", NumpyDiagflatOpType) +.set_attr("FCompute", NumpyDiagflatOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_diagflat"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyDiagflatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_diagflat) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyDiagflatOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 33f5aab7717c..6f292ab95802 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -124,5 +124,11 @@ NNVM_REGISTER_OP(_np_diag) NNVM_REGISTER_OP(_backward_np_diag) .set_attr("FCompute", NumpyDiagOpBackward); +NNVM_REGISTER_OP(_np_diagflat) +.set_attr("FCompute", NumpyDiagflatOpForward); + +NNVM_REGISTER_OP(_backward_np_diagflat) +.set_attr("FCompute", NumpyDiagflatOpBackward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index e52d25239d22..088306ff01c8 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1201,6 +1201,34 @@ def _add_workload_nonzero(): OpArgMngr.add_workload('nonzero', np.array([False, False, False], dtype=np.bool_)) OpArgMngr.add_workload('nonzero', np.array([True, False, False], dtype=np.bool_)) +def _add_workload_diagflat(): + def get_mat(n): + data = _np.arange(n) + data = _np.add.outer(data,data) + return data + + A = np.array([[1,2],[3,4],[5,6]]) + vals = (100 * np.arange(5)).astype('l') + vals_c = (100 * np.array(get_mat(5)) + 1).astype('l') + vals_f = _np.array((100 * get_mat(5) + 1), order='F', dtype='l') + vals_f = np.array(vals_f) + + OpArgMngr.add_workload('diagflat', A, k=2) + OpArgMngr.add_workload('diagflat', A, k=1) + OpArgMngr.add_workload('diagflat', A, k=0) + OpArgMngr.add_workload('diagflat', A, k=-1) + OpArgMngr.add_workload('diagflat', A, k=-2) + OpArgMngr.add_workload('diagflat', A, k=-3) + OpArgMngr.add_workload('diagflat', vals, k=0) + OpArgMngr.add_workload('diagflat', vals, k=2) + OpArgMngr.add_workload('diagflat', vals, k=-2) + OpArgMngr.add_workload('diagflat', vals_c, k=0) + OpArgMngr.add_workload('diagflat', vals_c, k=2) + OpArgMngr.add_workload('diagflat', vals_c, k=-2) + OpArgMngr.add_workload('diagflat', vals_f, k=0) + OpArgMngr.add_workload('diagflat', vals_f, k=2) + OpArgMngr.add_workload('diagflat', vals_f, k=-2) + def _add_workload_shape(): OpArgMngr.add_workload('shape', np.random.uniform(size=())) @@ -1263,6 +1291,7 @@ def _prepare_workloads(): _add_workload_cumsum() _add_workload_ravel() _add_workload_diag() + _add_workload_diagflat() _add_workload_dot() _add_workload_expand_dims() _add_workload_fix() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ff9dd4119968..7dd165b5421f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4165,6 +4165,48 @@ def dbg(name, data): assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_diagflat(): + class TestDiagflat(HybridBlock): + def __init__(self, k=0): + super(TestDiagflat,self).__init__() + self._k = k + def hybrid_forward(self,F,a): + return F.np.diagflat(a, k=self._k) + shapes = [(2,),5 , (1,5), (2,2), (2,5), (3,3), (4,3),(4,4,5)] # test_shapes, remember to include zero-dim shape and zero-size shapes + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] # remember to include all meaningful data types for the operator + range_k = 6 + for hybridize,shape,dtype, in itertools.product([False,True],shapes,dtypes): + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + + for k in range(-range_k,range_k): + test_diagflat = TestDiagflat(k) + if hybridize: + test_diagflat.hybridize() + + x = np.random.uniform(-1.0,1.0, size = shape).astype(dtype) + x.attach_grad() + + np_out = _np.diagflat(x.asnumpy(), k) + with mx.autograd.record(): + mx_out = test_diagflat(x) + + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(),np_out,rtol = rtol, atol = atol) + + mx_out.backward() + # Code to get the reference backward value + np_backward = np.ones(shape) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.diagflat(x, k) + np_out = _np.diagflat(x.asnumpy(), k) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_rand():