Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
replacing invalid input shapes with valid ones
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 10, 2019
1 parent f4654f6 commit dd28023
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6484,14 +6484,14 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape):
out = mx.sym.slice_axis(data, axis=axis, begin=begin, end=end)
assert (out.infer_shape_partial()[1][0] == expected_out_shape), out.infer_shape_partial()[1]

var1 = mx.sym.var(name="data", shape=(0, 20))
check_slice_partial_infer(var1, (None, None), (None, 10), [], (0, 10))
check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (0, 5))
check_slice_partial_infer(var1, (None, 3), (None, 10), [], (0, 7))
check_slice_partial_infer(var1, (None, 3), (5, 10), [], (0, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), [], (0, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (0, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (0, 3))
var1 = mx.sym.var(name="data", shape=(10, 20))
check_slice_partial_infer(var1, (None, None), (None, 10), [], (10, 10))
check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (10, 5))
check_slice_partial_infer(var1, (None, 3), (None, 10), [], (10, 7))
check_slice_partial_infer(var1, (None, 3), (5, 10), [], (5, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), [], (8, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (8, 7))
check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (8, 3))

var1 = mx.sym.var(name="data", shape=(10, 0))
check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0))
Expand Down

0 comments on commit dd28023

Please sign in to comment.