From 66a90adda8b7aee3e9ee8a9c79ce9cdefa1220e2 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Sat, 16 Feb 2019 00:01:30 -0500 Subject: [PATCH] Fix nd.pick large array issue (#14082) * large op support * replaced size_t with index_t for M, added test case * changed shape --- src/operator/tensor/broadcast_reduce_op.h | 12 ++++++------ tests/nightly/test_large_array.py | 5 +++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 1edcb5a74a77..6aeeadfe820d 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1172,12 +1172,12 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs, template struct pick { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, - const IType *idx, int M, int stride, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, + const IType *idx, index_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace broadcast; - int j = static_cast(idx[i]); + index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; else if (j >= M) j = M - 1; @@ -1194,12 +1194,12 @@ struct pick { template struct pick_grad { template - MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd, - const IType *idx, int M, int stride, + MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd, + const IType *idx, index_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace broadcast; - int j = static_cast(idx[i]); + index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; else if (j >= M) j = M - 1; diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 696fdb1d4175..0249f44932c2 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -145,6 +145,11 @@ def test_where(): res = nd.sparse.where(csr_cond, a, b) assert np.sum(res[0].asnumpy() == 1) == b.shape[1] +def test_pick(): + a = mx.nd.ones(shape=(256*35, 1024*1024)) + b = mx.nd.ones(shape=(256*35,)) + res = mx.nd.pick(a,b) + assert res.shape == b.shape if __name__ == '__main__': import nose