From 1399bbd89572d4772914b69ea6b14477fc44682a Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 9 Oct 2019 13:31:19 -0700 Subject: [PATCH] fix index copy --- src/operator/contrib/index_copy-inl.h | 2 +- src/operator/contrib/index_copy.cc | 4 ++-- tests/nightly/test_large_array.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 9f78f0593ed1..35bfcd0e77b6 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -71,7 +71,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(0)[i], in_attrs->at(2)[i]); } } - // The the length of the fitrst dim of copied tensor + // The the length of the first dim of copied tensor // must equal to the size of index vector CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc index f272a8860d85..9a071c04b51c 100644 --- a/src/operator/contrib/index_copy.cc +++ b/src/operator/contrib/index_copy.cc @@ -28,12 +28,12 @@ namespace op { struct index_copy_fwd_cpu { template - static void Map(int i, + static void Map(index_t i, const DType* new_tensor, const IType* idx, DType* out_tensor, int dim_size) { - DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; + DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; const DType* new_ptr = new_tensor + i * dim_size; std::memcpy(out_ptr, new_ptr, sizeof(DType) * dim_size); } diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 580db98935d6..5674f9826510 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -601,7 +601,7 @@ def test_softmax_cross_entropy(): def test_index_copy(): x = mx.nd.zeros((LARGE_X, SMALL_Y)) t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y)) - index = mx.nd.array([LARGE_X - 1]) + index = mx.nd.array([LARGE_X - 1], dtype="int64") x = mx.nd.contrib.index_copy(x, index, t) assert x[-1][-1] == t[0][-1]