@@ -948,6 +948,46 @@ def test_gather_backward(self):
948948 np .testing .assert_allclose (res_list [0 ], res_list [1 ])
949949
950950
951+ class TestGatherOp_ZeroSize (OpTest ):
952+ def setUp (self ):
953+ self .op_type = "gather"
954+ self .python_api = paddle .gather
955+ self .public_python_api = paddle .gather
956+ self .config ()
957+ self .init_inputs_and_outputs ()
958+
959+ def test_check_output (self ):
960+ self .check_output (check_pir = True )
961+
962+ def test_check_grad (self ):
963+ self .check_grad (['X' ], 'Out' , check_pir = True )
964+
965+ def config (self ):
966+ self .x_shape = (3 , 0 , 4 )
967+ self .config_dtype ()
968+ self .index = [2 ]
969+ self .index_type = "int32"
970+
971+ def config_dtype (self ):
972+ self .x_type = "float64"
973+
974+ def init_inputs_and_outputs (self ):
975+ xnp = np .random .random (self .x_shape ).astype (self .x_type )
976+ self .inputs = {
977+ 'X' : xnp ,
978+ 'Index' : np .array (self .index ).astype (self .index_type ),
979+ }
980+ self .outputs = {'Out' : self .inputs ["X" ][self .inputs ["Index" ]]}
981+
982+
983+ class TestGatherOp_ZeroSize2 (TestGatherOp_ZeroSize ):
984+ def config (self ):
985+ self .x_shape = (10 , 20 )
986+ self .config_dtype ()
987+ self .index = [2 , 0 ]
988+ self .index_type = "int32"
989+
990+
951991if __name__ == "__main__" :
952992 paddle .enable_static ()
953993 unittest .main ()
0 commit comments