From 96766a4f001b597311f1f3f3bf726100c12be7e4 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Thu, 7 Feb 2019 02:28:45 -0500 Subject: [PATCH] large op support --- src/operator/tensor/broadcast_reduce_op.h | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 1edcb5a74a77..bf06d6be2d56 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1172,18 +1172,18 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs, template struct pick { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, - const IType *idx, int M, int stride, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a, + const IType *idx, size_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace broadcast; - int j = static_cast(idx[i]); + index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; - else if (j >= M) j = M - 1; + else if (j >= static_cast(M) j = static_cast(M - 1; } else { - j = j % M; - j += (j < 0) ? M : 0; + j = j % static_cast(M; + j += (j < 0) ? static_cast(M : 0; } j = ravel(unravel(i, sshape), bshape) + j*stride; out[i] = a[j]; @@ -1194,18 +1194,18 @@ struct pick { template struct pick_grad { template - MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd, - const IType *idx, int M, int stride, + MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd, + const IType *idx, size_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { using namespace broadcast; - int j = static_cast(idx[i]); + index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; - else if (j >= M) j = M - 1; + else if (j >= static_cast(M) j = static_cast(M - 1; } else { - j = j % M; - j += (j < 0) ? M : 0; + j = j % static_cast(M; + j += (j < 0) ? static_cast(M : 0; } j = ravel(unravel(i, sshape), bshape) + j*stride; igrad[j] += ograd[i];