Skip to content

Commit 4259943

Browse files
authored
Fix axes in confusion matrix (#1976)
1 parent cd7ef55 commit 4259943

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3838

3939
### Fixed
4040

41+
- Fixed x/y labels when plotting confusion matrices ([#1976](https://github.com/Lightning-AI/torchmetrics/pull/1976))
42+
43+
4144
- Fixed IOU compute in cuda ([#1982](https://github.com/Lightning-AI/torchmetrics/pull/1982))
4245

4346

src/torchmetrics/classification/confusion_matrix.py

+34
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@
5151
class BinaryConfusionMatrix(Metric):
5252
r"""Compute the `confusion matrix`_ for binary tasks.
5353
54+
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
55+
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
56+
correspond to the true class labels and column indices correspond to the predicted class labels.
57+
58+
For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:
59+
60+
- :math:`C_{0, 0}`: True negatives
61+
- :math:`C_{0, 1}`: False positives
62+
- :math:`C_{1, 0}`: False negatives
63+
- :math:`C_{1, 1}`: True positives
64+
5465
As input to ``forward`` and ``update`` the metric accepts the following input:
5566
5667
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
@@ -176,6 +187,17 @@ def plot(
176187
class MulticlassConfusionMatrix(Metric):
177188
r"""Compute the `confusion matrix`_ for multiclass tasks.
178189
190+
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
191+
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
192+
correspond to the true class labels and column indices correspond to the predicted class labels.
193+
194+
For multiclass tasks, the confusion matrix is a NxN matrix, where:
195+
196+
- :math:`C_{i, i}` represents the number of true positives for class :math:`i`
197+
- :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i`
198+
- :math:`\sum_{i=1, i\neq j}^N C_{i, j}` represents the number of false positives for class :math:`i`
199+
- the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i`
200+
179201
As input to ``forward`` and ``update`` the metric accepts the following input:
180202
181203
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
@@ -305,6 +327,18 @@ def plot(
305327
class MultilabelConfusionMatrix(Metric):
306328
r"""Compute the `confusion matrix`_ for multilabel tasks.
307329
330+
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
331+
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
332+
correspond to the true class labels and column indices correspond to the predicted class labels.
333+
334+
For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion
335+
for that label. The structure of each 2x2 matrix is as follows:
336+
337+
- :math:`C_{0, 0}`: True negatives
338+
- :math:`C_{0, 1}`: False positives
339+
- :math:`C_{1, 0}`: False negatives
340+
- :math:`C_{1, 1}`: True positives
341+
308342
As input to 'update' the metric accepts the following input:
309343
310344
- ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside

src/torchmetrics/utilities/plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def plot_confusion_matrix(
249249
if fig_label is not None:
250250
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
251251
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
252-
ax.set_xlabel("True class", fontsize=15)
253-
ax.set_ylabel("Predicted class", fontsize=15)
252+
ax.set_xlabel("Predicted class", fontsize=15)
253+
ax.set_ylabel("True class", fontsize=15)
254254
ax.set_xticks(list(range(n_classes)))
255255
ax.set_yticks(list(range(n_classes)))
256256
ax.set_xticklabels(labels, rotation=45, fontsize=10)

0 commit comments

Comments
 (0)