Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (IR) metrics implementation (MAP, MRR, P@K, R@K, HR@K) [wip] #4991

Closed
wants to merge 10 commits into from
8 changes: 8 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@
PSNR,
SSIM,
)

from pytorch_lightning.metrics.retrieval import (
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
MeanAveragePrecision,
MeanReciprocalRank,
PrecisionAtK,
RecallAtK,
HitRateAtK,
)
53 changes: 53 additions & 0 deletions pytorch_lightning/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def average_precision(
preds: torch.Tensor,
target: torch.Tensor
):
"""
Computes average precision metric for information retrieval,
as explained here: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision

`preds` and `target` should be of the same shape and live on the same device. If not target is true, 0 is returned.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant.

Returns:
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.

Example:

>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> average_precision(preds, target)
... 0.833
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError(
f"`preds` and `target` must have the same shape and be on the same device"
)

if target.sum() == 0:
return torch.tensor(0).to(preds)

target = target[torch.argsort(preds, dim=-1, descending=True)]
positions = torch.arange(1, len(target) + 1, device=target.device)[target > 0]
res = torch.true_divide((torch.arange(len(positions), device=positions.device) + 1), positions).mean()
return res
50 changes: 50 additions & 0 deletions pytorch_lightning/metrics/functional/hit_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def hit_rate(
preds: torch.Tensor,
target: torch.Tensor,
k: int = 1
):
"""
Computes the hit_rate @ k metric for information retrieval
Hir Rate at k is 1 iff there is at least one relevant documents among the top K.

`preds` and `target` should be of the same shape and live on the same device. If not target is true, 0 is returned.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant.
k: consider only the top k elements.

Returns:
a single-value tensor with the hit rate at k (HR@K) of the predictions `preds` wrt the labels `target`.

Example:

>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> hit_rate(preds, target, k=2)
... 1.0
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError(
f"`preds` and `target` must have the same shape and be on the same device"
)

revelevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum()
return (revelevant > 0).to(preds)
54 changes: 54 additions & 0 deletions pytorch_lightning/metrics/functional/ir_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def precision(
preds: torch.Tensor,
target: torch.Tensor,
k: int = 1
):
"""
Computes the precision @ k metric for information retrieval,
as explained here: https://en.wikipedia.org/wiki/Precision_and_recall#Definition_(information_retrieval_context)
Precision at K is the fraction of relevant documents among the top K.

`preds` and `target` should be of the same shape and live on the same device. If not target is true, 0 is returned.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant.
k: consider only the top k elements.

Returns:
a single-value tensor with the precision at k (P@K) of the predictions `preds` wrt the labels `target`.

Example:

>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> precision(preds, target, k=2)
... 0.5
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError(
f"`preds` and `target` must have the same shape and be on the same device"
)

if target.sum() == 0:
return torch.tensor(0).to(preds)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum()
return torch.true_divide(relevant, k)
54 changes: 54 additions & 0 deletions pytorch_lightning/metrics/functional/ir_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def recall(
preds: torch.Tensor,
target: torch.Tensor,
k: int = 1
):
"""
Computes the recall @ k metric for information retrieval,
as explained here: https://en.wikipedia.org/wiki/Precision_and_recall#Definition_(information_retrieval_context)
Recall at K is the fraction of relevant documents in top K among all the relevant documents.

`preds` and `target` should be of the same shape and live on the same device. If not target is true, 0 is returned.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant.
k: consider only the top k elements.

Returns:
a single-value tensor with the recall at k (R@K) of the predictions `preds` wrt the labels `target`.

Example:

>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> recall(preds, target, k=2)
... 0.5
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError(
f"`preds` and `target` must have the same shape and be on the same device"
)

if target.sum() == 0:
return torch.tensor(0).to(preds)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum()
return torch.true_divide(relevant, target.sum())
53 changes: 53 additions & 0 deletions pytorch_lightning/metrics/functional/reciprocal_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def reciprocal_rank(
preds: torch.Tensor,
target: torch.Tensor
):
"""
Computes reciprocal rank metric for information retrieval,
as explained here: https://en.wikipedia.org/wiki/Mean_reciprocal_rank

`preds` and `target` should be of the same shape and live on the same device. If not target is true, 0 is returned.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant.

Returns:
a single-value tensor with the reciprocal rank (RR) of the predictions `preds` wrt the labels `target`.

Example:

>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([False, True, False])
>>> reciprocal_rank(preds, target)
... 0.5
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError(
f"`preds` and `target` must have the same shape and be on the same device"
)

if target.sum() == 0:
return torch.tensor(0).to(preds)

target = target[torch.argsort(preds, dim=-1, descending=True)]
position = torch.where(target == 1)[0]
res = 1.0 / (position[0] + 1)
return res
19 changes: 19 additions & 0 deletions pytorch_lightning/metrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#from pytorch_lightning.metrics.retrieval import
from pytorch_lightning.metrics.retrieval.mean_average_precision import MeanAveragePrecision
from pytorch_lightning.metrics.retrieval.mean_reciprocal_rank import MeanReciprocalRank
from pytorch_lightning.metrics.retrieval.precision import PrecisionAtK
from pytorch_lightning.metrics.retrieval.recall import RecallAtK
from pytorch_lightning.metrics.retrieval.hit_rate import HitRateAtK
35 changes: 35 additions & 0 deletions pytorch_lightning/metrics/retrieval/hit_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List

from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
from pytorch_lightning.metrics.functional.hit_rate import hit_rate


class HitRateAtK(RetrievalMetric):
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
"""
Hit Rate at K computes the HR@K over multiple retrieved documents for each query.
Each hit rate at k computation over a single query can be done on a different number
of predictions thanks to the usage of a tensor dedicated to separate query results.

Notice that HR@1 == P@1

Example:

>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])

>>> hr_k = HitRateAtK(k=1)
>>> hr_k(indexes, preds, target)
>>> hr_k.compute()
... 0.5
"""

def __init__(self, *args, k=1, **kwargs):
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self.k = k

def metric(self, group: List[int]):
_preds = self.preds[group]
_target = self.target[group]
valid_indexes = (_target != self.exclude)
return hit_rate(_preds[valid_indexes], _target[valid_indexes], k=self.k)
29 changes: 29 additions & 0 deletions pytorch_lightning/metrics/retrieval/mean_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List

from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
from pytorch_lightning.metrics.functional.average_precision import average_precision


class MeanAveragePrecision(RetrievalMetric):
"""
Mean Average Precision computes the MAP over multiple retrieved documents for each query.
Each average precision computation on a single query can be done on a different number of
predictions thanks to the usage of a tensor dedicated to separate query results.

Example:

>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])

>>> map = MeanAveragePrecision()
>>> map(indexes, preds, target)
>>> map.compute()
... 0.75
"""

def metric(self, group: List[int]):
_preds = self.preds[group]
_target = self.target[group]
valid_indexes = _target != self.exclude
return average_precision(_preds[valid_indexes], _target[valid_indexes])
29 changes: 29 additions & 0 deletions pytorch_lightning/metrics/retrieval/mean_reciprocal_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List

from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
from pytorch_lightning.metrics.functional.reciprocal_rank import reciprocal_rank


class MeanReciprocalRank(RetrievalMetric):
"""
Mean Reciprocal Rank computes the MRR over multiple retrieved documents for each query.
Each reciprocal rank computation on a single query can be done on a different number of
predictions thanks to the usage of a tensor dedicated to separate query results.

Example:

>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])

>>> mrr = MeanReciprocalRank()
>>> mrr(indexes, preds, target)
>>> mrr.compute()
... 0.75
"""

def metric(self, group: List[int]):
_preds = self.preds[group]
_target = self.target[group]
valid_indexes = (_target != self.exclude)
return reciprocal_rank(_preds[valid_indexes], _target[valid_indexes])
Loading