Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (5/5) #160

Merged
merged 40 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c8f0d80
init transition to standard metric interface for IR metrics
lucadiliello Mar 31, 2021
cc976b0
fixed typo in dtypes checks
lucadiliello Apr 1, 2021
8555c78
removed IGNORE_IDX, refactored tests using
lucadiliello Apr 3, 2021
34bcefa
added pep8 compatibility
lucadiliello Apr 3, 2021
fbaf05f
fixed np.ndarray to np.array
lucadiliello Apr 3, 2021
644477e
remove lambda functions
lucadiliello Apr 3, 2021
36ce60a
fixed typos with numpy dtype
lucadiliello Apr 3, 2021
44e2db6
fixed typo in doc example
lucadiliello Apr 3, 2021
f516e42
fixed typo in doc examples about new indexes position
lucadiliello Apr 3, 2021
1496d1b
added paramter to class testing to divide kwargs as preds and targets…
lucadiliello Apr 4, 2021
cc42311
added typo in doc example
lucadiliello Apr 4, 2021
8bb4830
added typo with new parameter frament_kwargs in MetricTester
lucadiliello Apr 4, 2021
e7c7e96
added typo in .cpu() conversion of non-tensor values
lucadiliello Apr 4, 2021
40fd75b
improved test coverage
lucadiliello Apr 4, 2021
7b3d2f8
improved test coverage
lucadiliello Apr 4, 2021
01c43de
added check on Tensor class to avoid calling .cpu() on non-tensor values
lucadiliello Apr 4, 2021
d084ff6
implemented functional ndcg
lucadiliello Apr 5, 2021
bb62f51
improved doc and changed default values for 'empty_target_action' arg…
lucadiliello Apr 5, 2021
9267c16
Merge branch 'refactor-IR-test' into feature-ir_ndcg
lucadiliello Apr 5, 2021
b9fe785
Implemented Normalized Discounted Cumulative Gain
lucadiliello Apr 5, 2021
148a908
refactored tests lists
lucadiliello Apr 5, 2021
aa8c716
Merge branch 'refactor-IR-test' into feature-ir_ndcg
lucadiliello Apr 5, 2021
c773c66
refactoring and fixed pep8 compatibility
lucadiliello Apr 5, 2021
f918a80
Merge branch 'master' into refactor-IR-test
Borda Apr 6, 2021
62c93f5
formatting
Borda Apr 6, 2021
5b2ce58
simple
Borda Apr 6, 2021
3a3871b
agrs
Borda Apr 6, 2021
7a3f8c9
format
Borda Apr 6, 2021
d350765
_sk
Borda Apr 6, 2021
c796557
fixed typo in tests
lucadiliello Apr 6, 2021
2047906
Merge branch 'refactor-IR-test' into feature-ir_ndcg
lucadiliello Apr 6, 2021
f1c49aa
updated tests experiments from tuple to dict
lucadiliello Apr 6, 2021
a71aa81
fixed typo in doc
lucadiliello Apr 6, 2021
82e2fff
fixed merge with master
lucadiliello Apr 6, 2021
fccd809
chlog
Borda Apr 6, 2021
8304f93
Update torchmetrics/functional/retrieval/ndcg.py
lucadiliello Apr 6, 2021
8a4483d
moved dcg to global scope as _dcg
lucadiliello Apr 6, 2021
8c80aa7
Update torchmetrics/functional/retrieval/ndcg.py
SkafteNicki Apr 6, 2021
0e7d93f
Merge branch 'master' into feature-ir_ndcg
lucadiliello Apr 6, 2021
c1f32a9
Update torchmetrics/functional/retrieval/ndcg.py
SkafteNicki Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalRecall` metric for Information Retrieval ([#146](https://github.com/PyTorchLightning/metrics/pull/146))


- Added `RetrievalNormalizedDCG` metric for Information Retrieval ([#160](https://github.com/PyTorchLightning/metrics/pull/160))


- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))


Expand Down
9 changes: 8 additions & 1 deletion docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,14 @@ retrieval_precision [func]


retrieval_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_recall
:noindex:


retrieval_normalized_dcg [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_normalized_dcg
:noindex:
9 changes: 8 additions & 1 deletion docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,19 @@ RetrievalPrecision


RetrievalRecall
~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalRecall
:noindex:


RetrievalNormalizedDCG
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalNormalizedDCG
:noindex:


********
Wrappers
********
Expand Down
148 changes: 148 additions & 0 deletions tests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from sklearn.metrics import ndcg_score
from torch import Tensor

from tests.helpers import seed_all
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
)
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG

seed_all(42)


def _ndcg_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""
Adapting `from sklearn.metrics.ndcg_score`.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if target.shape[0] < 2: # ranking is equal to ideal ranking with a single document
return np.array(1.0)

preds = np.expand_dims(preds, axis=0)
target = np.expand_dims(target, axis=0)

return ndcg_score(target, preds, k=k)


class TestNDCG(RetrievalMetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = {'empty_target_action': empty_target_action, 'k': k}

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalNormalizedDCG,
sk_metric=_ndcg_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_normalized_dcg,
sk_metric=_ndcg_at_k,
metric_args={},
k=k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalNormalizedDCG,
metric_functional=retrieval_normalized_dcg,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalNormalizedDCG,
metric_functional=retrieval_normalized_dcg,
)

@pytest.mark.parametrize(**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_class_metric_parameters_k,
))
def test_arguments_class_metric(
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict,
):
self.run_metric_class_arguments_test(
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalNormalizedDCG,
message=message,
metric_args=metric_args,
exception_type=ValueError,
kwargs_update={},
)

@pytest.mark.parametrize(**_concat_tests(
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
))
def test_arguments_functional_metric(
self, preds: Tensor, target: Tensor, message: str, metric_args: dict,
):
self.run_functional_metric_arguments_test(
preds=preds,
target=target,
metric_functional=retrieval_normalized_dcg,
message=message,
exception_type=ValueError,
kwargs_update=metric_args,
)
8 changes: 7 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,11 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision, RetrievalRecall # noqa: F401 E402
from torchmetrics.retrieval import ( # noqa: F401 E402
RetrievalMAP,
RetrievalMRR,
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
61 changes: 61 additions & 0 deletions torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def _dcg(target):
return (target / torch.log2(torch.arange(target.shape[-1]) + 2.0)).sum()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes Normalized Discounted Cumulative Gain (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`__.

``preds`` and ``target`` should be of the same shape and live on the same device.
``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document relevance.
k: consider only the top k elements (default: None)

Return:
a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> from torchmetrics.functional import retrieval_normalized_dcg
>>> preds = torch.tensor([.1, .2, .3, 4, 70])
>>> target = torch.tensor([10, 0, 0, 1, 5])
>>> retrieval_normalized_dcg(preds, target)
tensor(0.6957)
"""
preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True)

k = preds.shape[-1] if k is None else k

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

if not target.sum():
return tensor(0.0, device=preds.device)

sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k]
ideal_target = torch.sort(target, descending=True)[0][:k]

return _dcg(sorted_target) / _dcg(ideal_target)
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
94 changes: 94 additions & 0 deletions torchmetrics/retrieval/retrieval_ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

from torch import Tensor, tensor

from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric


class RetrievalNormalizedDCG(RetrievalMetric):
"""
Computes `Normalized Discounted Cumulative Gain
<https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`__.

Works with binary or positive integer target data. Accepts float predictions from a model output.

Forward accepts:

- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
- ``indexes`` (long tensor): ``(N, ...)``

``indexes``, ``preds`` and ``target`` must have the same dimension.
``indexes`` indicate to which query a prediction belongs.
Predictions will be first grouped by ``indexes`` and then `Normalized Discounted Cumulative Gain`
will be computed as the mean of the `Normalized Discounted Cumulative Gain` over each query.

Args:
empty_target_action:
Specify what to do with queries that do not have at least a positive ``target``. Choose from:

- ``'neg'``: those queries count as ``0.0`` (default)
- ``'pos'``: those queries count as ``1.0``
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
k: consider only the top k elements for each query. default: None

Example:
>>> from torchmetrics import RetrievalNormalizedDCG
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> ndcg = RetrievalNormalizedDCG()
>>> ndcg(preds, target, indexes=indexes)
tensor(0.8467)
"""

def __init__(
self,
empty_target_action: str = 'neg',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
k: int = None
):
super().__init__(
empty_target_action=empty_target_action,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)

if (k is not None) and not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")
self.k = k

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_normalized_dcg(preds, target, k=self.k)
Loading