Skip to content

Commit

Permalink
[Large Tensor] Fix ravel_multi_index op (apache#17644)
Browse files Browse the repository at this point in the history
* Fixed dtype on i

* Added nightly test for ravel_multi_index
  • Loading branch information
connorgoggins committed Feb 26, 2020
1 parent 1906eff commit 13f5ad9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/operator/tensor/ravel.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs,

struct ravel_index {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, index_t N, index_t ndim, index_t *shape,
MSHADOW_XINLINE static void Map(index_t i, index_t N, index_t ndim, index_t *shape,
DType *unravelled, DType *ravelled) {
index_t ret = 0;
#pragma unroll
Expand Down
9 changes: 9 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ def check_spatial_transformer():
assert res.shape[1] == 536870912
assert res.shape[2] == 2
assert res.shape[3] == 6

def check_ravel():
data = nd.random_normal(shape=(2, LARGE_TENSOR_SHAPE))
shape = (2, 10)

out = nd.ravel_multi_index(data=data, shape=shape)

assert out.shape[0] == LARGE_TENSOR_SHAPE

check_gluon_embedding()
check_fully_connected()
Expand All @@ -518,6 +526,7 @@ def check_spatial_transformer():
check_col2im()
check_embedding()
check_spatial_transformer()
check_ravel()


def test_tensor():
Expand Down

0 comments on commit 13f5ad9

Please sign in to comment.