Skip to content

Commit

Permalink
fix ignored custom callable in retrieval metric aggregation (#2364)
Browse files Browse the repository at this point in the history
* fix retrieval aggregation

* fix retrieval tests

* changelog

---------

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
(cherry picked from commit dce2368)
  • Loading branch information
fschlatt authored and Borda committed Feb 12, 2024
1 parent dd4ce92 commit f9f1842
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))


- Fixed custom aggregation in retrieval metrics ([#2364](https://github.com/Lightning-AI/torchmetrics/pull/2364))


- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _retrieval_aggregate(
return values.median() if dim is None else values.median(dim=dim).values
if aggregation == "min":
return values.min() if dim is None else values.min(dim=dim).values
if aggregation:
if aggregation == "max":
return values.max() if dim is None else values.max(dim=dim).values
return aggregation(values, dim=dim)

Expand Down
20 changes: 18 additions & 2 deletions tests/unittests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
from numpy import array
from torch import Tensor, tensor
from torchmetrics.retrieval.base import _retrieval_aggregate
from typing_extensions import Literal

from unittests.helpers import seed_all
Expand All @@ -42,6 +41,23 @@
# a version of get_group_indexes that depends on NumPy is here to avoid this dependency for the full library


def _retrieval_aggregate(
values: Tensor,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
dim: Optional[int] = None,
) -> Tensor:
"""Aggregate the final retrieval values into a single value."""
if aggregation == "mean":
return values.mean() if dim is None else values.mean(dim=dim)
if aggregation == "median":
return values.median() if dim is None else values.median(dim=dim).values
if aggregation == "min":
return values.min() if dim is None else values.min(dim=dim).values
if aggregation == "max":
return values.max() if dim is None else values.max(dim=dim).values
return aggregation(values, dim=dim)


def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]:
"""Extract group indexes.
Expand Down Expand Up @@ -74,7 +90,7 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor,


def _custom_aggregate_fn(val: Tensor, dim=None) -> Tensor:
return (val**2).mean(dim=dim)
return (val**2).mean() if dim is None else (val**2).mean(dim=dim)


def _compute_sklearn_metric(
Expand Down

0 comments on commit f9f1842

Please sign in to comment.