Skip to content

Commit d724a95

Browse files
committed
add sigmoid/softmax support and multi-class extension for AsymmetricUnifiedFocalLoss
Signed-off-by: ytl0623 <[email protected]>
1 parent 9a45627 commit d724a95

File tree

2 files changed

+75
-45
lines changed

2 files changed

+75
-45
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515

1616
import torch
17+
import torch.nn.functional as F
1718
from torch.nn.modules.loss import _Loss
1819

1920
from monai.networks import one_hot
@@ -24,7 +25,9 @@ class AsymmetricFocalTverskyLoss(_Loss):
2425
"""
2526
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
2627
27-
Actually, it's only supported for binary image segmentation now.
28+
It supports both binary and multi-class segmentation.
29+
30+
The logic assumes channel 0 is Background, and channels 1..N are Foreground.
2831
2932
Reimplementation of the Asymmetric Focal Tversky Loss described in:
3033
@@ -35,6 +38,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3538
def __init__(
3639
self,
3740
to_onehot_y: bool = False,
41+
use_softmax: bool = False,
3842
delta: float = 0.7,
3943
gamma: float = 0.75,
4044
epsilon: float = 1e-7,
@@ -43,17 +47,25 @@ def __init__(
4347
"""
4448
Args:
4549
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
50+
use_softmax: whether to use softmax to transform the original logits into probabilities.
51+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
4652
delta : weight of the background. Defaults to 0.7.
4753
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
4854
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
4955
"""
5056
super().__init__(reduction=LossReduction(reduction).value)
5157
self.to_onehot_y = to_onehot_y
58+
self.use_softmax = use_softmax
5259
self.delta = delta
5360
self.gamma = gamma
5461
self.epsilon = epsilon
5562

5663
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
64+
if self.use_softmax:
65+
y_pred = torch.softmax(y_pred, dim=1)
66+
else:
67+
y_pred = torch.sigmoid(y_pred)
68+
5769
n_pred_ch = y_pred.shape[1]
5870

5971
if self.to_onehot_y:
@@ -67,17 +79,23 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6779

6880
# clip the prediction to avoid NaN
6981
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
70-
axis = list(range(2, len(y_pred.shape)))
82+
83+
spatial_dims = list(range(2, len(y_pred.shape)))
7184

7285
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
73-
tp = torch.sum(y_true * y_pred, dim=axis)
74-
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
75-
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
86+
tp = torch.sum(y_true * y_pred, dim=spatial_dims)
87+
fn = torch.sum(y_true * (1 - y_pred), dim=spatial_dims)
88+
fp = torch.sum((1 - y_true) * y_pred, dim=spatial_dims)
7689
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
7790

7891
# Calculate losses separately for each class, enhancing both classes
7992
back_dice = 1 - dice_class[:, 0]
80-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
93+
fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)
94+
95+
if fore_dice.shape[1] > 1:
96+
fore_dice = torch.mean(fore_dice, dim=1)
97+
else:
98+
fore_dice = fore_dice.squeeze(1)
8199

82100
# Average class scores
83101
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
@@ -88,7 +106,7 @@ class AsymmetricFocalLoss(_Loss):
88106
"""
89107
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90108
91-
Actually, it's only supported for binary image segmentation now.
109+
It supports both binary and multi-class segmentation.
92110
93111
Reimplementation of the Asymmetric Focal Loss described in:
94112
@@ -99,6 +117,7 @@ class AsymmetricFocalLoss(_Loss):
99117
def __init__(
100118
self,
101119
to_onehot_y: bool = False,
120+
use_softmax: bool = False,
102121
delta: float = 0.7,
103122
gamma: float = 2,
104123
epsilon: float = 1e-7,
@@ -107,17 +126,27 @@ def __init__(
107126
"""
108127
Args:
109128
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
129+
use_softmax: whether to use softmax to transform the original logits into probabilities.
130+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
110131
delta : weight of the background. Defaults to 0.7.
111132
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112133
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
113134
"""
114135
super().__init__(reduction=LossReduction(reduction).value)
115136
self.to_onehot_y = to_onehot_y
137+
self.use_softmax = use_softmax
116138
self.delta = delta
117139
self.gamma = gamma
118140
self.epsilon = epsilon
119141

120142
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
143+
if self.use_softmax:
144+
y_log_pred = F.log_softmax(y_pred, dim=1)
145+
y_pred = torch.exp(y_log_pred)
146+
else:
147+
y_log_pred = F.logsigmoid(y_pred)
148+
y_pred = torch.sigmoid(y_pred)
149+
121150
n_pred_ch = y_pred.shape[1]
122151

123152
if self.to_onehot_y:
@@ -130,23 +159,28 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
130159
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
131160

132161
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
133-
cross_entropy = -y_true * torch.log(y_pred)
162+
cross_entropy = -y_true * y_log_pred
134163

135164
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
136165
back_ce = (1 - self.delta) * back_ce
137166

138-
fore_ce = cross_entropy[:, 1]
167+
fore_ce = cross_entropy[:, 1:]
139168
fore_ce = self.delta * fore_ce
140169

141-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
170+
if fore_ce.shape[1] > 1:
171+
fore_ce = torch.sum(fore_ce, dim=1)
172+
else:
173+
fore_ce = fore_ce.squeeze(1)
174+
175+
loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1))
142176
return loss
143177

144178

145179
class AsymmetricUnifiedFocalLoss(_Loss):
146180
"""
147181
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148182
149-
Actually, it's only supported for binary image segmentation now
183+
It supports both binary and multi-class segmentation.
150184
151185
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152186
@@ -157,20 +191,21 @@ class AsymmetricUnifiedFocalLoss(_Loss):
157191
def __init__(
158192
self,
159193
to_onehot_y: bool = False,
160-
num_classes: int = 2,
161194
weight: float = 0.5,
162195
gamma: float = 0.5,
163196
delta: float = 0.7,
197+
use_softmax: bool = False,
164198
reduction: LossReduction | str = LossReduction.MEAN,
165199
):
166200
"""
167201
Args:
168202
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170203
delta : weight of the background. Defaults to 0.7.
171204
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172205
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173206
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
207+
use_softmax: whether to use softmax to transform the original logits into probabilities.
208+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
174209
175210
Example:
176211
>>> import torch
@@ -182,50 +217,32 @@ def __init__(
182217
"""
183218
super().__init__(reduction=LossReduction(reduction).value)
184219
self.to_onehot_y = to_onehot_y
185-
self.num_classes = num_classes
186220
self.gamma = gamma
187221
self.delta = delta
188222
self.weight: float = weight
189-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
190-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
223+
self.use_softmax = use_softmax
224+
self.asy_focal_loss = AsymmetricFocalLoss(
225+
gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y
226+
)
227+
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
228+
gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y
229+
)
191230

192-
# TODO: Implement this function to support multiple classes segmentation
193231
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
194232
"""
195233
Args:
196234
y_pred : the shape should be BNH[WD], where N is the number of classes.
197-
It only supports binary segmentation.
198235
The input should be the original logits since it will be transformed by
199236
a sigmoid in the forward function.
200237
y_true : the shape should be BNH[WD], where N is the number of classes.
201-
It only supports binary segmentation.
202238
203239
Raises:
204240
ValueError: When input and target are different shape
205-
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206-
ValueError: When num_classes
207241
ValueError: When the number of classes entered does not match the expected number
208242
"""
209243
if y_pred.shape != y_true.shape:
210244
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
211245

212-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
213-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
214-
215-
if y_pred.shape[1] == 1:
216-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
217-
y_true = one_hot(y_true, num_classes=self.num_classes)
218-
219-
if torch.max(y_true) != self.num_classes - 1:
220-
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
221-
222-
n_pred_ch = y_pred.shape[1]
223-
if self.to_onehot_y:
224-
if n_pred_ch == 1:
225-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
226-
else:
227-
y_true = one_hot(y_true, num_classes=n_pred_ch)
228-
229246
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
230247
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
231248

tests/losses/test_unified_focal_loss.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,41 @@
2020
from monai.losses import AsymmetricUnifiedFocalLoss
2121

2222
TEST_CASES = [
23-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
23+
# Case 0: Binary segmentation
24+
[
25+
{},
2426
{
25-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
27+
"y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]),
2628
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
2729
},
2830
0.0,
2931
],
30-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
32+
# Case 1: Same as above but explicit arguments
33+
[
34+
{"use_softmax": False, "to_onehot_y": False},
3135
{
32-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
36+
"y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]),
3337
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
3438
},
3539
0.0,
3640
],
41+
# Case 2: Multi-class segmentation
42+
[
43+
{"use_softmax": True, "to_onehot_y": True},
44+
{
45+
"y_pred": torch.tensor([[[[-20.0]], [[-20.0]], [[20.0]]]]).repeat(2, 1, 1, 1),
46+
"y_true": torch.tensor([[[[2]]]]).repeat(2, 1, 1, 1),
47+
},
48+
0.0,
49+
],
3750
]
3851

3952

4053
class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
4154

4255
@parameterized.expand(TEST_CASES)
43-
def test_result(self, input_data, expected_val):
44-
loss = AsymmetricUnifiedFocalLoss()
56+
def test_result(self, input_param, input_data, expected_val):
57+
loss = AsymmetricUnifiedFocalLoss(**input_param)
4558
result = loss(**input_data)
4659
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
4760

@@ -52,7 +65,7 @@ def test_ill_shape(self):
5265

5366
def test_with_cuda(self):
5467
loss = AsymmetricUnifiedFocalLoss()
55-
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
68+
i = torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]])
5669
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
5770
if torch.cuda.is_available():
5871
i = i.cuda()

0 commit comments

Comments
 (0)