Skip to content

Commit de22ff6

Browse files
committed
Fix
1 parent 079053a commit de22ff6

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

test/legacy_test/test_nn_margin_rank_loss.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,77 @@ def test_case(self):
198198
create_test_case(margin, reduction)
199199

200200

201+
def create_test_case_zero_size(margin, reduction):
202+
class MarginRankingLossCls_ZeroSize(unittest.TestCase):
203+
def init_shape(self):
204+
self.x_shape = (1, 10)
205+
self.y_shape = (0, 10)
206+
self.label_shape = (1, 10)
207+
208+
def setUp(self):
209+
self.init_shape()
210+
self.x_data = np.random.rand(*self.x_shape).astype("float64")
211+
self.y_data = np.random.rand(*self.y_shape).astype("float64")
212+
self.label_data = np.random.choice(
213+
[-1, 1], size=self.label_shape
214+
).astype("float64")
215+
self.places = []
216+
if (
217+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
218+
in ['1', 'true', 'on']
219+
or not core.is_compiled_with_cuda()
220+
):
221+
self.places.append(base.CPUPlace())
222+
if core.is_compiled_with_cuda():
223+
self.places.append(paddle.CUDAPlace(0))
224+
225+
def run_dynamic_functional_api(self, place):
226+
paddle.disable_static(place)
227+
x = paddle.to_tensor(self.x_data)
228+
x.stop_gradient = False
229+
y = paddle.to_tensor(self.y_data)
230+
label = paddle.to_tensor(self.label_data)
231+
232+
result = paddle.nn.functional.margin_ranking_loss(
233+
x, y, label, margin, reduction
234+
)
235+
expected = calc_margin_rank_loss(
236+
self.x_data,
237+
self.y_data,
238+
self.label_data,
239+
margin=margin,
240+
reduction=reduction,
241+
)
242+
np.testing.assert_allclose(result.numpy(), expected, rtol=1e-05)
243+
loss = paddle.sum(result)
244+
loss.backward()
245+
np.testing.assert_allclose(x.grad.shape, x.shape)
246+
paddle.enable_static()
247+
248+
def test_case(self):
249+
for place in self.places:
250+
self.run_dynamic_functional_api(place)
251+
252+
cls_name = f"TestMarginRankLossCase_ZeroSize_{margin}_{reduction}"
253+
MarginRankingLossCls_ZeroSize.__name__ = cls_name
254+
globals()[cls_name] = MarginRankingLossCls_ZeroSize
255+
256+
class MarginRankingLossCls_ZeroSize2(MarginRankingLossCls_ZeroSize):
257+
def init_shape(self):
258+
self.x_shape = (0, 10)
259+
self.y_shape = (0, 10)
260+
self.label_shape = (0, 10)
261+
262+
cls_name = f"TestMarginRankLossCase_ZeroSize2_{margin}_{reduction}"
263+
MarginRankingLossCls_ZeroSize2.__name__ = cls_name
264+
globals()[cls_name] = MarginRankingLossCls_ZeroSize2
265+
266+
267+
for margin in [0.0, 0.2]:
268+
for reduction in ['none', 'mean', 'sum']:
269+
create_test_case_zero_size(margin, reduction)
270+
271+
201272
# test case the raise message
202273
class MarginRakingLossError(unittest.TestCase):
203274
paddle.enable_static()

0 commit comments

Comments
 (0)