Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1253] fix control_flow_op (#13555)
Browse files Browse the repository at this point in the history
* fix control_flow_op

* change type for M

* add test for sparse where op
  • Loading branch information
apeforest authored and TaoLv committed Dec 11, 2018
1 parent 75d1d4f commit 46a2990
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/operator/tensor/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct where {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
const DType* x, const DType* y) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i]));
}
Expand All @@ -64,7 +64,7 @@ struct where_csr {
// CType is condition data type
// i is for i-th row in the output
template<typename DType, typename CType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const IType* cond_idx,
const IType* cond_indptr, const CType* cond_data,
const nnvm::dim_t num_cols, const DType* x) {
using nnvm::dim_t;
Expand Down Expand Up @@ -92,8 +92,8 @@ struct where_batch {
// DType is the output data type
// CType is the condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
const DType* x, const DType* y, int M) {
MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
const DType* x, const DType* y, index_t M) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i]));
}
};
Expand All @@ -109,7 +109,7 @@ struct where_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond) {
KERNEL_ASSIGN(grad_out[i], req,
Expand All @@ -130,7 +130,7 @@ struct where_backward_csr {
// CType is condition data type
// IType is condition aux data type
template<typename DType, typename CType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond_data,
const IType* cond_idx,
Expand Down Expand Up @@ -161,9 +161,9 @@ struct where_batch_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond, int M) {
const CType* cond, index_t M) {
KERNEL_ASSIGN(grad_out[i], req,
((0 == cond[i/M])^negate)? grad_in[i] : static_cast<DType>(0));
}
Expand Down
11 changes: 11 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def test_Dense(ctx=mx.cpu(0)):
res.wait_to_read()
assert res.shape == (50000000, 100)

def test_where():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
res = nd.where(b > 100, a, b)
assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]

csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
res = nd.sparse.where(csr_cond, a, b)
assert np.sum(res[0].asnumpy() == 1) == b.shape[1]


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 46a2990

Please sign in to comment.