Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (4/5) #146

Merged
merged 8 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMRR` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))


- Added `RetrievalPrecision` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))
- Added `RetrievalPrecision` metric for Information Retrieval ([#139](https://github.com/PyTorchLightning/metrics/pull/139))


- Added `RetrievalRecall` metric for Information Retrieval ([#146](https://github.com/PyTorchLightning/metrics/pull/146))


- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,10 @@ retrieval_precision [func]

.. autofunction:: torchmetrics.functional.retrieval_precision
:noindex:


retrieval_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_recall
:noindex:
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ RetrievalPrecision
:noindex:


RetrievalRecall
~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalRecall
:noindex:


********
Wrappers
********
Expand Down
12 changes: 9 additions & 3 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from tests.helpers import seed_all
from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank
from tests.retrieval.test_precision import _precision_at_k as precision_at_k
from tests.retrieval.test_recall import _recall_at_k as recall_at_k
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank

seed_all(1337)
Expand Down Expand Up @@ -46,6 +48,7 @@ def test_metrics_output_values(sklearn_metric, torch_metric, size):

@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
[precision_at_k, retrieval_precision],
[recall_at_k, retrieval_recall]
Borda marked this conversation as resolved.
Show resolved Hide resolved
])
@pytest.mark.parametrize("size", [1, 4, 10])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
Expand Down Expand Up @@ -76,7 +79,8 @@ def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k):
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
[retrieval_precision],
[retrieval_recall]
Borda marked this conversation as resolved.
Show resolved Hide resolved
])
def test_input_dtypes(torch_metric) -> None:
""" Check wrong input dtypes are managed correctly. """
Expand Down Expand Up @@ -107,7 +111,8 @@ def test_input_dtypes(torch_metric) -> None:
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
[retrieval_precision],
[retrieval_recall]
Borda marked this conversation as resolved.
Show resolved Hide resolved
])
def test_input_shapes(torch_metric) -> None:
""" Check wrong input shapes are managed correctly. """
Expand All @@ -130,7 +135,8 @@ def test_input_shapes(torch_metric) -> None:

# test metrics using top K parameter
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_precision]
[retrieval_precision],
[retrieval_recall]
Borda marked this conversation as resolved.
Show resolved Hide resolved
])
@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(torch_metric, k) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/retrieval/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
def _precision_at_k(target: np.array, preds: np.array, k: int = None):
"""
Didn't find a reliable implementation of Precision in Information Retrieval, so,
reimplementing here. A good explanation can be found ``
reimplementing here. A good explanation can be found
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs
Expand Down
56 changes: 56 additions & 0 deletions tests/retrieval/test_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall


def _recall_at_k(target: np.array, preds: np.array, k: int = None):
"""
Didn't find a reliable implementation of Recall in Information Retrieval, so,
reimplementing here. See wikipedia for more information about definition.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
k = len(preds)

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][:k])
return relevant * 1.0 / target.sum()
else:
return np.NaN


@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg'])
@pytest.mark.parametrize('k', [None, 1, 4, 10])
def test_results(size, n_documents, query_without_relevant_docs_options, k):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_recall_at_k,
RetrievalRecall,
size,
n_documents,
query_without_relevant_docs_options,
k=k
)


def test_dtypes():
""" Check dypes are managed correctly. """
_test_dtypes(RetrievalRecall)


def test_input_shapes() -> None:
"""Check inputs shapes are managed correctly. """
_test_input_shapes(RetrievalRecall)


@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(k) -> None:
"""Check invalid args are managed correctly. """
_test_input_args(RetrievalRecall, "`k` has to be a positive integer or None", k=k)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision, RetrievalRecall # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
56 changes: 56 additions & 0 deletions torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes the recall metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`__.
Recall is the fraction of relevant documents among all the relevant documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Recall@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)

Returns:
a single-value tensor with the recall (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> preds = tensor([0.2, 0.3, 0.5])
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = tensor([True, False, True])
>>> retrieval_recall(preds, target, k=2)
tensor(0.5000)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
k = preds.shape[-1]

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

if target.sum() == 0:
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float()
return relevant / target.sum()
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
99 changes: 99 additions & 0 deletions torchmetrics/retrieval/retrieval_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

from torch import Tensor, tensor

from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.retrieval.retrieval_metric import IGNORE_IDX, RetrievalMetric


class RetrievalRecall(RetrievalMetric):
"""
Computes `Recall
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall>`__.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``

``indexes``, ``preds`` and ``target`` must have the same dimension.
``indexes`` indicate to which query a prediction belongs.
Predictions will be first grouped by ``indexes`` and then `Recall` will be computed as the mean
of the `Recall` over each query.

Args:
query_without_relevant_docs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_without_relevant_docs:
miss_query_action:

the argument doe not help with the meaning so let's keep it compact

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we make this change then all the other retrieval metrics should also be changed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, or if you have any better name suggestions? @lucadiliello
cc: @PyTorchLightning/core-metrics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss_query_action sounds good. I suggest also empty_target_action because this is the case when target.sum() == 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for empty_target_action

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with empty_target_action 🎉

Specify what to do with queries that do not have at least a positive ``target``. Choose from:

- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``

exclude:
Do not take into account predictions where the ``target`` is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
k: consider only the top k elements for each query. default: None

Example:
>>> from torchmetrics import RetrievalRecall
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> r2 = RetrievalRecall(k=2)
>>> r2(indexes, preds, target)
tensor(0.7500)
"""

def __init__(
self,
query_without_relevant_docs: str = 'skip',
exclude: int = IGNORE_IDX,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
k: int = None
):
super().__init__(
query_without_relevant_docs=query_without_relevant_docs,
exclude=exclude,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)

if (k is not None) and not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")
self.k = k

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
valid_indexes = (target != self.exclude)
return retrieval_recall(preds[valid_indexes], target[valid_indexes], k=self.k)