Skip to content

Commit

Permalink
[Fix] Refine focal loss (PaddlePaddle#1915)
Browse files Browse the repository at this point in the history
* Refine focal loss
* fix according comments
  • Loading branch information
juncaipeng authored Mar 30, 2022
1 parent fc714ac commit bd46c4e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 33 deletions.
6 changes: 1 addition & 5 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ def loss_computation(logits_list, labels, losses, edges=None):
loss_i = losses['types'][i]
coef_i = losses['coef'][i]

if loss_i.__class__.__name__ in ('BCELoss', 'FocalLoss'
) and loss_i.edge_label:
# If use edges as labels According to loss type.
loss_list.append(coef_i * loss_i(logits, edges))
elif loss_i.__class__.__name__ == 'MixedLoss':
if loss_i.__class__.__name__ == 'MixedLoss':
mixed_loss_list = loss_i(logits, labels)
for mixed_loss in mixed_loss_list:
loss_list.append(coef_i * mixed_loss)
Expand Down
2 changes: 1 addition & 1 deletion paddleseg/models/losses/binary_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
raise ValueError(
"if type of `weight` is str, it should equal to 'dynamic', but it is {}"
.format(self.weight))
elif isinstance(self.weight, paddle.Tensor):
elif not isinstance(self.weight, paddle.Tensor):
raise TypeError(
'The type of `weight` is wrong, it should be Tensor or str, but it is {}'
.format(type(self.weight)))
Expand Down
2 changes: 0 additions & 2 deletions paddleseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def forward(self, logit, label, semantic_weights=None):
logit = paddle.transpose(logit, [0, 2, 3, 1])
label = label.astype('int64')

# In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
# When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
loss = F.cross_entropy(
logit,
label,
Expand Down
72 changes: 47 additions & 25 deletions paddleseg/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,60 @@
@manager.LOSSES.add_component
class FocalLoss(nn.Layer):
"""
Focal Loss.
The implement of focal loss.
Code referenced from:
https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
The focal loss requires the label is 0 or 1 for now.
Args:
gamma (float): the coefficient of Focal Loss.
ignore_index (int64): Specifies a target value that is ignored
alpha (float, list, optional): The alpha of focal loss. alpha is the weight
of class 1, 1-alpha is the weight of class 0. Default: 0.25
gamma (float, optional): The gamma of Focal Loss. Default: 2.0
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
"""

def __init__(self, gamma=2.0, ignore_index=255, edge_label=False):
super(FocalLoss, self).__init__()
def __init__(self, alpha=0.25, gamma=2.0, ignore_index=255):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.edge_label = edge_label
self.EPS = 1e-10

def forward(self, logit, label):
logit = paddle.reshape(
logit, [logit.shape[0], logit.shape[1], -1]) # N,C,H,W => N,C,H*W
logit = paddle.transpose(logit, [0, 2, 1]) # N,C,H*W => N,H*W,C
logit = paddle.reshape(logit,
[-1, logit.shape[2]]) # N,H*W,C => N*H*W,C
label = paddle.reshape(label, [-1, 1])
range_ = paddle.arange(0, label.shape[0])
range_ = paddle.unsqueeze(range_, axis=-1)
label = paddle.cast(label, dtype='int64')
label = paddle.concat([range_, label], axis=-1)
logpt = F.log_softmax(logit)
logpt = paddle.gather_nd(logpt, label)

pt = paddle.exp(logpt.detach())
loss = -1 * (1 - pt)**self.gamma * logpt
loss = paddle.mean(loss)
return loss
"""
Forward computation.
Args:
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
(N, C, H, W), where C is number of classes.
label (Tensor): Label tensor, the data type is int64. Shape is (N, W, W),
where each value is 0 <= label[i] <= C-1.
Returns:
(Tensor): The average loss.
"""
assert logit.ndim == 4, "The ndim of logit should be 4."
assert logit.shape[1] == 2, "The channel of logit should be 2."
assert label.ndim == 3, "The ndim of label should be 3."

class_num = logit.shape[1] # class num is 2
logit = paddle.transpose(logit, [0, 2, 3, 1]) # N,C,H,W => N,H,W,C

mask = label != self.ignore_index # N,H,W
mask = paddle.unsqueeze(mask, 3)
mask = paddle.cast(mask, 'float32')
mask.stop_gradient = True

label = F.one_hot(label, class_num) # N,H,W,C
label = paddle.cast(label, logit.dtype)
label.stop_gradient = True

loss = F.sigmoid_focal_loss(
logit=logit,
label=label,
alpha=self.alpha,
gamma=self.gamma,
reduction='none')
loss = loss * mask
avg_loss = paddle.sum(loss) / (
paddle.sum(paddle.cast(mask != 0., 'int32')) * class_num + self.EPS)
return avg_loss

0 comments on commit bd46c4e

Please sign in to comment.