Skip to content

Commit

Permalink
Remove get_num_classes in jaccard_index (#914)
Browse files Browse the repository at this point in the history
* remove get_num_classes
* remove test
  • Loading branch information
rusty1s authored Mar 31, 2022
1 parent 712a81b commit c537c9b
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 60 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853))
- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914))


- Added normalizer, tokenizer to ROUGE metric ([#838](https://github.com/PyTorchLightning/metrics/pull/838))
Expand Down
11 changes: 0 additions & 11 deletions tests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,3 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction
reduction=reduction,
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))


def test_warning_on_difference_in_number_of_classes():
"""Test that warning is thrown if the detected number of classes are different from the the specified number of
classes."""
preds = torch.randint(3, (10,))
target = torch.randint(3, (10,))
with pytest.warns(
RuntimeWarning,
):
jaccard_index(preds, target, num_classes=4)
14 changes: 1 addition & 13 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce


Expand Down Expand Up @@ -92,18 +92,6 @@ def test_to_categorical():
assert torch.allclose(result, expected.to(result.dtype))


@pytest.mark.parametrize(
["preds", "target", "num_classes", "expected_num_classes"],
[
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
],
)
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
assert get_num_classes(preds, target, num_classes) == expected_num_classes


def test_flatten_list():
"""Check that _flatten utility function works as expected."""
inp = [[1, 2, 3], [4, 5], [6]]
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.data import get_num_classes
from torchmetrics.utilities.distributed import reduce


Expand Down Expand Up @@ -129,6 +128,5 @@ def jaccard_index(
tensor(0.9660)
"""

num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
33 changes: 0 additions & 33 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6


Expand Down Expand Up @@ -145,37 +143,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
return torch.argmax(x, dim=argmax_dim)


def get_num_classes(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
"""Calculates the number of classes for a given prediction and target tensor.
Args:
preds: predicted values
target: true labels
num_classes: number of classes if known
Return:
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(preds.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)

if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(
f"You have set {num_classes} number of classes which is"
f" different from predicted ({num_pred_classes}) and"
f" target ({num_target_classes}) number of classes",
RuntimeWarning,
)
return num_classes


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
Expand Down

0 comments on commit c537c9b

Please sign in to comment.