Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Rand Score #2025

Merged
merged 59 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
b726b2e
working implementation
matsumotosan Aug 19, 2023
a065ef1
passing functional and basic error tests
matsumotosan Aug 19, 2023
f355a3b
working implementation
matsumotosan Aug 19, 2023
e6862da
passing functional and basic error tests
matsumotosan Aug 19, 2023
432d2d0
Merge branch '2003-mutual-info-score' of https://github.com/matsumoto…
matsumotosan Aug 21, 2023
fbfae57
clean up naming and imports
matsumotosan Aug 21, 2023
f72183d
push metric class (broken but to allow review)
matsumotosan Aug 21, 2023
7fe14e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2023
808b278
add docs files
matsumotosan Aug 21, 2023
a0308d2
releasing 1.1.0
Borda Aug 22, 2023
6eddb2e
Merge branch 'master' into 2003-mutual-info-score
SkafteNicki Aug 22, 2023
fcd44b5
Merge branch 'master' into 2003-mutual-info-score
matsumotosan Aug 22, 2023
0d3fec9
Create util functions for clustering. Fix metric implementation.
matsumotosan Aug 22, 2023
d13c6f8
Merge branch 'master' into 2003-mutual-info-score
matsumotosan Aug 22, 2023
7dad1f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2023
c36d8a0
Fix ruff-related errors
matsumotosan Aug 22, 2023
71956f4
Merge branch '2003-mutual-info-score' of https://github.com/matsumoto…
matsumotosan Aug 22, 2023
f677483
Fix docstring examples
matsumotosan Aug 22, 2023
0d361d1
Test functional metric for symmetry
matsumotosan Aug 22, 2023
1a01690
Merge branch 'master' into 2003-mutual-info-score
matsumotosan Aug 23, 2023
422ace3
changelog
SkafteNicki Aug 23, 2023
bf05b8b
Fix type hint error. Additional checks for tensor shapes.
matsumotosan Aug 23, 2023
e9a1233
Update src/torchmetrics/clustering/mutual_info_score.py
matsumotosan Aug 23, 2023
9cff876
Update src/torchmetrics/clustering/mutual_info_score.py
matsumotosan Aug 23, 2023
3ecd697
Merge branch 'master' into 2003-mutual-info-score
matsumotosan Aug 23, 2023
1c967ef
Merge branch '2003-mutual-info-score' of https://github.com/matsumoto…
matsumotosan Aug 23, 2023
e4523d4
Test contingency matrix calculation
matsumotosan Aug 24, 2023
f1cc3df
fix mutual info score calculation. all test passing.
matsumotosan Aug 24, 2023
f278c5c
fix plotting docstring
matsumotosan Aug 24, 2023
c866355
add paren
matsumotosan Aug 24, 2023
6a4a423
Merge branch 'master' into 2003-mutual-info-score
matsumotosan Aug 24, 2023
ca5ff5f
fix doc import
SkafteNicki Aug 25, 2023
157e8f8
fix on gpu
SkafteNicki Aug 25, 2023
1d6693a
add implementation
SkafteNicki Aug 25, 2023
700617b
add docs
SkafteNicki Aug 25, 2023
74dd965
fix text
SkafteNicki Aug 25, 2023
0f18828
Merge branch 'master' into newmetric/rand_score
SkafteNicki Aug 25, 2023
3f5536d
fixes to doctest + docstrings
SkafteNicki Aug 25, 2023
747a4ad
changelog
SkafteNicki Aug 25, 2023
42b10d0
fix
SkafteNicki Aug 25, 2023
8ec288a
Merge branch 'master' into newmetric/rand_score
Borda Aug 25, 2023
1a6d39f
Update tests/unittests/clustering/test_rand_score.py
SkafteNicki Aug 26, 2023
25eb88d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
4bf9595
add implementation
SkafteNicki Aug 25, 2023
349408f
add docs
SkafteNicki Aug 25, 2023
d716114
fix text
SkafteNicki Aug 25, 2023
c272ee3
fixes to doctest + docstrings
SkafteNicki Aug 25, 2023
e733451
changelog
SkafteNicki Aug 25, 2023
741f5df
fix
SkafteNicki Aug 25, 2023
b75c0b3
Update tests/unittests/clustering/test_rand_score.py
SkafteNicki Aug 26, 2023
30dfbde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
1b61095
Merge branch 'newmetric/rand_score' of https://github.com/PyTorchLigh…
SkafteNicki Aug 26, 2023
f4386c5
Update src/torchmetrics/clustering/rand_score.py
SkafteNicki Aug 26, 2023
a0581fa
Update src/torchmetrics/clustering/rand_score.py
SkafteNicki Aug 26, 2023
10c579b
fix mypy
SkafteNicki Aug 26, 2023
77efd31
Merge branch 'master' into newmetric/rand_score
Borda Aug 27, 2023
90bf3b7
skip on too much memory
SkafteNicki Aug 28, 2023
2884f37
Merge branch 'master' into newmetric/rand_score
mergify[bot] Aug 28, 2023
d567362
Merge branch 'master' into newmetric/rand_score
mergify[bot] Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)


- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)


### Changed

-
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/rand_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Rand Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
:tags: Clustering

.. include:: ../links.rst

##########
Rand Score
##########

Module Interface
________________

.. autoclass:: torchmetrics.clustering.RandScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.rand_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"MutualInfoScore",
"RandScore",
]
4 changes: 2 additions & 2 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class MutualInfoScore(Metric):

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
- ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

Expand Down
122 changes: 122 additions & 0 deletions src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.rand_score import rand_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RandScore.plot"]


class RandScore(Metric):
r"""Compute `Rand Score`_ (alternative know as Rand Index).
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

.. math::
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}

The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.

The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score

Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> metric = RandScore()
>>> metric(preds, target)
tensor(0.6000)

"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Compute mutual information over state."""
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

"""
return self._plot(val, ax)
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = ["mutual_info_score"]
__all__ = ["mutual_info_score", "rand_score"]
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute mutual information between two clusterings.

Args:
preds: predicted classes
target: ground truth classes
preds: predicted cluster labels
target: ground truth cluster labels

Example:
>>> from torchmetrics.functional.clustering import mutual_info_score
Expand Down
82 changes: 82 additions & 0 deletions src/torchmetrics/functional/clustering/rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.clustering.utils import (
calcualte_pair_cluster_confusion_matrix,
calculate_contingency_matrix,
check_cluster_labels,
)


def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor:
"""Update and return variables required to compute the rand score.

Args:
preds: predicted cluster labels
target: ground truth cluster labels

Returns:
contingency: contingency matrix

"""
check_cluster_labels(preds, target)
return calculate_contingency_matrix(preds, target)


def _rand_score_compute(contingency: Tensor) -> Tensor:
"""Compute the rand score based on the contingency matrix.

Args:
contingency: contingency matrix

Returns:
rand_score: rand score

"""
pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency)

numerator = pair_matrix.diagonal().sum()
denominator = pair_matrix.sum()
if numerator == denominator or denominator == 0:
# Special limit cases: no clustering since the data is not split;
# or trivial clustering where each document is assigned a unique
# cluster. These are perfect matches hence return 1.0.
return torch.ones_like(numerator, dtype=torch.float32)

return numerator / denominator


def rand_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute the Rand score between two clusterings.

Args:
preds: predicted cluster labels
target: ground truth cluster labels

Returns:
scalar tensor with the rand score

Example:
>>> from torchmetrics.functional.clustering import rand_score
>>> import torch
>>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0]))
tensor(1.)
>>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1]))
tensor(0.8333)

"""
contingency = _rand_score_update(preds, target)
return _rand_score_compute(contingency)
67 changes: 67 additions & 0 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,70 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
f"Expected real, discrete values but received {preds.dtype} for"
f"predictions and {target.dtype} for target labels instead."
)


def calcualte_pair_cluster_confusion_matrix(
preds: Optional[Tensor] = None,
target: Optional[Tensor] = None,
contingency: Optional[Tensor] = None,
) -> Tensor:
"""Calculates the pair cluster confusion matrix.

Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed
contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between
two clustering by considering all pairs of samples and counting pairs that are assigned into same or different
clusters in the predicted and target clusterings.

Note that the matrix is not symmetric.

Inspired by:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html

Args:
preds: predicted cluster labels
target: ground truth cluster labels
contingency: contingency matrix

Returns:
A 2x2 tensor containing the pair cluster confusion matrix.

Raises:
ValueError:
If neither `preds` and `target` nor `contingency` are provided.
ValueError:
If both `preds` and `target` and `contingency` are provided.

Example:
>>> import torch
>>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix
>>> preds = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([1, 1, 0, 0])
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
tensor([[8, 0],
[0, 4]])
>>> preds = torch.tensor([0, 0, 1, 2])
>>> target = torch.tensor([0, 0, 1, 1])
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
tensor([[8, 2],
[0, 2]])

"""
if preds is None and target is None and contingency is None:
raise ValueError("Must provide either `preds` and `target` or `contingency`.")
if preds is not None and target is not None and contingency is not None:
raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.")

if preds is not None and target is not None:
contingency = calculate_contingency_matrix(preds, target)

n_samples = contingency.sum()
n_c = contingency.sum(dim=1)
n_k = contingency.sum(dim=0)
sum_squared = (contingency**2).sum()

pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device)
pair_matrix[1, 1] = sum_squared - n_samples
pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared
pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared
pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared
return pair_matrix
Loading