Skip to content

Commit f50d991

Browse files
authored
Apply suggestions from code review
1 parent 287361a commit f50d991

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

pytorch_lightning/metrics/classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
multiclass_roc,
1717
multiclass_precision_recall_curve,
1818
dice_score,
19-
iou
19+
iou,
2020
)
2121
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
2222

tests/metrics/functional/test_classification.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -367,24 +367,21 @@ def test_dice_score(pred, target, expected):
367367
assert score == expected
368368

369369

370-
@pytest.mark.parametrize(['target', 'pred', 'half_ones', 'reduction', 'remove_bg', 'expected'], [
371-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
372-
False, 'none', False, torch.Tensor([1, 1, 1])),
373-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
374-
False, 'elementwise_mean', False, torch.Tensor([1])),
375-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
376-
False, 'none', True, torch.Tensor([1, 1])),
377-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
378-
True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
379-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
380-
True, 'elementwise_mean', False, torch.Tensor([0.5])),
381-
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
382-
True, 'none', True, torch.Tensor([0.5, 0.5])),
370+
@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [
371+
pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])),
372+
pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])),
373+
pytest.param(False, 'none', True, torch.Tensor([1, 1])),
374+
pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
375+
pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])),
376+
pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])),
383377
])
384-
def test_iou(target, pred, half_ones, reduction, remove_bg, expected):
378+
def test_iou(half_ones, reduction, remove_bg, expected):
379+
pred = (torch.arange(120) % 3).view(-1, 1)
380+
target = (torch.arange(120) % 3).view(-1, 1)
385381
if half_ones:
386382
pred[:60] = 1
387-
assert torch.all(torch.eq(iou(pred, target, remove_bg=remove_bg, reduction=reduction), expected))
383+
iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction)
384+
assert torch.allclose(iou_val, expected, atol=1e-9)
388385

389386

390387
# example data taken from

0 commit comments

Comments
 (0)