-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement clustering accuracy #2767
Comments
Hi! thanks for your contribution!, great first issue! |
@moetayuko thanks for opening this issue. Do you have a reference to a source (possible research paper) where they describe the metric in details? |
Sec. 6.2.1 of https://arxiv.org/abs/2206.07579 FYI here are some random implementations I found: |
@moetayuko thanks for the references, it really helped understanding how the metric is intended to work. from torchmetrics.functional.classification import multiclass_confusion_matrix
import torch
# pip install git+https://github.com/ivan-chai/torch-linear-assignment.git@main
from torch_linear_assignment import batch_linear_assignment
preds = torch.tensor([0, 0, 1, 1])
target = torch.tensor([1, 1, 0, 0])
confmat = multiclass_confusion_matrix(preds, target, num_classes=5)
print(confmat)
confmat = confmat[None]
assignment = batch_linear_assignment(confmat.max() - confmat)
print(assignment)
confmat = confmat[0]
tps = confmat[torch.arange(confmat.size(0)), assignment.flatten()]
acc = tps.sum() / len(preds)
print(acc) |
🚀 Feature
Motivation
Clustering accuracy is a popular metric. In addition to classification accuracy, it employs the Hungarian algorithm to align the predicted pseudo labels and the ground truth labels.
Current implementations of clustering accuracy use either
scipy.optimize.linear_sum_assignment
or themunkres
package for Hungarian. I'm not sure if this is allowed for torchmetrics, and a custom implementation needs to be added if not.Pitch
Implement clustering accuracy in
torchmetrics.clustering
The text was updated successfully, but these errors were encountered: