1414import warnings
1515
1616import torch
17+ import torch .nn .functional as F
1718from torch .nn .modules .loss import _Loss
1819
1920from monai .networks import one_hot
@@ -24,7 +25,9 @@ class AsymmetricFocalTverskyLoss(_Loss):
2425 """
2526 AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
2627
27- Actually, it's only supported for binary image segmentation now.
28+ It supports both binary and multi-class segmentation.
29+
30+ The logic assumes channel 0 is Background, and channels 1..N are Foreground.
2831
2932 Reimplementation of the Asymmetric Focal Tversky Loss described in:
3033
@@ -35,6 +38,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3538 def __init__ (
3639 self ,
3740 to_onehot_y : bool = False ,
41+ use_softmax : bool = False ,
3842 delta : float = 0.7 ,
3943 gamma : float = 0.75 ,
4044 epsilon : float = 1e-7 ,
@@ -43,17 +47,25 @@ def __init__(
4347 """
4448 Args:
4549 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
50+ use_softmax: whether to use softmax to transform the original logits into probabilities.
51+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
4652 delta : weight of the background. Defaults to 0.7.
4753 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
4854 epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
4955 """
5056 super ().__init__ (reduction = LossReduction (reduction ).value )
5157 self .to_onehot_y = to_onehot_y
58+ self .use_softmax = use_softmax
5259 self .delta = delta
5360 self .gamma = gamma
5461 self .epsilon = epsilon
5562
5663 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
64+ if self .use_softmax :
65+ y_pred = torch .softmax (y_pred , dim = 1 )
66+ else :
67+ y_pred = torch .sigmoid (y_pred )
68+
5769 n_pred_ch = y_pred .shape [1 ]
5870
5971 if self .to_onehot_y :
@@ -67,17 +79,23 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6779
6880 # clip the prediction to avoid NaN
6981 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
70- axis = list (range (2 , len (y_pred .shape )))
82+
83+ spatial_dims = list (range (2 , len (y_pred .shape )))
7184
7285 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
73- tp = torch .sum (y_true * y_pred , dim = axis )
74- fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
75- fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
86+ tp = torch .sum (y_true * y_pred , dim = spatial_dims )
87+ fn = torch .sum (y_true * (1 - y_pred ), dim = spatial_dims )
88+ fp = torch .sum ((1 - y_true ) * y_pred , dim = spatial_dims )
7689 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
7790
7891 # Calculate losses separately for each class, enhancing both classes
7992 back_dice = 1 - dice_class [:, 0 ]
80- fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
93+ fore_dice = (1 - dice_class [:, 1 :]) * torch .pow (1 - dice_class [:, 1 :], - self .gamma )
94+
95+ if fore_dice .shape [1 ] > 1 :
96+ fore_dice = torch .mean (fore_dice , dim = 1 )
97+ else :
98+ fore_dice = fore_dice .squeeze (1 )
8199
82100 # Average class scores
83101 loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
@@ -88,7 +106,7 @@ class AsymmetricFocalLoss(_Loss):
88106 """
89107 AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90108
91- Actually, it's only supported for binary image segmentation now .
109+ It supports both binary and multi-class segmentation .
92110
93111 Reimplementation of the Asymmetric Focal Loss described in:
94112
@@ -99,6 +117,7 @@ class AsymmetricFocalLoss(_Loss):
99117 def __init__ (
100118 self ,
101119 to_onehot_y : bool = False ,
120+ use_softmax : bool = False ,
102121 delta : float = 0.7 ,
103122 gamma : float = 2 ,
104123 epsilon : float = 1e-7 ,
@@ -107,17 +126,27 @@ def __init__(
107126 """
108127 Args:
109128 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
129+ use_softmax: whether to use softmax to transform the original logits into probabilities.
130+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
110131 delta : weight of the background. Defaults to 0.7.
111132 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112133 epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
113134 """
114135 super ().__init__ (reduction = LossReduction (reduction ).value )
115136 self .to_onehot_y = to_onehot_y
137+ self .use_softmax = use_softmax
116138 self .delta = delta
117139 self .gamma = gamma
118140 self .epsilon = epsilon
119141
120142 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
143+ if self .use_softmax :
144+ y_log_pred = F .log_softmax (y_pred , dim = 1 )
145+ y_pred = torch .exp (y_log_pred )
146+ else :
147+ y_log_pred = F .logsigmoid (y_pred )
148+ y_pred = torch .sigmoid (y_pred )
149+
121150 n_pred_ch = y_pred .shape [1 ]
122151
123152 if self .to_onehot_y :
@@ -130,23 +159,28 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
130159 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
131160
132161 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
133- cross_entropy = - y_true * torch . log ( y_pred )
162+ cross_entropy = - y_true * y_log_pred
134163
135164 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
136165 back_ce = (1 - self .delta ) * back_ce
137166
138- fore_ce = cross_entropy [:, 1 ]
167+ fore_ce = cross_entropy [:, 1 : ]
139168 fore_ce = self .delta * fore_ce
140169
141- loss = torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
170+ if fore_ce .shape [1 ] > 1 :
171+ fore_ce = torch .sum (fore_ce , dim = 1 )
172+ else :
173+ fore_ce = fore_ce .squeeze (1 )
174+
175+ loss = torch .mean (torch .stack ([back_ce , fore_ce ], dim = - 1 ))
142176 return loss
143177
144178
145179class AsymmetricUnifiedFocalLoss (_Loss ):
146180 """
147181 AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148182
149- Actually, it's only supported for binary image segmentation now
183+ It supports both binary and multi-class segmentation.
150184
151185 Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152186
@@ -157,20 +191,21 @@ class AsymmetricUnifiedFocalLoss(_Loss):
157191 def __init__ (
158192 self ,
159193 to_onehot_y : bool = False ,
160- num_classes : int = 2 ,
161194 weight : float = 0.5 ,
162195 gamma : float = 0.5 ,
163196 delta : float = 0.7 ,
197+ use_softmax : bool = False ,
164198 reduction : LossReduction | str = LossReduction .MEAN ,
165199 ):
166200 """
167201 Args:
168202 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169- num_classes : number of classes, it only supports 2 now. Defaults to 2.
170203 delta : weight of the background. Defaults to 0.7.
171204 gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172205 epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173206 weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
207+ use_softmax: whether to use softmax to transform the original logits into probabilities.
208+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
174209
175210 Example:
176211 >>> import torch
@@ -182,50 +217,32 @@ def __init__(
182217 """
183218 super ().__init__ (reduction = LossReduction (reduction ).value )
184219 self .to_onehot_y = to_onehot_y
185- self .num_classes = num_classes
186220 self .gamma = gamma
187221 self .delta = delta
188222 self .weight : float = weight
189- self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
190- self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
223+ self .use_softmax = use_softmax
224+ self .asy_focal_loss = AsymmetricFocalLoss (
225+ gamma = self .gamma , delta = self .delta , use_softmax = use_softmax , to_onehot_y = to_onehot_y
226+ )
227+ self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
228+ gamma = self .gamma , delta = self .delta , use_softmax = use_softmax , to_onehot_y = to_onehot_y
229+ )
191230
192- # TODO: Implement this function to support multiple classes segmentation
193231 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
194232 """
195233 Args:
196234 y_pred : the shape should be BNH[WD], where N is the number of classes.
197- It only supports binary segmentation.
198235 The input should be the original logits since it will be transformed by
199236 a sigmoid in the forward function.
200237 y_true : the shape should be BNH[WD], where N is the number of classes.
201- It only supports binary segmentation.
202238
203239 Raises:
204240 ValueError: When input and target are different shape
205- ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206- ValueError: When num_classes
207241 ValueError: When the number of classes entered does not match the expected number
208242 """
209243 if y_pred .shape != y_true .shape :
210244 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
211245
212- if len (y_pred .shape ) != 4 and len (y_pred .shape ) != 5 :
213- raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
214-
215- if y_pred .shape [1 ] == 1 :
216- y_pred = one_hot (y_pred , num_classes = self .num_classes )
217- y_true = one_hot (y_true , num_classes = self .num_classes )
218-
219- if torch .max (y_true ) != self .num_classes - 1 :
220- raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
221-
222- n_pred_ch = y_pred .shape [1 ]
223- if self .to_onehot_y :
224- if n_pred_ch == 1 :
225- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
226- else :
227- y_true = one_hot (y_true , num_classes = n_pred_ch )
228-
229246 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
230247 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
231248
0 commit comments