@@ -50,7 +50,7 @@ def np_medain_min(data, keepdims=False):
5050 return np_res + np .sum (np .isnan (data ).astype (data .dtype ) * data )
5151
5252
53- def np_medain_min_axis (data , axis = None , keepdims = False ):
53+ def np_median_min_axis (data , axis = None , keepdims = False ):
5454 data = copy .deepcopy (data )
5555 if axis is None :
5656 return np_medain_min (data , keepdims )
@@ -232,7 +232,7 @@ class TestMedianMin(unittest.TestCase):
232232 def static_single_test_median (self , lis_test ):
233233 paddle .enable_static ()
234234 x , axis , keepdims = lis_test
235- res_np = np_medain_min_axis (x , axis = axis , keepdims = keepdims )
235+ res_np = np_median_min_axis (x , axis = axis , keepdims = keepdims )
236236 main_program = paddle .static .Program ()
237237 startup_program = paddle .static .Program ()
238238 exe = paddle .static .Executor ()
@@ -245,7 +245,7 @@ def static_single_test_median(self, lis_test):
245245
246246 def dygraph_single_test_median (self , lis_test ):
247247 x , axis , keepdims = lis_test
248- res_np = np_medain_min_axis (x , axis = axis , keepdims = keepdims )
248+ res_np = np_median_min_axis (x , axis = axis , keepdims = keepdims )
249249 if axis is None :
250250 res_pd = paddle .median (
251251 paddle .to_tensor (x ), axis , keepdims , mode = 'min'
@@ -335,7 +335,7 @@ def test_float16(self):
335335 for keepdims in [False , True ]
336336 ]
337337 for axis , keepdims in lis_tests :
338- res_np = np_medain_min_axis (x , axis = axis , keepdims = keepdims )
338+ res_np = np_median_min_axis (x , axis = axis , keepdims = keepdims )
339339 if axis is None :
340340 res_pd = paddle .median (
341341 paddle .to_tensor (x ), axis , keepdims , mode = 'min'
@@ -357,5 +357,49 @@ def test_output_dtype(self):
357357 np .testing .assert_equal (res .numpy ().dtype , np .dtype (inp_dtype ))
358358
359359
360+ class TestMedianAvg_ZeroSize (unittest .TestCase ):
361+ def dygraph_single_test_median (self , lis_test ):
362+ x , axis , keepdims = lis_test
363+ res_np = np .median (x , axis = axis , keepdims = keepdims )
364+ x_pd = paddle .to_tensor (x )
365+ x_pd .stop_gradient = False
366+ res_pd = paddle .median (x_pd , axis , keepdims )
367+ np .testing .assert_allclose (res_pd .numpy (), res_np )
368+ paddle .sum (res_pd ).backward ()
369+ np .testing .assert_allclose (x_pd .grad .shape , x_pd .shape )
370+
371+ def test_median_dygraph (self ):
372+ paddle .disable_static ()
373+ h = 0
374+ w = 4
375+ l = 2
376+ x = np .arange (h * w * l ).reshape ([h , w , l ])
377+ self .dygraph_single_test_median ([x , 1 , False ])
378+
379+
380+ class TestMedianMin_ZeroSize (unittest .TestCase ):
381+
382+ def dygraph_single_test_median (self , lis_test ):
383+ x , axis , keepdims = lis_test
384+ res_np = np_median_min_axis (x , axis = axis , keepdims = keepdims )
385+ x_pd = paddle .to_tensor (x )
386+ x_pd .stop_gradient = False
387+ if axis is None :
388+ res_pd = paddle .median (x_pd , axis , keepdims , mode = 'min' )
389+ else :
390+ res_pd , _ = paddle .median (x_pd , axis , keepdims , mode = 'min' )
391+ np .testing .assert_allclose (res_pd .numpy (), res_np )
392+ paddle .sum (res_pd ).backward ()
393+ np .testing .assert_allclose (x_pd .grad .shape , x_pd .shape )
394+
395+ def test_median_dygraph (self ):
396+ paddle .disable_static ()
397+ h = 0
398+ w = 4
399+ l = 2
400+ x = np .arange (h * w * l ).reshape ([h , w , l ]).astype ("float32" )
401+ self .dygraph_single_test_median ([x , 1 , False ])
402+
403+
360404if __name__ == '__main__' :
361405 unittest .main ()
0 commit comments