Skip to content

Commit

Permalink
replace where with min and max
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored and chajchaj committed Jan 10, 2022
1 parent 3ab9ace commit e30150d
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,27 +1665,17 @@ def cross_entropy(input,
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode():
if not soft_label:
if soft_label == False:
valid_label = paddle.cast(
label != ignore_index, dtype=label.dtype) * label
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))

label_min = paddle.min(valid_label)
label_max = paddle.max(valid_label)
if label_min < 0:
raise ValueError("label should not out of bound, but got{}".
format(label_min))
if label_max >= input.shape[axis]:
raise ValueError("label should not out of bound, but got{}".
format(label_max))
if core.is_compiled_with_npu():
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
Expand Down Expand Up @@ -1842,7 +1832,6 @@ def cross_entropy(input,
valid_label = paddle.multiply(
paddle.cast(
label != ignore_index, dtype=label.dtype), label)

ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
Expand Down

0 comments on commit e30150d

Please sign in to comment.