Skip to content

Commit

Permalink
Mean Average Precision metric for Information Retrieval (1/5) (PL^5032)
Browse files Browse the repository at this point in the history
* init information retrieval metrics

* changed retrieval metrics names, expanded arguments and fixed typo

* added 'Retrieval' prefix to metrics and fixed conflict with already-present 'average_precision' file

* improved code formatting

* pep8 code compatibility

* features/implemented new Mean Average Precision metrics for Information Retrieval + doc

* fixed pep8 compatibility

* removed threshold parameter and fixed typo on types in RetrievalMAP and improved doc

* improved doc, put first class-specific args in RetrievalMetric and transformed RetrievalMetric in abstract class

* implemented tests for functional and class metric. fixed typo when input tensors are empty or when all targets are False

* fixed typos in doc and changed torch.true_divide to torch.div

* fixed typos pep8 compatibility

* fixed types in long division in ir_average_precision and example in mean_average_precision

* RetrievalMetric states are not lists and _metric method accepts predictions and targets for easier extension

* updated CHANGELOG file

* added '# noqa: F401' flag to not used imports

* added double space before '# noqa: F401' flag

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <[email protected]>

* change get_mini_groups in get_group_indexes

* added checks on target inputs

* minor refactoring for code cleanness

* split tests over exception raising in separate function && refactored test code into multiple functions

* fixed pep8 compatibility

* implemented suggestions of @SkafteNicki

* fixed imports for isort and added types annontations to functions in test_map.py

* isort on test_map and fixed typing

* isort on retrieval and on __init__.py and utils.py in metrics package

* fixed typo in pytorch_lightning/metrics/__init__.py regarding code style

* fixed yapf compatibility

* fixed yapf compatibility

* fixed typo in doc

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Mar 22, 2021
1 parent 79548d5 commit 60c0eec
Show file tree
Hide file tree
Showing 12 changed files with 467 additions and 10 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ ignore_errors = True
# todo: add proper typing to this module...
[mypy-torchmetrics.regression.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.retrieval.*]
ignore_errors = True
30 changes: 30 additions & 0 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import math

import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as sk_average_precision

from torchmetrics.functional.classification.ir_average_precision import retrieval_average_precision


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_average_precision, retrieval_average_precision),
])
@pytest.mark.parametrize("size", [1, 4, 10, 100])
def test_against_sklearn(sklearn_metric, torch_metric, size):
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

a = np.random.randn(size)
b = np.random.randn(size) > 0

sk = torch.tensor(sklearn_metric(b, a), device=device)
pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device))

# `torch_metric`s return 0 when no label is True
# while `sklearn.average_precision_score` returns NaN
if math.isnan(sk):
assert pl == 0
else:
assert torch.allclose(sk.float(), pl.float())
120 changes: 120 additions & 0 deletions tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import math
import random
from typing import Callable, List

import numpy as np
import pytest
import torch
from pytorch_lightning import seed_everything
from sklearn.metrics import average_precision_score as sk_average_precision
from torch import Tensor

from torchmetrics.metric import Metric
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP


@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [
[sk_average_precision, RetrievalMAP],
])
def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None:
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(0)

rounds = 20
sizes = [1, 4, 10, 100]
batch_sizes = [1, 4, 10]
query_without_relevant_docs_options = ['skip', 'pos', 'neg']

def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> Tensor:
""" Compute sk metric with multiple iterations using the base `sklearn_metric`. """
sk_results = []
kwargs = {'device': device, 'dtype': torch.float32}

for b, a in zip(target, preds):
res = sklearn_metric(b, a)

if math.isnan(res):
if behaviour == 'skip':
pass
elif behaviour == 'pos':
sk_results.append(torch.tensor(1.0, **kwargs))
else:
sk_results.append(torch.tensor(0.0, **kwargs))
else:
sk_results.append(torch.tensor(res, **kwargs))
if len(sk_results) > 0:
sk_results = torch.stack(sk_results).mean()
else:
sk_results = torch.tensor(0.0, **kwargs)

return sk_results

def do_test(batch_size: int, size: int) -> None:
""" For each possible behaviour of the metric, check results are correct. """
for behaviour in query_without_relevant_docs_options:
metric = torch_class_metric(query_without_relevant_docs=behaviour)
shape = (size, )

indexes = []
preds = []
target = []

for i in range(batch_size):
indexes.append(np.ones(shape, dtype=int) * i)
preds.append(np.random.randn(*shape))
target.append(np.random.randn(*shape) > 0)

sk_results = compute_sklearn_metric(target, preds, behaviour)

indexes_tensor = torch.cat([torch.tensor(i) for i in indexes])
preds_tensor = torch.cat([torch.tensor(p) for p in preds])
target_tensor = torch.cat([torch.tensor(t) for t in target])

# lets assume data are not ordered
perm = torch.randperm(indexes_tensor.nelement())
indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size())
preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size())
target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size())

# shuffle ids to require also sorting of documents ability from the lightning metric
pl_result = metric(indexes_tensor, preds_tensor, target_tensor)

assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True)

for batch_size in batch_sizes:
for size in sizes:
for _ in range(rounds):
do_test(batch_size, size)


@pytest.mark.parametrize(['torch_class_metric'], [
[RetrievalMAP],
])
def test_input_data(torch_class_metric: Metric) -> None:
"""Check PL metrics inputs are controlled correctly. """

device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(0)

for _ in range(10):

length = random.randint(0, 20)

# check error when `query_without_relevant_docs='error'` is raised correctly
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32)
target = torch.tensor([False] * length, device=device, dtype=torch.bool)

metric = torch_class_metric(query_without_relevant_docs='error')

try:
metric(indexes, preds, target)
except Exception as e:
assert isinstance(e, ValueError)

# check ValueError with non-accepted argument
try:
metric = torch_class_metric(query_without_relevant_docs='casual_argument')
except Exception as e:
assert isinstance(e, ValueError)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.ir_average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
from torchmetrics.functional.classification.roc import roc # noqa: F401
Expand Down
55 changes: 55 additions & 0 deletions torchmetrics/functional/classification/ir_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor


def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor:
r"""
Computes average precision (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_.
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``,
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised.
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor.
Return:
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.
Example:
>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError("`preds` and `target` must have the same shape and live on the same device")

if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64):
raise ValueError("`target` must be a tensor of booleans or integers")

if target.dtype is not torch.bool:
target = target.bool()

if target.sum() == 0:
return torch.tensor(0, device=preds.device)

target = target[torch.argsort(preds, dim=-1, descending=True)]
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0]
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean()
return res
6 changes: 3 additions & 3 deletions torchmetrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Optional

import torch
from torch import Tensor
from torch import Tensor, tensor

from torchmetrics.functional.regression.mean_squared_error import (
_mean_squared_error_compute,
Expand Down Expand Up @@ -64,8 +64,8 @@ def __init__(
dist_sync_fn=dist_sync_fn,
)

self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor):
"""
Expand Down
12 changes: 6 additions & 6 deletions torchmetrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torch import Tensor, tensor

from torchmetrics.functional.regression.psnr import _psnr_compute, _psnr_update
from torchmetrics.metric import Metric
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(
rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')

if dim is None:
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
else:
self.add_state("sum_squared_error", default=[])
self.add_state("total", default=[])
Expand All @@ -99,10 +99,10 @@ def __init__(
raise ValueError("The `data_range` must be given when `dim` is not None.")

self.data_range = None
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
self.add_state("min_target", default=tensor(0.0), dist_reduce_fx=torch.min)
self.add_state("max_target", default=tensor(0.0), dist_reduce_fx=torch.max)
else:
self.register_buffer("data_range", torch.tensor(float(data_range)))
self.register_buffer("data_range", tensor(float(data_range)))
self.base = base
self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
Expand Down
15 changes: 15 additions & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
62 changes: 62 additions & 0 deletions torchmetrics/retrieval/mean_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from torch import Tensor

from torchmetrics.functional.classification.ir_average_precision import retrieval_average_precision
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric


class RetrievalMAP(RetrievalMetric):
r"""
Computes `Mean Average Precision
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_.
Works with binary data. Accepts integer or float predictions from a model output.
Forward accepts
- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
`indexes`, `preds` and `target` must have the same dimension.
`indexes` indicate to which query a prediction belongs.
Predictions will be first grouped by indexes and then MAP will be computed as the mean
of the Average Precisions over each query.
Args:
query_without_relevant_docs:
Specify what to do with queries that do not have at least a positive target. Choose from:
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``
exclude:
Do not take into account predictions where the target is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
Example:
>>> from torchmetrics import RetrievalMAP
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])
>>> map = RetrievalMAP()
>>> map(indexes, preds, target)
tensor(0.7500)
>>> map.compute()
tensor(0.7500)
"""

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
valid_indexes = target != self.exclude
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes])
Loading

0 comments on commit 60c0eec

Please sign in to comment.