diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h index 256fe334e971..d96b9cf44253 100644 --- a/src/operator/tensor/ravel.h +++ b/src/operator/tensor/ravel.h @@ -93,7 +93,7 @@ inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs, struct ravel_index { template - MSHADOW_XINLINE static void Map(int i, index_t N, index_t ndim, index_t *shape, + MSHADOW_XINLINE static void Map(index_t i, index_t N, index_t ndim, index_t *shape, DType *unravelled, DType *ravelled) { index_t ret = 0; #pragma unroll diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 0dfeda47385f..ee57f172c1c9 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -495,6 +495,14 @@ def check_spatial_transformer(): assert res.shape[1] == 536870912 assert res.shape[2] == 2 assert res.shape[3] == 6 + + def check_ravel(): + data = nd.random_normal(shape=(2, LARGE_TENSOR_SHAPE)) + shape = (2, 10) + + out = nd.ravel_multi_index(data=data, shape=shape) + + assert out.shape[0] == LARGE_TENSOR_SHAPE check_gluon_embedding() check_fully_connected() @@ -518,6 +526,7 @@ def check_spatial_transformer(): check_col2im() check_embedding() check_spatial_transformer() + check_ravel() def test_tensor():