Skip to content

Commit 6c25189

Browse files
committed
minor fixes
Signed-off-by: ytl0623 <[email protected]>
1 parent 0640d0f commit 6c25189

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
176176
fore_ce = self.delta * fore_ce
177177

178178
if fore_ce.shape[1] > 1:
179-
fore_ce = torch.sum(fore_ce, dim=1)
179+
fore_ce = torch.mean(fore_ce, dim=1)
180180
else:
181181
fore_ce = fore_ce.squeeze(1)
182182

@@ -241,7 +241,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
241241
y_pred : the shape should be BNH[WD], where N is the number of classes.
242242
The input should be the original logits since it will be transformed by
243243
a sigmoid/softmax in the forward function.
244-
y_true : the shape should be BNH[WD], where N is the number of classes.
244+
y_true : the shape should be BNH[WD], or B1H[WD] when to_onehot_y=True.
245245
"""
246246

247247
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)

0 commit comments

Comments
 (0)