Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mrr k #1961

Merged
merged 11 commits into from
Aug 3, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926))


- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


### Changed

-
Expand Down
15 changes: 13 additions & 2 deletions src/torchmetrics/functional/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
def retrieval_reciprocal_rank(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Compute reciprocal rank (for information retrieval). See `Mean Reciprocal Rank`_.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
Expand All @@ -27,10 +29,15 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
top_k: consider only the top k elements (default: ``None``, which considers them all)

Return:
a single-value tensor with the reciprocal rank (RR) of the predictions ``preds`` wrt the labels ``target``.

Raises:
ValueError:
If ``top_k`` is not ``None`` or an integer larger than 0.

Example:
>>> from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
>>> preds = torch.tensor([0.2, 0.3, 0.5])
Expand All @@ -41,9 +48,13 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

top_k = top_k or preds.shape[-1]
if not isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}.")

target = target[preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1]]
if not target.sum():
return tensor(0.0, device=preds.device)

target = target[torch.argsort(preds, dim=-1, descending=True)]
position = torch.nonzero(target).view(-1)
return 1.0 / (position[0] + 1.0)
11 changes: 5 additions & 6 deletions src/torchmetrics/retrieval/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class RetrievalMAP(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``rmap`` (:class:`~torch.Tensor`): A tensor with the mean average precision of the predictions ``preds``
w.r.t. the labels ``target``
- ``map@k`` (:class:`~torch.Tensor`): A single-value tensor with the mean average precision (MAP)
of the predictions ``preds`` w.r.t. the labels ``target``.

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -54,9 +54,8 @@ class RetrievalMAP(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -65,7 +64,7 @@ class RetrievalMAP(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not ``None`` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torch import tensor
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RetrievalFallOut(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``fo`` (:class:`~torch.Tensor`): A tensor with the computed metric
- ``fo@k`` (:class:`~torch.Tensor`): A tensor with the computed metric

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -55,9 +55,8 @@ class RetrievalFallOut(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: `None`, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: `None`, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalFallOut(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torchmetrics.retrieval import RetrievalFallOut
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalHitRate(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``hr2`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions
- ``hr@k`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,9 +55,8 @@ class RetrievalHitRate(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalHitRate(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torch import tensor
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalNormalizedDCG(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``ndcg`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions
- ``ndcg@k`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions
``preds`` w.r.t. the labels ``target``

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,9 +55,8 @@ class RetrievalNormalizedDCG(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalNormalizedDCG(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torch import tensor
Expand Down
11 changes: 5 additions & 6 deletions src/torchmetrics/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalPrecision(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions
- ``p@k`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -54,10 +54,9 @@ class RetrievalPrecision(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
adaptive_k: adjust ``top_k`` to ``min(k, number of documents)`` for each query
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
adaptive_k: Adjust ``top_k`` to ``min(k, number of documents)`` for each query
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalPrecision(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
ValueError:
If ``adaptive_k`` is not boolean.

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/retrieval/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class RetrievalPrecisionRecallCurve(Metric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``max_k`` parameter is not `None` or an integer larger than 0.
If ``max_k`` parameter is not `None` or not an integer larger than 0.

Example:
>>> from torch import tensor
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/retrieval/r_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalRPrecision(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds``
- ``rp`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds``
w.r.t. the labels ``target``.

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -54,8 +54,7 @@ class RetrievalRPrecision(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

ignore_index:
Ignore predictions where the target is equal to this number.
ignore_index: Ignore predictions where the target is equal to this number.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalRecall(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``r2`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions
- ``r@k`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,7 +55,7 @@ class RetrievalRecall(RetrievalMetric):
- ``'error'``: raise a ``ValueError``

ignore_index: Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: `None`, which considers them all)
top_k: Consider only the top k elements for each query (default: `None`, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -64,7 +64,7 @@ class RetrievalRecall(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torch import tensor
Expand Down
28 changes: 24 additions & 4 deletions src/torchmetrics/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union

from torch import Tensor

Expand All @@ -38,8 +38,8 @@ class RetrievalMRR(RetrievalMetric):

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mrr`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR) of the predictions
``preds`` w.r.t. the labels ``target``
- ``mrr@k`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR)
of the predictions ``preds`` w.r.t. the labels ``target``.

All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -55,13 +55,16 @@ class RetrievalMRR(RetrievalMetric):
- ``'error'``: raise a ``ValueError``

ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``.
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not ``None`` or not an integer greater than 0.

Example:
>>> from torch import tensor
Expand All @@ -81,8 +84,25 @@ class RetrievalMRR(RetrievalMetric):
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
**kwargs,
)

if top_k is not None and not isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}")
self.top_k = top_k

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_reciprocal_rank(preds, target)
return retrieval_reciprocal_rank(preds, target, top_k=self.top_k)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
Loading