Skip to content

Commit 39c41a3

Browse files
committed
Fix
1 parent 1a07398 commit 39c41a3

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

test/legacy_test/test_gather_op.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,46 @@ def test_gather_backward(self):
948948
np.testing.assert_allclose(res_list[0], res_list[1])
949949

950950

951+
class TestGatherOp_ZeroSize(OpTest):
952+
def setUp(self):
953+
self.op_type = "gather"
954+
self.python_api = paddle.gather
955+
self.public_python_api = paddle.gather
956+
self.config()
957+
self.init_inputs_and_outputs()
958+
959+
def test_check_output(self):
960+
self.check_output(check_pir=True)
961+
962+
def test_check_grad(self):
963+
self.check_grad(['X'], 'Out', check_pir=True)
964+
965+
def config(self):
966+
self.x_shape = (3, 0, 4)
967+
self.config_dtype()
968+
self.index = [2]
969+
self.index_type = "int32"
970+
971+
def config_dtype(self):
972+
self.x_type = "float64"
973+
974+
def init_inputs_and_outputs(self):
975+
xnp = np.random.random(self.x_shape).astype(self.x_type)
976+
self.inputs = {
977+
'X': xnp,
978+
'Index': np.array(self.index).astype(self.index_type),
979+
}
980+
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
981+
982+
983+
class TestGatherOp_ZeroSize2(TestGatherOp_ZeroSize):
984+
def config(self):
985+
self.x_shape = (10, 20)
986+
self.config_dtype()
987+
self.index = [2, 0]
988+
self.index_type = "int32"
989+
990+
951991
if __name__ == "__main__":
952992
paddle.enable_static()
953993
unittest.main()

0 commit comments

Comments
 (0)