Skip to content

Commit 6993fc3

Browse files
authored
[0-size Tensor No.188、231] Add 0-size Tensor support for paddle.nn.functional.local_response_norm (#73998)
* Fix * Fix * Fix * Fix * Fix * Fix
1 parent aa1091e commit 6993fc3

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3912,10 +3912,10 @@ def triplet_margin_with_distance_loss(
39123912

39133913
if (
39143914
not isinstance(positive_dist, paddle.pir.Value)
3915-
and not paddle.all(positive_dist > 0)
3915+
and not paddle.all(positive_dist >= 0)
39163916
) or (
39173917
not isinstance(negative_dist, paddle.pir.Value)
3918-
and not paddle.all(negative_dist > 0)
3918+
and not paddle.all(negative_dist >= 0)
39193919
):
39203920
raise ValueError(
39213921
"The positive distance or negative distance should be greater than 0, "

test/legacy_test/test_lrn_op.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,5 +371,37 @@ def test_static_fp16_gpu(self):
371371
np.testing.assert_array_equal(res[0].shape, input.shape)
372372

373373

374+
class TestLocalResponseNormAPI_ZeroSize(unittest.TestCase):
375+
def setUp(self):
376+
np.random.seed(123)
377+
self.places = get_places()
378+
379+
def check_dygraph(self, place):
380+
with base.dygraph.guard(place):
381+
in_np1 = np.random.random([0, 40, 40]).astype("float32")
382+
in_np2 = np.transpose(in_np1, (0, 2, 1))
383+
384+
in1 = paddle.to_tensor(in_np1)
385+
in1.stop_gradient = False
386+
in2 = paddle.to_tensor(in_np2)
387+
388+
res1 = paddle.nn.functional.local_response_norm(
389+
x=in1, size=5, data_format='NCL'
390+
)
391+
res2 = paddle.nn.functional.local_response_norm(
392+
x=in2, size=5, data_format='NLC'
393+
)
394+
395+
res2_tran = np.transpose(res2.numpy(), (0, 2, 1))
396+
np.testing.assert_allclose(res1.numpy(), res2_tran, rtol=1e-05)
397+
398+
res1.sum().backward()
399+
np.testing.assert_allclose(in1.grad.shape, in1.shape)
400+
401+
def test_dygraph(self):
402+
for place in self.places:
403+
self.check_dygraph(place)
404+
405+
374406
if __name__ == "__main__":
375407
unittest.main()

test/legacy_test/test_triplet_margin_with_distance_loss.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,5 +489,65 @@ def test_TripletMarginWithDistanceLoss_margin(self):
489489
paddle.enable_static()
490490

491491

492+
class TestTripletMarginWithDistanceLoss_ZeroSize(unittest.TestCase):
493+
def _test_dygraph(
494+
self,
495+
place,
496+
input,
497+
positive,
498+
negative,
499+
distance_function=None,
500+
margin=0.3,
501+
swap=False,
502+
reduction='mean',
503+
expected=None,
504+
):
505+
paddle.disable_static(place)
506+
input = paddle.to_tensor(input)
507+
input.stop_gradient = False
508+
positive = paddle.to_tensor(positive)
509+
negative = paddle.to_tensor(negative)
510+
511+
dy_res = call_TripletMaginDistanceLoss_functional(
512+
input=input,
513+
positive=positive,
514+
negative=negative,
515+
distance_function=distance_function,
516+
margin=margin,
517+
swap=swap,
518+
reduction=reduction,
519+
)
520+
dy_result = dy_res.numpy()
521+
np.testing.assert_allclose(dy_result, expected, rtol=1e-5, atol=1e-8)
522+
dy_res.sum().backward()
523+
np.testing.assert_allclose(input.grad.shape, input.shape)
524+
paddle.enable_static()
525+
526+
def test_TripletMarginDistanceLoss(self):
527+
shape = (5, 0)
528+
np.random.seed(1234)
529+
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
530+
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
531+
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
532+
533+
places = get_places()
534+
reduction = 'sum'
535+
for place in places:
536+
expected = calc_triplet_margin_distance_loss(
537+
input=input,
538+
positive=positive,
539+
negative=negative,
540+
reduction=reduction,
541+
)
542+
self._test_dygraph(
543+
place=place,
544+
input=input,
545+
positive=positive,
546+
negative=negative,
547+
reduction=reduction,
548+
expected=expected,
549+
)
550+
551+
492552
if __name__ == "__main__":
493553
unittest.main()

0 commit comments

Comments
 (0)