Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/topk accuracy #2423

Merged
merged 6 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))


- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


## [1.3.1] - 2024-02-12

### Fixed
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ class MulticlassAccuracy(MulticlassStatScores):
def compute(self) -> Tensor:
"""Compute accuracy based on inputs passed in to ``update`` previously."""
tp, fp, tn, fn = self._final_state()
return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)
return _accuracy_reduce(
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
Expand Down Expand Up @@ -702,7 +702,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _accuracy_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
top_k: int = 1,
) -> Tensor:
"""Reduce classification statistics into accuracy score.
Expand All @@ -66,6 +67,7 @@ def _accuracy_reduce(
- ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
multilabel: If input is multilabel or not
top_k: value for top-k accuracy, else 1
Returns:
Accuracy score
Expand All @@ -83,7 +85,7 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)


def binary_accuracy(
Expand Down Expand Up @@ -266,7 +268,7 @@ def multiclass_accuracy(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k)


def multilabel_accuracy(
Expand Down
11 changes: 8 additions & 3 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _precision_recall_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
top_k: int = 1,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -54,7 +55,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat)

score = _safe_divide(tp, tp + different_stat)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)


def binary_precision(
Expand Down Expand Up @@ -235,7 +236,9 @@ def multiclass_precision(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
)


def multilabel_precision(
Expand Down Expand Up @@ -519,7 +522,9 @@ def multiclass_recall(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
)


def multilabel_recall(
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -65,7 +65,7 @@ def _adjust_weights_safe_divide(
else:
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0] = 0.0
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
5 changes: 5 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,17 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype):
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])


@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected"),
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(3 / 3)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 2)),
],
)
def test_top_k(k, preds, target, average, expected):
Expand Down
5 changes: 5 additions & 0 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])


@pytest.mark.parametrize(
("metric_class", "metric_fn"), [(MulticlassPrecision, multiclass_precision), (MulticlassRecall, multiclass_recall)]
Expand All @@ -340,6 +343,8 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
],
)
def test_top_k(
Expand Down
Loading