|
51 | 51 | class BinaryConfusionMatrix(Metric):
|
52 | 52 | r"""Compute the `confusion matrix`_ for binary tasks.
|
53 | 53 |
|
| 54 | + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations |
| 55 | + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix |
| 56 | + correspond to the true class labels and column indices correspond to the predicted class labels. |
| 57 | +
|
| 58 | + For binary tasks, the confusion matrix is a 2x2 matrix with the following structure: |
| 59 | +
|
| 60 | + - :math:`C_{0, 0}`: True negatives |
| 61 | + - :math:`C_{0, 1}`: False positives |
| 62 | + - :math:`C_{1, 0}`: False negatives |
| 63 | + - :math:`C_{1, 1}`: True positives |
| 64 | +
|
54 | 65 | As input to ``forward`` and ``update`` the metric accepts the following input:
|
55 | 66 |
|
56 | 67 | - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
|
@@ -176,6 +187,17 @@ def plot(
|
176 | 187 | class MulticlassConfusionMatrix(Metric):
|
177 | 188 | r"""Compute the `confusion matrix`_ for multiclass tasks.
|
178 | 189 |
|
| 190 | + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations |
| 191 | + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix |
| 192 | + correspond to the true class labels and column indices correspond to the predicted class labels. |
| 193 | +
|
| 194 | + For multiclass tasks, the confusion matrix is a NxN matrix, where: |
| 195 | +
|
| 196 | + - :math:`C_{i, i}` represents the number of true positives for class :math:`i` |
| 197 | + - :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i` |
| 198 | + - :math:`\sum_{i=1, i\neq j}^N C_{i, j}` represents the number of false positives for class :math:`i` |
| 199 | + - the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i` |
| 200 | +
|
179 | 201 | As input to ``forward`` and ``update`` the metric accepts the following input:
|
180 | 202 |
|
181 | 203 | - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
|
@@ -305,6 +327,18 @@ def plot(
|
305 | 327 | class MultilabelConfusionMatrix(Metric):
|
306 | 328 | r"""Compute the `confusion matrix`_ for multilabel tasks.
|
307 | 329 |
|
| 330 | + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations |
| 331 | + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix |
| 332 | + correspond to the true class labels and column indices correspond to the predicted class labels. |
| 333 | +
|
| 334 | + For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion |
| 335 | + for that label. The structure of each 2x2 matrix is as follows: |
| 336 | +
|
| 337 | + - :math:`C_{0, 0}`: True negatives |
| 338 | + - :math:`C_{0, 1}`: False positives |
| 339 | + - :math:`C_{1, 0}`: False negatives |
| 340 | + - :math:`C_{1, 1}`: True positives |
| 341 | +
|
308 | 342 | As input to 'update' the metric accepts the following input:
|
309 | 343 |
|
310 | 344 | - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
|
|
0 commit comments