Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (4/5) #146

Merged
merged 8 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMRR` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))


- Added `RetrievalPrecision` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))
- Added `RetrievalPrecision` metric for Information Retrieval ([#139](https://github.com/PyTorchLightning/metrics/pull/139))


- Added `RetrievalRecall` metric for Information Retrieval ([#146](https://github.com/PyTorchLightning/metrics/pull/146))


- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,10 @@ retrieval_precision [func]

.. autofunction:: torchmetrics.functional.retrieval_precision
:noindex:


retrieval_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_recall
:noindex:
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ RetrievalPrecision
:noindex:


RetrievalRecall
~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalRecall
:noindex:


********
Wrappers
********
Expand Down
14 changes: 11 additions & 3 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from tests.helpers import seed_all
from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank
from tests.retrieval.test_precision import _precision_at_k as precision_at_k
from tests.retrieval.test_recall import _recall_at_k as recall_at_k
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank

seed_all(1337)
Expand Down Expand Up @@ -46,6 +48,7 @@ def test_metrics_output_values(sklearn_metric, torch_metric, size):

@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
[precision_at_k, retrieval_precision],
[recall_at_k, retrieval_recall],
])
@pytest.mark.parametrize("size", [1, 4, 10])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
Expand Down Expand Up @@ -73,11 +76,12 @@ def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k):
assert torch.allclose(sk.float(), tm.float())


@pytest.mark.parametrize(['torch_metric'], (
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision],
))
[retrieval_recall],
])
def test_input_dtypes(torch_metric) -> None:
""" Check wrong input dtypes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -108,6 +112,7 @@ def test_input_dtypes(torch_metric) -> None:
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision],
[retrieval_recall],
))
def test_input_shapes(torch_metric) -> None:
""" Check wrong input shapes are managed correctly. """
Expand All @@ -129,7 +134,10 @@ def test_input_shapes(torch_metric) -> None:


# test metrics using top K parameter
@pytest.mark.parametrize(['torch_metric'], ([retrieval_precision], ))
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_precision],
[retrieval_recall],
])
@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(torch_metric, k) -> None:
""" Check wrong input shapes are managed correctly. """
Expand Down
22 changes: 11 additions & 11 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _test_retrieval_against_sklearn(
torch_metric: Metric,
size: int,
n_documents: int,
query_without_relevant_docs_options: str,
**kwargs,
empty_target_action: str,
**kwargs
) -> None:
""" Compare PL metrics to standard version. """
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options, **kwargs)
metric = torch_metric(empty_target_action=empty_target_action, **kwargs)
shape = (size, )

indexes = []
Expand All @@ -55,7 +55,7 @@ def _test_retrieval_against_sklearn(
preds.append(np.random.randn(*shape))
target.append(np.random.randn(*shape) > 0)

sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options, **kwargs)
sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, empty_target_action, **kwargs)
sk_results = torch.tensor(sk_results)

indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]).long()
Expand Down Expand Up @@ -83,26 +83,26 @@ def _test_dtypes(torchmetric) -> None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
length = 10 # not important in this test

# check error when `query_without_relevant_docs='error'` is raised correctly
# check error when `empty_target_action='error'` is raised correctly
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32)
target = torch.tensor([False] * length, device=device, dtype=torch.bool)

metric = torchmetric(query_without_relevant_docs='error')
metric = torchmetric(empty_target_action='error')
with pytest.raises(ValueError, match="`compute` method was provided with a query with no positive target."):
metric(indexes, preds, target)

# check ValueError with invalid `query_without_relevant_docs` argument
# check ValueError with invalid `empty_target_action` argument
casual_argument = 'casual_argument'
with pytest.raises(ValueError, match=f"`query_without_relevant_docs` received a wrong value {casual_argument}."):
metric = torchmetric(query_without_relevant_docs=casual_argument)
with pytest.raises(ValueError, match=f"`empty_target_action` received a wrong value {casual_argument}."):
metric = torchmetric(empty_target_action=casual_argument)

# check input dtypes
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.tensor([0] * length, device=device, dtype=torch.float32)
target = torch.tensor([0] * length, device=device, dtype=torch.int64)

metric = torchmetric(query_without_relevant_docs='error')
metric = torchmetric(empty_target_action='error')

# check error on input dtypes are raised correctly
with pytest.raises(ValueError, match="`indexes` must be a tensor of long integers"):
Expand All @@ -116,7 +116,7 @@ def _test_dtypes(torchmetric) -> None:
def _test_input_shapes(torchmetric) -> None:
"""Check PL metrics inputs are controlled correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = torchmetric(query_without_relevant_docs='error')
metric = torchmetric(empty_target_action='error')

# check input shapes are checked correclty
elements_1, elements_2 = np.random.choice(np.arange(1, 20), size=2, replace=False)
Expand Down
6 changes: 3 additions & 3 deletions tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg'])
def test_results(size, n_documents, query_without_relevant_docs_options):
@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg'])
def test_results(size, n_documents, empty_target_action):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
sk_average_precision, RetrievalMAP, size, n_documents, query_without_relevant_docs_options
sk_average_precision, RetrievalMAP, size, n_documents, empty_target_action
)


Expand Down
6 changes: 3 additions & 3 deletions tests/retrieval/test_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def _reciprocal_rank(target: np.array, preds: np.array):

@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg'])
def test_results(size, n_documents, query_without_relevant_docs_options):
@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg'])
def test_results(size, n_documents, empty_target_action):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_reciprocal_rank, RetrievalMRR, size, n_documents, query_without_relevant_docs_options
_reciprocal_rank, RetrievalMRR, size, n_documents, empty_target_action
)


Expand Down
9 changes: 5 additions & 4 deletions tests/retrieval/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
def _precision_at_k(target: np.array, preds: np.array, k: int = None):
"""
Didn't find a reliable implementation of Precision in Information Retrieval, so,
reimplementing here. A good explanation can be found ``
reimplementing here. A good explanation can be found
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs
Expand All @@ -26,12 +27,12 @@ def _precision_at_k(target: np.array, preds: np.array, k: int = None):

@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg'])
@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg'])
@pytest.mark.parametrize('k', [None, 1, 4, 10])
def test_results(size, n_documents, query_without_relevant_docs_options, k):
def test_results(size, n_documents, empty_target_action, k):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_precision_at_k, RetrievalPrecision, size, n_documents, query_without_relevant_docs_options, k=k
_precision_at_k, RetrievalPrecision, size, n_documents, empty_target_action, k=k
)


Expand Down
56 changes: 56 additions & 0 deletions tests/retrieval/test_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall


def _recall_at_k(target: np.array, preds: np.array, k: int = None):
"""
Didn't find a reliable implementation of Recall in Information Retrieval, so,
reimplementing here. See wikipedia for more information about definition.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
k = len(preds)

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][:k])
return relevant * 1.0 / target.sum()
else:
return np.NaN


@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg'])
@pytest.mark.parametrize('k', [None, 1, 4, 10])
def test_results(size, n_documents, empty_target_action, k):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_recall_at_k,
RetrievalRecall,
size,
n_documents,
empty_target_action,
k=k
)


def test_dtypes():
""" Check dypes are managed correctly. """
_test_dtypes(RetrievalRecall)


def test_input_shapes() -> None:
"""Check inputs shapes are managed correctly. """
_test_input_shapes(RetrievalRecall)


@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(k) -> None:
"""Check invalid args are managed correctly. """
_test_input_args(RetrievalRecall, "`k` has to be a positive integer or None", k=k)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision, RetrievalRecall # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
57 changes: 57 additions & 0 deletions torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes the recall metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`__.
Recall is the fraction of relevant documents among all the relevant documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Recall@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)

Returns:
a single-value tensor with the recall (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> from torchmetrics.functional import retrieval_recall
>>> preds = tensor([0.2, 0.3, 0.5])
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> target = tensor([True, False, True])
>>> retrieval_recall(preds, target, k=2)
tensor(0.5000)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
k = preds.shape[-1]

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

if target.sum() == 0:
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float()
return relevant / target.sum()
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
2 changes: 1 addition & 1 deletion torchmetrics/retrieval/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RetrievalMAP(RetrievalMetric):
of the `Average Precisions` over each query.

Args:
query_without_relevant_docs:
empty_target_action:
Specify what to do with queries that do not have at least a positive ``target``. Choose from:

- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/retrieval/mean_reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RetrievalMRR(RetrievalMetric):
of the `Reciprocal Rank` over each query.

Args:
query_without_relevant_docs:
empty_target_action:
Specify what to do with queries that do not have at least a positive ``target``. Choose from:

- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
Expand Down
Loading