-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mean Average Precision metric for Information Retrieval (1/5) (PL^5032)
* init information retrieval metrics * changed retrieval metrics names, expanded arguments and fixed typo * added 'Retrieval' prefix to metrics and fixed conflict with already-present 'average_precision' file * improved code formatting * pep8 code compatibility * features/implemented new Mean Average Precision metrics for Information Retrieval + doc * fixed pep8 compatibility * removed threshold parameter and fixed typo on types in RetrievalMAP and improved doc * improved doc, put first class-specific args in RetrievalMetric and transformed RetrievalMetric in abstract class * implemented tests for functional and class metric. fixed typo when input tensors are empty or when all targets are False * fixed typos in doc and changed torch.true_divide to torch.div * fixed typos pep8 compatibility * fixed types in long division in ir_average_precision and example in mean_average_precision * RetrievalMetric states are not lists and _metric method accepts predictions and targets for easier extension * updated CHANGELOG file * added '# noqa: F401' flag to not used imports * added double space before '# noqa: F401' flag * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * change get_mini_groups in get_group_indexes * added checks on target inputs * minor refactoring for code cleanness * split tests over exception raising in separate function && refactored test code into multiple functions * fixed pep8 compatibility * implemented suggestions of @SkafteNicki * fixed imports for isort and added types annontations to functions in test_map.py * isort on test_map and fixed typing * isort on retrieval and on __init__.py and utils.py in metrics package * fixed typo in pytorch_lightning/metrics/__init__.py regarding code style * fixed yapf compatibility * fixed yapf compatibility * fixed typo in doc Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
79548d5
commit 60c0eec
Showing
12 changed files
with
467 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import math | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from sklearn.metrics import average_precision_score as sk_average_precision | ||
|
||
from torchmetrics.functional.classification.ir_average_precision import retrieval_average_precision | ||
|
||
|
||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ | ||
pytest.param(sk_average_precision, retrieval_average_precision), | ||
]) | ||
@pytest.mark.parametrize("size", [1, 4, 10, 100]) | ||
def test_against_sklearn(sklearn_metric, torch_metric, size): | ||
"""Compare PL metrics to sklearn version. """ | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
a = np.random.randn(size) | ||
b = np.random.randn(size) > 0 | ||
|
||
sk = torch.tensor(sklearn_metric(b, a), device=device) | ||
pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device)) | ||
|
||
# `torch_metric`s return 0 when no label is True | ||
# while `sklearn.average_precision_score` returns NaN | ||
if math.isnan(sk): | ||
assert pl == 0 | ||
else: | ||
assert torch.allclose(sk.float(), pl.float()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import math | ||
import random | ||
from typing import Callable, List | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from pytorch_lightning import seed_everything | ||
from sklearn.metrics import average_precision_score as sk_average_precision | ||
from torch import Tensor | ||
|
||
from torchmetrics.metric import Metric | ||
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP | ||
|
||
|
||
@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [ | ||
[sk_average_precision, RetrievalMAP], | ||
]) | ||
def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None: | ||
"""Compare PL metrics to sklearn version. """ | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
seed_everything(0) | ||
|
||
rounds = 20 | ||
sizes = [1, 4, 10, 100] | ||
batch_sizes = [1, 4, 10] | ||
query_without_relevant_docs_options = ['skip', 'pos', 'neg'] | ||
|
||
def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> Tensor: | ||
""" Compute sk metric with multiple iterations using the base `sklearn_metric`. """ | ||
sk_results = [] | ||
kwargs = {'device': device, 'dtype': torch.float32} | ||
|
||
for b, a in zip(target, preds): | ||
res = sklearn_metric(b, a) | ||
|
||
if math.isnan(res): | ||
if behaviour == 'skip': | ||
pass | ||
elif behaviour == 'pos': | ||
sk_results.append(torch.tensor(1.0, **kwargs)) | ||
else: | ||
sk_results.append(torch.tensor(0.0, **kwargs)) | ||
else: | ||
sk_results.append(torch.tensor(res, **kwargs)) | ||
if len(sk_results) > 0: | ||
sk_results = torch.stack(sk_results).mean() | ||
else: | ||
sk_results = torch.tensor(0.0, **kwargs) | ||
|
||
return sk_results | ||
|
||
def do_test(batch_size: int, size: int) -> None: | ||
""" For each possible behaviour of the metric, check results are correct. """ | ||
for behaviour in query_without_relevant_docs_options: | ||
metric = torch_class_metric(query_without_relevant_docs=behaviour) | ||
shape = (size, ) | ||
|
||
indexes = [] | ||
preds = [] | ||
target = [] | ||
|
||
for i in range(batch_size): | ||
indexes.append(np.ones(shape, dtype=int) * i) | ||
preds.append(np.random.randn(*shape)) | ||
target.append(np.random.randn(*shape) > 0) | ||
|
||
sk_results = compute_sklearn_metric(target, preds, behaviour) | ||
|
||
indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]) | ||
preds_tensor = torch.cat([torch.tensor(p) for p in preds]) | ||
target_tensor = torch.cat([torch.tensor(t) for t in target]) | ||
|
||
# lets assume data are not ordered | ||
perm = torch.randperm(indexes_tensor.nelement()) | ||
indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) | ||
preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) | ||
target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) | ||
|
||
# shuffle ids to require also sorting of documents ability from the lightning metric | ||
pl_result = metric(indexes_tensor, preds_tensor, target_tensor) | ||
|
||
assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True) | ||
|
||
for batch_size in batch_sizes: | ||
for size in sizes: | ||
for _ in range(rounds): | ||
do_test(batch_size, size) | ||
|
||
|
||
@pytest.mark.parametrize(['torch_class_metric'], [ | ||
[RetrievalMAP], | ||
]) | ||
def test_input_data(torch_class_metric: Metric) -> None: | ||
"""Check PL metrics inputs are controlled correctly. """ | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
seed_everything(0) | ||
|
||
for _ in range(10): | ||
|
||
length = random.randint(0, 20) | ||
|
||
# check error when `query_without_relevant_docs='error'` is raised correctly | ||
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) | ||
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) | ||
target = torch.tensor([False] * length, device=device, dtype=torch.bool) | ||
|
||
metric = torch_class_metric(query_without_relevant_docs='error') | ||
|
||
try: | ||
metric(indexes, preds, target) | ||
except Exception as e: | ||
assert isinstance(e, ValueError) | ||
|
||
# check ValueError with non-accepted argument | ||
try: | ||
metric = torch_class_metric(query_without_relevant_docs='casual_argument') | ||
except Exception as e: | ||
assert isinstance(e, ValueError) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,4 @@ | |
MeanSquaredLogError, | ||
R2Score, | ||
) | ||
from torchmetrics.retrieval import RetrievalMAP # noqa: F401 E402 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
torchmetrics/functional/classification/ir_average_precision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: | ||
r""" | ||
Computes average precision (for information retrieval), as explained | ||
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_. | ||
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, | ||
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. | ||
Args: | ||
preds: estimated probabilities of each document to be relevant. | ||
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. | ||
Return: | ||
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. | ||
Example: | ||
>>> preds = torch.tensor([0.2, 0.3, 0.5]) | ||
>>> target = torch.tensor([True, False, True]) | ||
>>> retrieval_average_precision(preds, target) | ||
tensor(0.8333) | ||
""" | ||
|
||
if preds.shape != target.shape or preds.device != target.device: | ||
raise ValueError("`preds` and `target` must have the same shape and live on the same device") | ||
|
||
if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): | ||
raise ValueError("`target` must be a tensor of booleans or integers") | ||
|
||
if target.dtype is not torch.bool: | ||
target = target.bool() | ||
|
||
if target.sum() == 0: | ||
return torch.tensor(0, device=preds.device) | ||
|
||
target = target[torch.argsort(preds, dim=-1, descending=True)] | ||
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] | ||
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 | ||
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.classification.ir_average_precision import retrieval_average_precision | ||
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric | ||
|
||
|
||
class RetrievalMAP(RetrievalMetric): | ||
r""" | ||
Computes `Mean Average Precision | ||
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_. | ||
Works with binary data. Accepts integer or float predictions from a model output. | ||
Forward accepts | ||
- ``indexes`` (long tensor): ``(N, ...)`` | ||
- ``preds`` (float tensor): ``(N, ...)`` | ||
- ``target`` (long or bool tensor): ``(N, ...)`` | ||
`indexes`, `preds` and `target` must have the same dimension. | ||
`indexes` indicate to which query a prediction belongs. | ||
Predictions will be first grouped by indexes and then MAP will be computed as the mean | ||
of the Average Precisions over each query. | ||
Args: | ||
query_without_relevant_docs: | ||
Specify what to do with queries that do not have at least a positive target. Choose from: | ||
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned | ||
- ``'error'``: raise a ``ValueError`` | ||
- ``'pos'``: score on those queries is counted as ``1.0`` | ||
- ``'neg'``: score on those queries is counted as ``0.0`` | ||
exclude: | ||
Do not take into account predictions where the target is equal to this value. default `-100` | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects | ||
the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When `None`, DDP | ||
will be used to perform the allgather. default: None | ||
Example: | ||
>>> from torchmetrics import RetrievalMAP | ||
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) | ||
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) | ||
>>> target = torch.tensor([False, False, True, False, True, False, False]) | ||
>>> map = RetrievalMAP() | ||
>>> map(indexes, preds, target) | ||
tensor(0.7500) | ||
>>> map.compute() | ||
tensor(0.7500) | ||
""" | ||
|
||
def _metric(self, preds: Tensor, target: Tensor) -> Tensor: | ||
valid_indexes = target != self.exclude | ||
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) |
Oops, something went wrong.