@@ -115,6 +115,52 @@ def initTestCase(self):
115115 self .log_target = True
116116
117117
118+ class TestKLDivLossOp_ZeroSize1 (TestKLDivLossOp ):
119+ def setUp (self ):
120+ self .initTestCase ()
121+ self .op_type = 'kldiv_loss'
122+ self .python_api = kl_div
123+ self .public_python_api = paddle .nn .functional .kl_div
124+ x = np .random .uniform (- 10 , 10 , self .x_shape ).astype ('float64' )
125+ target = np .random .uniform (- 10 , 10 , self .x_shape ).astype ('float64' )
126+
127+ self .attrs = {
128+ "reduction" : self .reduction ,
129+ "log_target" : self .log_target ,
130+ }
131+
132+ self .inputs = {
133+ 'X' : x ,
134+ 'Target' : target ,
135+ }
136+ loss = kldiv_loss (x , target , self .reduction , self .log_target )
137+ self .outputs = {'Loss' : loss .astype ('float64' )}
138+
139+ def initTestCase (self ):
140+ # return NAN
141+ self .x_shape = (0 , 2 , 7 , 7 )
142+ self .reduction = 'mean'
143+ self .log_target = False
144+
145+ def test_check_output (self ):
146+ self .check_output (check_pir = True , equal_nan = True )
147+
148+ def test_check_grad (self ):
149+ self .check_grad (
150+ ['X' ],
151+ 'Loss' ,
152+ no_grad_set = {"Target" },
153+ check_pir = True ,
154+ )
155+
156+
157+ class TestKLDivLossOp_ZeroSize2 (TestKLDivLossOp_ZeroSize1 ):
158+ def initTestCase (self ):
159+ self .x_shape = (0 , 2 , 7 , 7 )
160+ self .reduction = 'none'
161+ self .log_target = False
162+
163+
118164class TestKLDivLossDygraph (unittest .TestCase ):
119165 def run_kl_loss (self , reduction , shape = (5 , 20 ), log_target = False ):
120166 x = np .random .uniform (- 10 , 10 , shape ).astype ('float64' )
0 commit comments