diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 07252963c874..9d0e8cf90817 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -46,7 +46,7 @@ struct where { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond, const DType* x, const DType* y) { KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i])); } @@ -64,7 +64,7 @@ struct where_csr { // CType is condition data type // i is for i-th row in the output template - MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const IType* cond_idx, const IType* cond_indptr, const CType* cond_data, const nnvm::dim_t num_cols, const DType* x) { using nnvm::dim_t; @@ -92,8 +92,8 @@ struct where_batch { // DType is the output data type // CType is the condition data type template - MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond, - const DType* x, const DType* y, int M) { + MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond, + const DType* x, const DType* y, index_t M) { KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i])); } }; @@ -109,7 +109,7 @@ struct where_backward { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, const CType* cond) { KERNEL_ASSIGN(grad_out[i], req, @@ -130,7 +130,7 @@ struct where_backward_csr { // CType is condition data type // IType is condition aux data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, const CType* cond_data, const IType* cond_idx, @@ -161,9 +161,9 @@ struct where_batch_backward { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, - const CType* cond, int M) { + const CType* cond, index_t M) { KERNEL_ASSIGN(grad_out[i], req, ((0 == cond[i/M])^negate)? grad_in[i] : static_cast(0)); } diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index a301362f2db7..696fdb1d4175 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -134,6 +134,17 @@ def test_Dense(ctx=mx.cpu(0)): res.wait_to_read() assert res.shape == (50000000, 100) +def test_where(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y)) + res = nd.where(b > 100, a, b) + assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] + + csr_cond = nd.sparse.cast_storage(b < 10, 'csr') + res = nd.sparse.where(csr_cond, a, b) + assert np.sum(res[0].asnumpy() == 1) == b.shape[1] + if __name__ == '__main__': import nose