Skip to content

Commit

Permalink
Fix nd.pick large array issue (apache#14082)
Browse files Browse the repository at this point in the history
* large op support

* replaced size_t with index_t for M, added test case

* changed shape
  • Loading branch information
ChaiBapchya authored and stephenrawls committed Feb 16, 2019
1 parent b1e5f93 commit d207d21
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1172,12 +1172,12 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
template<int ndim, bool clip = true>
struct pick {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
const IType *idx, int M, int stride,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
const IType *idx, index_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
Expand All @@ -1194,12 +1194,12 @@ struct pick {
template<int ndim, bool clip = true>
struct pick_grad {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
const IType *idx, int M, int stride,
MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd,
const IType *idx, index_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
Expand Down
5 changes: 5 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def test_where():
res = nd.sparse.where(csr_cond, a, b)
assert np.sum(res[0].asnumpy() == 1) == b.shape[1]

def test_pick():
a = mx.nd.ones(shape=(256*35, 1024*1024))
b = mx.nd.ones(shape=(256*35,))
res = mx.nd.pick(a,b)
assert res.shape == b.shape

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit d207d21

Please sign in to comment.