File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -79,5 +79,37 @@ def test_api_int64(self):
7979 self .api_case ()
8080
8181
82+ class TestLogsumexpAPI_ZeroSize (unittest .TestCase ):
83+ def setUp (self ):
84+ self .place = (
85+ paddle .CUDAPlace (0 )
86+ if paddle .base .core .is_compiled_with_cuda ()
87+ else paddle .CPUPlace ()
88+ )
89+
90+ def api_case (self ):
91+ self .x = np .random .uniform (- 1 , 1 , self .xshape ).astype (self .dtype )
92+ self .y = np .random .uniform (- 1 , 1 , self .yshape ).astype (self .dtype )
93+ out_ref = ref_logaddexp (self .x , self .y )
94+
95+ paddle .disable_static (self .place )
96+ x = paddle .to_tensor (self .x )
97+ y = paddle .to_tensor (self .y )
98+ x .stop_gradient = False
99+ y .stop_gradient = False
100+ out = paddle .logaddexp (x , y )
101+ np .testing .assert_allclose (out .numpy (), out_ref , atol = 1e-06 )
102+
103+ loss = paddle .sum (out )
104+ loss .backward ()
105+ np .testing .assert_allclose (x .grad .shape , x .shape )
106+
107+ def test_api (self ):
108+ self .xshape = [1 , 2 , 3 , 0 ]
109+ self .yshape = [1 , 2 , 3 , 1 ]
110+ self .dtype = np .float32
111+ self .api_case ()
112+
113+
82114if __name__ == '__main__' :
83115 unittest .main ()
You can’t perform that action at this time.
0 commit comments