Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei committed May 21, 2019
1 parent 78902b1 commit 25a9da0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,16 @@ def test_where_partial_shape():
y = mx.sym.Variable("y")
cond = mx.sym.Variable("cond")
where_op = mx.sym.where(cond, x, y)
where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
# condition must be fully known to infer shape
_, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
assert result == [()]
_, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2))
assert result == [()]
with mx.np_compat(True):
where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))

_, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))
assert result == [None]
_, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2))
assert result == [None]

if __name__ == "__main__":
import nose
Expand Down

0 comments on commit 25a9da0

Please sign in to comment.