@@ -1028,5 +1028,63 @@ def init_dtype_type(self):
10281028 self .index_type_pd1 = "bool"
10291029
10301030
1031+ class TestIndexPutAPI_ZeroSize (unittest .TestCase ):
1032+ def setUp (self ):
1033+ self .init_dtype_type ()
1034+ self .setPlace ()
1035+
1036+ def init_dtype_type (self ):
1037+ self .dtype_np = np .float32
1038+ self .index_type_np = np .int64
1039+ self .x_shape = (10 , 0 )
1040+ self .indices_shapes = [[10 ]]
1041+ self .value_shape = [1 , 1 ]
1042+ self .dtype_pd = paddle .float32
1043+ self .index_type_pd = paddle .int64
1044+
1045+ def setPlace (self ):
1046+ self .place = []
1047+ if (
1048+ os .environ .get ('FLAGS_CI_both_cpu_and_gpu' , 'False' ).lower ()
1049+ in ['1' , 'true' , 'on' ]
1050+ or not paddle .is_compiled_with_cuda ()
1051+ ):
1052+ self .place .append ('cpu' )
1053+ if self .dtype_np is np .float16 :
1054+ self .place = []
1055+ if paddle .is_compiled_with_cuda ():
1056+ self .place .append ('gpu' )
1057+
1058+ def test_dygraph_forward (self ):
1059+ paddle .disable_static ()
1060+ for place in self .place :
1061+ paddle .device .set_device (place )
1062+ x_pd = paddle .randn (self .x_shape , dtype = self .dtype_pd )
1063+ x_np = x_pd .numpy ()
1064+ value_pd = paddle .randn (self .value_shape , dtype = self .dtype_pd )
1065+ value_np = value_pd .numpy ()
1066+ x_pd .stop_gradient = False
1067+ value_pd .stop_gradient = False
1068+ indices_pd = [
1069+ paddle .randn (indices_shape ).astype (dtype = self .index_type_pd )
1070+ for indices_shape in self .indices_shapes
1071+ ]
1072+ indices_np = [item .numpy () for item in indices_pd ]
1073+ indices_pd = tuple (indices_pd )
1074+ accumulate = False
1075+ ref_res = compute_index_put_ref (
1076+ x_np , indices_np , value_np , accumulate
1077+ )
1078+ pd_res = paddle .index_put (x_pd , indices_pd , value_pd , accumulate )
1079+ np .testing .assert_allclose (ref_res , pd_res .numpy (), atol = 1e-7 )
1080+
1081+ # check grad
1082+ pd_res .sum ().backward ()
1083+ np .testing .assert_allclose (x_pd .grad .shape , x_pd .shape )
1084+ np .testing .assert_allclose (
1085+ value_pd .grad .numpy (), np .zeros (value_pd .shape )
1086+ )
1087+
1088+
10311089if __name__ == '__main__' :
10321090 unittest .main ()
0 commit comments