Skip to content

Commit

Permalink
Fix tie breaking in ndcg metric (#2031)
Browse files Browse the repository at this point in the history
* fix implementation

* add tests

* chlog

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
(cherry picked from commit 1caaf28)
  • Loading branch information
SkafteNicki authored and Borda committed Sep 11, 2023
1 parent 725f493 commit 5c2db0b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))
- Fixed tie breaking in ndcg metric ([#2031](https://github.com/Lightning-AI/torchmetrics/pull/2031))


- Fixed bug in `BootStrapper` when very few samples were evaluated that could lead to crash ([#2052](https://github.com/Lightning-AI/torchmetrics/pull/2052))
Expand All @@ -24,11 +24,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060))


- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


## [1.1.1] - 2023-08-29

### Added

- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018)
- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018))

### Fixed

Expand Down
66 changes: 53 additions & 13 deletions src/torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,53 @@
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def _dcg(target: Tensor) -> Tensor:
"""Compute Discounted Cumulative Gain for input tensor."""
denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
return (target / denom).sum(dim=-1)
def _tie_average_dcg(target: Tensor, preds: Tensor, discount_cumsum: Tensor) -> Tensor:
"""Translated version of sklearns `_tie_average_dcg` function.
Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
discount_cumsum: cumulative sum of the discount.
Returns:
The cumulative gain of the tied elements.
"""
_, inv, counts = torch.unique(-preds, return_inverse=True, return_counts=True)
ranked = torch.zeros_like(counts, dtype=torch.float32)
ranked.scatter_add_(0, inv, target.to(dtype=ranked.dtype))
ranked = ranked / counts
groups = counts.cumsum(dim=0) - 1
discount_sums = torch.zeros_like(counts, dtype=torch.float32)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = discount_cumsum[groups].diff()
return (ranked * discount_sums).sum()


def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor:
"""Translated version of sklearns `_dcg_sample_scores` function.
Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
top_k: consider only the top k elements
ignore_ties: If True, ties are ignored. If False, ties are averaged.
Returns:
The cumulative gain
"""
discount = 1.0 / (torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0))
discount[top_k:] = 0.0

if ignore_ties:
ranking = preds.argsort(descending=True)
ranked = target[ranking]
cumulative_gain = (discount * ranked).sum()
else:
discount_cumsum = discount.cumsum(dim=-1)
cumulative_gain = _tie_average_dcg(target, preds, discount_cumsum)
return cumulative_gain


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
Expand Down Expand Up @@ -59,15 +102,12 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
if not (isinstance(top_k, int) and top_k > 0):
raise ValueError("`top_k` has to be a positive integer or None")

sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:top_k]
ideal_target = torch.sort(target, descending=True)[0][:top_k]

ideal_dcg = _dcg(ideal_target)
target_dcg = _dcg(sorted_target)
gain = _dcg_sample_scores(target, preds, top_k, ignore_ties=False)
normalized_gain = _dcg_sample_scores(target, target, top_k, ignore_ties=True)

# filter undefined scores
all_irrelevant = ideal_dcg == 0
target_dcg[all_irrelevant] = 0
target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant]
all_irrelevant = normalized_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalized_gain[~all_irrelevant]

return target_dcg.mean()
return gain.mean()
13 changes: 13 additions & 0 deletions tests/unittests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import pytest
import torch
from sklearn.metrics import ndcg_score
from torch import Tensor
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
Expand Down Expand Up @@ -185,3 +186,15 @@ def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, messag
exception_type=ValueError,
kwargs_update=metric_args,
)


def test_corner_case_with_tied_scores():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/2022."""
target = torch.tensor([[10, 0, 0, 1, 5]])
preds = torch.tensor([[0.1, 0, 0, 0, 0.1]])

for k in [1, 3, 5]:
assert torch.allclose(
retrieval_normalized_dcg(preds, target, top_k=k),
torch.tensor([ndcg_score(target, preds, k=k)], dtype=torch.float32),
)

0 comments on commit 5c2db0b

Please sign in to comment.