Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
sanity fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Dec 6, 2019
1 parent d8f99f7 commit 58c2a29
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
11 changes: 5 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7517,10 +7517,10 @@ def where(condition, x=None, y=None):
@set_module('mxnet.numpy')
def diagonal(a, offset=0, axis1=0, axis2=1):
"""
If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of
the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the
resulting array can be determined by removing axis1 and axis2 and appending an index to the
If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of
the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the
resulting array can be determined by removing axis1 and axis2 and appending an index to the
right equal to the size of the resulting diagonals.
Parameters
Expand Down Expand Up @@ -7560,9 +7560,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.diagonal(a, 0, 0, 1)
>>> np.diagonal(a, 0, 0, 1)
array([[0, 6],
[1, 7]])
"""
return _mx_nd_np.diagonal(a, offset=offset, axis1=axis1, axis2=axis2)

10 changes: 6 additions & 4 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,8 @@ void NumpyDiagonalOpImpl(const TBlob& in_data,
stride1 + stride2, offset, oshape[odim - 1]);
} else {
Kernel<diag_n<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape3(oleading, obody, otrailing), Shape3(ileading, ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
in_data.dptr<DType>(), Shape3(oleading, obody, otrailing),
Shape3(ileading, ibody, itrailing), stride1 + stride2, offset, oshape[odim - 1]);
}
});
});
Expand All @@ -1342,7 +1342,8 @@ void NumpyDiagonalOpForward(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyDiagonalParam& param = nnvm::get<NumpyDiagonalParam>(attrs.parsed);

NumpyDiagonalOpImpl<xpu, false>(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req);
NumpyDiagonalOpImpl<xpu, false>(in_data, out_data, ishape, oshape,
out_data.Size(), param, s, req);
}

template<typename xpu>
Expand All @@ -1363,7 +1364,8 @@ void NumpyDiagonalOpBackward(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyDiagonalParam& param = nnvm::get<NumpyDiagonalParam>(attrs.parsed);

NumpyDiagonalOpImpl<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req);
NumpyDiagonalOpImpl<xpu, true>(in_data, out_data, oshape, ishape,
in_data.Size(), param, s, req);
}

struct NumpyDiagflatParam : public dmlc::Parameter<NumpyDiagflatParam> {
Expand Down

0 comments on commit 58c2a29

Please sign in to comment.