Skip to content

Commit 85bc32d

Browse files
authored
[Accuracy diff No.125] Fix accuracy diff for paddle.nn.functional.multi_margin_loss API (#73739)
1 parent 31656c9 commit 85bc32d

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4170,8 +4170,9 @@ def multi_margin_loss(
41704170
)
41714171
weight = paddle.gather(weight, label, axis=0).reshape((-1, 1))
41724172
loss = paddle.mean(
4173-
paddle.pow(
4174-
paddle.clip(weight * (margin - index_sample + input), min=0.0),
4173+
weight
4174+
* paddle.pow(
4175+
paddle.clip((margin - index_sample + input), min=0.0),
41754176
p,
41764177
),
41774178
axis=1,

test/legacy_test/test_multimarginloss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def calc_multi_margin_loss(
224224
[weight[label[i]] for i in range(label.size)]
225225
).reshape(-1, 1)
226226
expected = np.mean(
227-
np.maximum(weight * (margin + input - index_sample), 0.0) ** p,
227+
weight * (np.maximum((margin + input - index_sample), 0.0) ** p),
228228
axis=1,
229229
) - weight * (margin**p / input.shape[1])
230230

0 commit comments

Comments
 (0)