Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix diag code format
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Nov 16, 2019
1 parent 9006455 commit 94c86e4
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
1 change: 0 additions & 1 deletion python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,4 +1122,3 @@ def _np_diag(array, k = 0):
[0, 0, 8]])
"""
pass

12 changes: 6 additions & 6 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ void NumpyDiagOpImpl(const TBlob &in_data,
index_t dsize,
const int &k,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType> &req) {
const OpReqType &req) {
using namespace mxnet_op;
using namespace mshadow;
if (ishape.ndim() > 1) {
Expand All @@ -1084,8 +1084,8 @@ void NumpyDiagOpImpl(const TBlob &in_data,
}

MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (back && req[0] != kAddTo && req[0] != kNullOp) {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
if (back && req != kAddTo && req != kNullOp) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}

Expand All @@ -1096,7 +1096,7 @@ void NumpyDiagOpImpl(const TBlob &in_data,
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<diag_gen<req_type, back>, xpu>::Launch(
s, dsize, out_data.dptr<DType>(), in_data.dptr<DType>(),
Shape2(oshape[0], oshape[1]), k);
Expand Down Expand Up @@ -1125,7 +1125,7 @@ void NumpyDiagOpForward(const nnvm::NodeAttrs &attrs,
const NumpyDiagParam &param = nnvm::get<NumpyDiagParam>(attrs.parsed);

NumpyDiagOpImpl<xpu, false>(in_data, out_data, ishape, oshape,
out_data.Size(), param.k, s, req);
out_data.Size(), param.k, s, req[0]);
}

template <typename xpu>
Expand All @@ -1147,7 +1147,7 @@ void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs,
const NumpyDiagParam &param = nnvm::get<NumpyDiagParam>(attrs.parsed);

NumpyDiagOpImpl<xpu, true>(in_data, out_data, oshape, ishape,
in_data.Size(), param.k, s, req);
in_data.Size(), param.k, s, req[0]);
}

} // namespace op
Expand Down
3 changes: 1 addition & 2 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def get_workloads(name):


def _add_workload_diag():

def get_mat(n):
data = _np.arange(n)
data = _np.add.outer(data, data)
Expand All @@ -67,7 +66,7 @@ def get_mat(n):
A = np.array([[1, 2], [3, 4], [5, 6]])
vals = (100 * np.arange(5)).astype('l')
vals_c = (100 * np.array(get_mat(5)) + 1).astype('l')
vals_f = _np.array((100 * get_mat(5) + 1), order = 'F', dtype = 'l')
vals_f = _np.array((100 * get_mat(5) + 1), order ='F', dtype ='l')
vals_f = np.array(vals_f)

OpArgMngr.add_workload('diag', A, k= 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4510,7 +4510,7 @@ def hybrid_forward(self, F, a):
if hybridize:
test_diag.hybridize()

x = np.random.uniform(-1.0, 1.0, size=shape).astype(dtype)
x = np.random.uniform(-2.0, 2.0, size=shape).astype(dtype) if len(shape) != 0 else np.array(())
x.attach_grad()

np_out = _np.diag(x.asnumpy(), k)
Expand Down

0 comments on commit 94c86e4

Please sign in to comment.