Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (3/5) #139

Merged
merged 8 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMRR` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))


- Added `RetrievalPrecision` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))


- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))


Expand Down
9 changes: 8 additions & 1 deletion docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,14 @@ retrieval_average_precision [func]


retrieval_reciprocal_rank [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_reciprocal_rank
:noindex:


retrieval_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_precision
:noindex:
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,13 @@ RetrievalMRR
:noindex:


RetrievalPrecision
~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalPrecision
:noindex:


********
Wrappers
********
Expand Down
49 changes: 49 additions & 0 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from tests.helpers import seed_all
from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank
from tests.retrieval.test_precision import _precision_at_k as precision_at_k
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank

seed_all(1337)
Expand Down Expand Up @@ -42,9 +44,39 @@ def test_metrics_output_values(sklearn_metric, torch_metric, size):
assert torch.allclose(sk.float(), tm.float())


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
[precision_at_k, retrieval_precision],
])
@pytest.mark.parametrize("size", [1, 4, 10])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k):
""" Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# test results are computed correctly wrt std implementation
for i in range(6):
preds = np.random.randn(size)
target = np.random.randn(size) > 0

# sometimes test with integer targets
if (i % 2) == 0:
target = target.astype(np.int)

sk = torch.tensor(sklearn_metric(target, preds, k), device=device)
tm = torch_metric(torch.tensor(preds, device=device), torch.tensor(target, device=device), k)

# `torch_metric`s return 0 when no label is True
# while `sklearn` metrics returns NaN
if math.isnan(sk):
assert tm == 0
else:
assert torch.allclose(sk.float(), tm.float())


@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
])
def test_input_dtypes(torch_metric) -> None:
""" Check wrong input dtypes are managed correctly. """
Expand Down Expand Up @@ -75,6 +107,7 @@ def test_input_dtypes(torch_metric) -> None:
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
])
def test_input_shapes(torch_metric) -> None:
""" Check wrong input shapes are managed correctly. """
Expand All @@ -93,3 +126,19 @@ def test_input_shapes(torch_metric) -> None:

with pytest.raises(ValueError, match="`preds` and `target` must be of the same shape"):
torch_metric(preds, target)


# test metrics using top K parameter
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_precision]
])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(torch_metric, k) -> None:
""" Check wrong input shapes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# test with random tensors
preds = torch.tensor([0] * 4, device=device, dtype=torch.float)
target = torch.tensor([0] * 4, device=device, dtype=torch.int64)
with pytest.raises(ValueError, match="`k` has to be a positive integer or None"):
torch_metric(preds, target, k=k)
22 changes: 17 additions & 5 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from torch import Tensor

from tests.helpers import seed_all
from torchmetrics import Metric

seed_all(1337)


def _compute_sklearn_metric(
metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str
metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str, **kwargs
) -> Tensor:
""" Compute metric with multiple iterations over every query predictions set. """
sk_results = []
Expand All @@ -25,7 +26,7 @@ def _compute_sklearn_metric(
else:
sk_results.append(0.0)
else:
res = metric(b, a)
res = metric(b, a, **kwargs)
sk_results.append(res)

if len(sk_results) > 0:
Expand All @@ -34,10 +35,15 @@ def _compute_sklearn_metric(


def _test_retrieval_against_sklearn(
sklearn_metric, torch_metric, size, n_documents, query_without_relevant_docs_options
sklearn_metric: Callable,
torch_metric: Metric,
size: int,
n_documents: int,
query_without_relevant_docs_options: str,
**kwargs
) -> None:
""" Compare PL metrics to standard version. """
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options)
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options, **kwargs)
shape = (size, )

indexes = []
Expand All @@ -49,7 +55,7 @@ def _test_retrieval_against_sklearn(
preds.append(np.random.randn(*shape))
target.append(np.random.randn(*shape) > 0)

sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options)
sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options, **kwargs)
sk_results = torch.tensor(sk_results)

indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]).long()
Expand Down Expand Up @@ -120,3 +126,9 @@ def _test_input_shapes(torchmetric) -> None:

with pytest.raises(ValueError, match="`indexes`, `preds` and `target` must be of the same shape"):
metric(indexes, preds, target)


def _test_input_args(torchmetric: Metric, message: str, **kwargs) -> None:
"""Check invalid args are managed correctly. """
with pytest.raises(ValueError, match=message):
torchmetric(**kwargs)
20 changes: 13 additions & 7 deletions tests/retrieval/test_mrr.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import numpy as np
import pytest
from sklearn.metrics import label_ranking_average_precision_score

from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR


def _reciprocal_rank(target: np.array, preds: np.array):
"""
Implementation of reciprocal rank because couldn't find a good implementation.
`sklearn.metrics.label_ranking_average_precision_score` is similar but works in a different way
then the number of positive labels is greater than 1.
Adaptation of `sklearn.metrics.label_ranking_average_precision_score`.
Since the original sklearn metric works as RR only when the number of positive
targets is exactly 1, here we remove every positive target that is not the most
important. Remember that in RR only the positive target with the highest score is considered.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

# going to remove T targets that are not ranked as highest
indexes = preds[target.astype(np.bool)]
if len(indexes) > 0:
target[preds != indexes.max(-1, keepdims=True)[0]] = 0 # ensure that only 1 positive label is present

if target.sum() > 0:
target = target[np.argsort(preds, axis=-1)][::-1]
rank = np.nonzero(target)[0][0] + 1
return 1.0 / rank
# sklearn `label_ranking_average_precision_score` requires at most 2 dims
return label_ranking_average_precision_score(np.expand_dims(target, axis=0), np.expand_dims(preds, axis=0))
else:
return np.NaN
return 0.0


@pytest.mark.parametrize('size', [1, 4, 10])
Expand Down
56 changes: 56 additions & 0 deletions tests/retrieval/test_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision


def _precision_at_k(target: np.array, preds: np.array, k: int = None):
"""
Didn't find a reliable implementation of Precision in Information Retrieval, so,
reimplementing here. A good explanation can be found ``
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
k = len(preds)

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][:k])
return relevant * 1.0 / k
else:
return np.NaN


@pytest.mark.parametrize('size', [1, 4, 10])
@pytest.mark.parametrize('n_documents', [1, 5])
@pytest.mark.parametrize('query_without_relevant_docs_options', ['skip', 'pos', 'neg'])
@pytest.mark.parametrize('k', [None, 1, 4, 10])
def test_results(size, n_documents, query_without_relevant_docs_options, k):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_precision_at_k,
RetrievalPrecision,
size,
n_documents,
query_without_relevant_docs_options,
k=k
)


def test_dtypes():
""" Check dypes are managed correctly. """
_test_dtypes(RetrievalPrecision)


def test_input_shapes() -> None:
"""Check inputs shapes are managed correctly. """
_test_input_shapes(RetrievalPrecision)


@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(k) -> None:
"""Check invalid args are managed correctly. """
_test_input_args(RetrievalPrecision, "`k` has to be a positive integer or None", k=k)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR # noqa: F401 E402
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
8 changes: 4 additions & 4 deletions torchmetrics/functional/retrieval/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor:
r"""
"""
Computes average precision (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`__.

Expand All @@ -31,11 +31,11 @@ def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor:
target: ground truth about each document being relevant or not.

Return:
a single-value tensor with the average precision (AP) of the predictions ``preds`` wrt the labels ``target``.
a single-value tensor with the average precision (AP) of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)
"""
Expand Down
56 changes: 56 additions & 0 deletions torchmetrics/functional/retrieval/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_precision(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes the precision metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Precision>`__.
Precision is the fraction of relevant documents among all the retrieved documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)

Returns:
a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_precision(preds, target, k=2)
tensor(0.5000)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
k = preds.shape[-1]

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

if target.sum() == 0:
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float()
return relevant / k
2 changes: 1 addition & 1 deletion torchmetrics/functional/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
r"""
"""
Computes reciprocal rank (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Mean_reciprocal_rank>`__.

Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
Loading