Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Cluster Accuracy #2777

Open
wants to merge 59 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
3139d7b
new requirements file
SkafteNicki Oct 11, 2024
99b8428
doc file
SkafteNicki Oct 11, 2024
d5a2241
remove unused class variables
SkafteNicki Oct 11, 2024
69998f8
init files
SkafteNicki Oct 12, 2024
afde4cc
functional implementation
SkafteNicki Oct 12, 2024
d8f633d
modular implementation
SkafteNicki Oct 12, 2024
4cf6251
rename some files
SkafteNicki Oct 12, 2024
49fd21d
update init and requirements files
SkafteNicki Oct 12, 2024
ae10d90
fixes
SkafteNicki Oct 12, 2024
b9bb1b0
fix remaining issues
SkafteNicki Oct 12, 2024
3ff7cc5
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 12, 2024
dcdd1dd
changelog
SkafteNicki Oct 12, 2024
9090b36
add install note
SkafteNicki Oct 12, 2024
4c35de9
skip doctests on missing package
SkafteNicki Oct 12, 2024
8298a34
missing requirements linking
SkafteNicki Oct 12, 2024
7a68994
guard against missing package
SkafteNicki Oct 12, 2024
184d910
lazy import instead
SkafteNicki Oct 12, 2024
805543d
fix doctests
SkafteNicki Oct 12, 2024
9d6e622
try fixing
SkafteNicki Oct 14, 2024
92f8854
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 14, 2024
bcd7105
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 15, 2024
d8ab2d7
Apply suggestions from code review
Borda Oct 15, 2024
9adce5d
Apply suggestions from code review
Borda Oct 15, 2024
984e229
Apply suggestions from code review
Borda Oct 15, 2024
fa30efe
req.
Borda Oct 15, 2024
a7c8be2
Merge branch 'master' into newmetric/cluster_accuracy
Borda Oct 15, 2024
636aa12
lint
Borda Oct 15, 2024
e8ee7ed
Merge branch 'newmetric/cluster_accuracy' of https://github.com/Light…
Borda Oct 15, 2024
c6a8eb5
guard against older pytorch versions
SkafteNicki Oct 16, 2024
f8cf6e2
lower aeon requirement
SkafteNicki Oct 16, 2024
63119d4
sort requirements list
SkafteNicki Oct 16, 2024
943f1f9
Merge branch 'master' into newmetric/cluster_accuracy
Borda Oct 22, 2024
a27b0b5
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 22, 2024
4dd2e35
remove guard against older torch versions
SkafteNicki Oct 22, 2024
e499182
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 22, 2024
71459ab
Merge branch 'master' into newmetric/cluster_accuracy
Borda Oct 23, 2024
a2fb0a8
Merge branch 'master' into newmetric/cluster_accuracy
SkafteNicki Oct 24, 2024
ad4e88f
Update requirements/clustering.txt
SkafteNicki Oct 24, 2024
099a1d9
try setting cuda home
SkafteNicki Oct 24, 2024
c52a335
try setting cuda home
SkafteNicki Oct 24, 2024
db0dcd6
continue on failing install
SkafteNicki Oct 24, 2024
e72aa17
Merge branch 'master' into newmetric/cluster_accuracy
Borda Oct 29, 2024
ada126f
Merge branch 'master' into newmetric/cluster_accuracy
Borda Oct 30, 2024
1fdc09f
Merge branch 'master' into newmetric/cluster_accuracy
baskrahmer Nov 9, 2024
195d857
revert changes
SkafteNicki Nov 11, 2024
904287a
update requirement to min py3.10
SkafteNicki Nov 11, 2024
4fb9e9c
fix
SkafteNicki Nov 11, 2024
a6ae4a2
Merge branch 'master' into newmetric/cluster_accuracy
Borda Nov 12, 2024
ae8c6df
try 2020-resolver
Borda Nov 12, 2024
5c3c092
Merge branch 'master' into newmetric/cluster_accuracy
baskrahmer Nov 22, 2024
53117ed
Merge branch 'master' into newmetric/cluster_accuracy
Borda Dec 17, 2024
f418947
Apply suggestions from code review
Borda Jan 7, 2025
09fea99
Merge branch 'master' into newmetric/cluster_accuracy
Borda Jan 7, 2025
184009f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
7b751b5
Apply suggestions from code review
Borda Jan 7, 2025
1617b82
Merge branch 'master' into newmetric/cluster_accuracy
Borda Jan 7, 2025
7c80d11
--extra-index-url
Borda Jan 8, 2025
8c1d5ec
scikit-learn ==1.5.*
Borda Jan 8, 2025
30b4c19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
TOKENIZERS_PARALLELISM: false
TEST_DIRS: ${{ needs.check-diff.outputs.test-dirs }}
PIP_USE_FEATURE: "2020-resolver"
Borda marked this conversation as resolved.
Show resolved Hide resolved
PIP_EXTRA_INDEX_URL: "--extra-index-url=http://download.pytorch.org/whl/cpu/"
UNITTEST_TIMEOUT: "" # by default, it is not set

Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `ClusterAccuracy` metric to cluster package ([#2777](https://github.com/Lightning-AI/torchmetrics/pull/2777))


-


Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/cluster_accuracy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Cluster Accuracy
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
:tags: Clustering

.. include:: ../links.rst

################
Cluster Accuracy
################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.ClusterAccuracy
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.cluster_accuracy
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,6 @@
.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
.. _Cluster Accuracy: https://arxiv.org/abs/2206.07579
.. _Log AUC: https://pubmed.ncbi.nlm.nih.gov/20735049/
.. _Negative Predictive Value: https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values
14 changes: 8 additions & 6 deletions requirements/_devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
-r _tests.txt

# add extra requirements
-r audio.txt
-r clustering.txt
-r detection.txt
-r image.txt
-r text.txt
-r detection.txt
-r audio.txt
-r multimodal.txt
-r visual.txt

# add extra testing
-r image_test.txt
-r text_test.txt
-r audio_test.txt
-r detection_test.txt
-r classification_test.txt
-r clustering_test.txt
-r detection_test.txt
-r image_test.txt
-r nominal_test.txt
-r segmentation_test.txt
-r regression_test.txt
-r segmentation_test.txt
-r text_test.txt
3 changes: 2 additions & 1 deletion requirements/_docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ pydantic > 1.0.0, < 3.0.0

# integrations
-r _integrate.txt
-r visual.txt
-r audio.txt
-r clustering.txt
-r detection.txt
-r image.txt
-r multimodal.txt
-r text.txt
-r visual.txt

# Gallery extra requirements
# --------------------------
Expand Down
4 changes: 4 additions & 0 deletions requirements/clustering.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch_linear_assignment >=0.0.2, <0.0.3
4 changes: 4 additions & 0 deletions requirements/clustering_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

aeon >= 1.0.0, <1.1.0; python_version >"3.10" # cluster accuracy
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from torchmetrics.clustering.cluster_accuracy import ClusterAccuracy
from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore
from torchmetrics.clustering.dunn_index import DunnIndex
from torchmetrics.clustering.fowlkes_mallows_index import FowlkesMallowsIndex
Expand All @@ -30,6 +31,7 @@
"AdjustedMutualInfoScore",
"AdjustedRandScore",
"CalinskiHarabaszScore",
"ClusterAccuracy",
"CompletenessScore",
"DaviesBouldinScore",
"DunnIndex",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class AdjustedMutualInfoScore(MutualInfoScore):
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(
self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any
Expand Down
148 changes: 148 additions & 0 deletions src/torchmetrics/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.functional.classification import multiclass_confusion_matrix
from torchmetrics.functional.clustering.cluster_accuracy import _cluster_accuracy_compute
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
_MATPLOTLIB_AVAILABLE,
_TORCH_LINEAR_ASSIGNMENT_AVAILABLE,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClusterAccuracy.plot"]

if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
__doctest_skip__ = ["ClusterAccuracy", "ClusterAccuracy.plot"]


class ClusterAccuracy(Metric):
r"""Compute `Cluster Accuracy`_ between predicted and target clusters.

.. math::

\text{Cluster Accuracy} = \max_g \frac{1}{N} \sum_{n=1}^N \mathbb{1}_{g(p_n) = t_n}

Where :math:`g` is a function that maps predicted clusters :math:`p` to target clusters :math:`t`, :math:`N` is the
number of samples, :math:`p_n` is the predicted cluster for sample :math:`n`, :math:`t_n` is the target cluster for
sample :math:`n`, and :math:`\mathbb{1}` is the indicator function. The function :math:`g` is determined by solving
the linear sum assignment problem.

This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``acc_score`` (:class:`~torch.Tensor`): A tensor with the Cluster Accuracy score

Args:
num_classes: number of classes
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
RuntimeError:
If ``torch_linear_assignment`` is not installed. To install, run ``pip install torchmetrics[clustering]``.
ValueError
If ``num_classes`` is not a positive integer

Example::
>>> import torch
>>> from torchmetrics.clustering import ClusterAccuracy
>>> preds = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([1, 1, 0, 0])
>>> metric = ClusterAccuracy(num_classes=2)
>>> metric(preds, target)
tensor(1.)

"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
confmat: Tensor

def __init__(self, num_classes: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
raise RuntimeError(
"Missing `torch_linear_assignment`. Please install it with `pip install torchmetrics[clustering]`."
)

if not isinstance(num_classes, int) or num_classes <= 0:
raise ValueError("Argument `num_classes` should be a positive integer")
self.add_state(
"confmat", default=torch.zeros((num_classes, num_classes), dtype=torch.int64), dist_reduce_fx="sum"
)
self.num_classes = num_classes

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the confusion matrix with the new predictions and targets."""
self.confmat += multiclass_confusion_matrix(preds, target, num_classes=self.num_classes)

def compute(self) -> Tensor:
"""Computes the clustering accuracy."""
return _cluster_accuracy_compute(self.confmat)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling ``metric.forward`` or ``metric.compute``
or a list of these results. If no value is provided, will automatically call `metric.compute`
and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import ClusterAccuracy
>>> metric = ClusterAccuracy(num_classes=4)
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import ClusterAccuracy
>>> metric = ClusterAccuracy(num_classes=4)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))))
>>> fig_, ax_ = metric.plot(values)

"""
return self._plot(val, ax)
1 change: 0 additions & 1 deletion src/torchmetrics/clustering/fowlkes_mallows_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class FowlkesMallowsIndex(Metric):
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class MutualInfoScore(Metric):
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class NormalizedMutualInfoScore(MutualInfoScore):
plot_upper_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(
self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class RandScore(Metric):
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
from torchmetrics.functional.clustering.cluster_accuracy import cluster_accuracy
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
from torchmetrics.functional.clustering.dunn_index import dunn_index
from torchmetrics.functional.clustering.fowlkes_mallows_index import fowlkes_mallows_index
Expand All @@ -30,6 +31,7 @@
"adjusted_mutual_info_score",
"adjusted_rand_score",
"calinski_harabasz_score",
"cluster_accuracy",
"completeness_score",
"davies_bouldin_score",
"dunn_index",
Expand Down
67 changes: 67 additions & 0 deletions src/torchmetrics/functional/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.classification import multiclass_confusion_matrix
from torchmetrics.functional.clustering.utils import check_cluster_labels
from torchmetrics.utilities.imports import _TORCH_LINEAR_ASSIGNMENT_AVAILABLE

if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
__doctest_skip__ = ["cluster_accuracy"]


def _cluster_accuracy_compute(confmat: Tensor) -> Tensor:
"""Computes the clustering accuracy from a confusion matrix."""
from torch_linear_assignment import batch_linear_assignment

confmat = confmat[None]
# solve the linear sum assignment problem
assignment = batch_linear_assignment(confmat.max() - confmat)
confmat = confmat[0]
# extract the true positives
tps = confmat[torch.arange(confmat.shape[0]), assignment.flatten()]
return tps.sum() / confmat.sum()


def cluster_accuracy(preds: Tensor, target: Tensor, num_classes: int) -> Tensor:
"""Computes the clustering accuracy between the predicted and target clusters.

Args:
preds: predicted cluster labels
target: ground truth cluster labels
num_classes: number of classes

Returns:
Scalar tensor with clustering accuracy between 0.0 and 1.0

Raises:
RuntimeError:
If `torch_linear_assignment` is not installed

Example:
>>> from torchmetrics.functional.clustering import cluster_accuracy
>>> preds = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([1, 1, 0, 0])
>>> cluster_accuracy(preds, target, 2)
tensor(1.000)

"""
if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
raise RuntimeError(
"Missing `torch_linear_assignment`. Please install it with `pip install torchmetrics[clustering]`."
)
check_cluster_labels(preds, target)
confmat = multiclass_confusion_matrix(preds, target, num_classes=num_classes)
return _cluster_accuracy_compute(confmat)
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SCIPI_AVAILABLE = RequirementCache("scipy")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")
_TORCH_LINEAR_ASSIGNMENT_AVAILABLE = RequirementCache("torch_linear_assignment")
_PYTDC_AVAILABLE = RequirementCache("pyTDC")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
Loading
Loading