diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index a8091e276862..2bf7e8bf9d71 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -235,10 +235,16 @@ def test_where_partial_shape(): y = mx.sym.Variable("y") cond = mx.sym.Variable("cond") where_op = mx.sym.where(cond, x, y) - where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2)) + # condition must be fully known to infer shape + _, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2)) + assert result == [()] + _, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2)) + assert result == [()] with mx.np_compat(True): - where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2)) - + _, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2)) + assert result == [None] + _, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2)) + assert result == [None] if __name__ == "__main__": import nose