@@ -428,6 +428,41 @@ def test_static_shape_take_along_axis(self):
428428 _ = static_f (x , ind , axis = 0 , broadcast = False )
429429
430430
431+ class TestTakeAlongAxis_ZeroSize (OpTest ):
432+ def setUp (self ):
433+ self .python_api = paddle .take_along_axis
434+ self .op_type = "take_along_axis"
435+ self .dtype = "float64"
436+ self .check_pir = True
437+
438+ x = np .zeros ((2 , 0 , 5 )).astype (self .dtype )
439+ indices = np .zeros ((2 , 3 , 5 )).astype ("int64" )
440+
441+ self .inputs = {'Input' : x , 'Index' : indices }
442+ self .attrs = {'Axis' : 1 }
443+
444+ output = np .zeros ((2 , 3 , 5 )).astype (self .dtype )
445+ self .outputs = {'Result' : output }
446+
447+ def test_check_output (self ):
448+ self .check_output_with_place (
449+ paddle .CPUPlace (), check_pir = self .check_pir
450+ )
451+ if core .is_compiled_with_cuda ():
452+ self .check_output_with_place (
453+ core .CUDAPlace (0 ), check_pir = self .check_pir
454+ )
455+
456+ def test_check_grad (self ):
457+ self .check_grad_with_place (
458+ paddle .CPUPlace (), ['Input' ], 'Result' , check_pir = self .check_pir
459+ )
460+ if core .is_compiled_with_cuda ():
461+ self .check_grad_with_place (
462+ core .CUDAPlace (0 ), ['Input' ], 'Result' , check_pir = self .check_pir
463+ )
464+
465+
431466if __name__ == "__main__" :
432467 paddle .enable_static ()
433468 unittest .main ()
0 commit comments