From 94c86e487b69ab613dbce816469d984f55b0b44e Mon Sep 17 00:00:00 2001 From: Minghao Liu <40382964+Tommliu@users.noreply.github.com> Date: Sat, 16 Nov 2019 11:14:28 +0800 Subject: [PATCH] Fix diag code format --- python/mxnet/_numpy_op_doc.py | 1 - src/operator/numpy/np_matrix_op-inl.h | 12 ++++++------ tests/python/unittest/test_numpy_interoperability.py | 3 +-- tests/python/unittest/test_numpy_op.py | 2 +- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index fc8fac19d7ad..9e237cb0049b 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1122,4 +1122,3 @@ def _np_diag(array, k = 0): [0, 0, 8]]) """ pass - diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index a98a582c9ad8..508968718af0 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -1065,7 +1065,7 @@ void NumpyDiagOpImpl(const TBlob &in_data, index_t dsize, const int &k, mxnet_op::Stream *s, - const std::vector &req) { + const OpReqType &req) { using namespace mxnet_op; using namespace mshadow; if (ishape.ndim() > 1) { @@ -1084,8 +1084,8 @@ void NumpyDiagOpImpl(const TBlob &in_data, } MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - if (back && req[0] != kAddTo && req[0] != kNullOp) { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + if (back && req != kAddTo && req != kNullOp) { out_data.FlatTo1D(s) = 0; } @@ -1096,7 +1096,7 @@ void NumpyDiagOpImpl(const TBlob &in_data, }); } else { MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { Kernel, xpu>::Launch( s, dsize, out_data.dptr(), in_data.dptr(), Shape2(oshape[0], oshape[1]), k); @@ -1125,7 +1125,7 @@ void NumpyDiagOpForward(const nnvm::NodeAttrs &attrs, const NumpyDiagParam ¶m = nnvm::get(attrs.parsed); NumpyDiagOpImpl(in_data, out_data, ishape, oshape, - out_data.Size(), param.k, s, req); + out_data.Size(), param.k, s, req[0]); } template @@ -1147,7 +1147,7 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, const NumpyDiagParam ¶m = nnvm::get(attrs.parsed); NumpyDiagOpImpl(in_data, out_data, oshape, ishape, - in_data.Size(), param.k, s, req); + in_data.Size(), param.k, s, req[0]); } } // namespace op diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 94783defe83b..486722d15057 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -58,7 +58,6 @@ def get_workloads(name): def _add_workload_diag(): - def get_mat(n): data = _np.arange(n) data = _np.add.outer(data, data) @@ -67,7 +66,7 @@ def get_mat(n): A = np.array([[1, 2], [3, 4], [5, 6]]) vals = (100 * np.arange(5)).astype('l') vals_c = (100 * np.array(get_mat(5)) + 1).astype('l') - vals_f = _np.array((100 * get_mat(5) + 1), order = 'F', dtype = 'l') + vals_f = _np.array((100 * get_mat(5) + 1), order ='F', dtype ='l') vals_f = np.array(vals_f) OpArgMngr.add_workload('diag', A, k= 2) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 003962ba28c0..39f462379940 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4510,7 +4510,7 @@ def hybrid_forward(self, F, a): if hybridize: test_diag.hybridize() - x = np.random.uniform(-1.0, 1.0, size=shape).astype(dtype) + x = np.random.uniform(-2.0, 2.0, size=shape).astype(dtype) if len(shape) != 0 else np.array(()) x.attach_grad() np_out = _np.diag(x.asnumpy(), k)