Skip to content

Commit

Permalink
Fix a bug in where op with 1-D input (apache#12325)
Browse files Browse the repository at this point in the history
* Fix a bug in where op with 1-D input

* Add unit test
  • Loading branch information
apeforest authored and Roshrini committed Aug 24, 2018
1 parent 84b3b56 commit 396b6c3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/operator/tensor/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs,
return true;
} else if ((*in_attrs)[0].ndim() == 1) {
CHECK_EQ((*in_attrs)[0].Size(), static_cast<size_t>(tshape[0]));
return true;
}
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4507,6 +4507,14 @@ def test_invalid_shape():
y=mx.nd.array([[8,9],[10,11],[12,13]]),
condition=mx.nd.array([1,0])), MXNetError)

def test_1d_cond():
cond = mx.nd.array([1, 0, 1])
x = mx.nd.array([[2, 3], [4, 5], [6, 7]])
y = mx.nd.array([[7, 8], [9, 10], [10, 11]])
expect_out = np.array([[2, 3], [9, 10], [6, 7]])
out = mx.nd.where(cond, x, y).asnumpy()
assert(expect_out.all() == out.all())

test_where_helper((5, 9), True)
test_where_helper((5, 9), False)
test_where_helper((5, 7, 9), True)
Expand All @@ -4518,6 +4526,7 @@ def test_invalid_shape():
test_where_numeric_gradient((5, 7, 9), True)
test_where_numeric_gradient((5, 7, 9), False)
test_invalid_shape()
test_1d_cond()

@with_seed()
def test_new_softmax():
Expand Down

0 comments on commit 396b6c3

Please sign in to comment.