Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 10, 2022
2 parents cedc44f + 3bd4fb0 commit be3814e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695))


- Changed dtype of metric state from `torch.float` to `torch.long` in `ConfusionMatrix` to accommodate larger values ([#708](https://github.com/PyTorchLightning/metrics/issues/708))


### Deprecated

- Renamed IoU -> Jaccard Index ([#662](https://github.com/PyTorchLightning/metrics/pull/662))
Expand Down
21 changes: 12 additions & 9 deletions torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,26 @@ class ConfusionMatrix(Metric):
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])
tensor([[2, 0],
[1, 1]])
Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])
tensor([[1, 1, 0],
[0, 1, 0],
[0, 0, 1]])
Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
tensor([[[1, 0], [0, 1]],
[[1, 0], [1, 0]],
[[0, 1], [0, 1]]])
"""
is_differentiable = False
Expand Down Expand Up @@ -118,7 +118,10 @@ def __init__(
if self.normalize not in allowed_normalize:
raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}")

default = torch.zeros(num_classes, 2, 2) if multilabel else torch.zeros(num_classes, num_classes)
if multilabel:
default = torch.zeros(num_classes, 2, 2, dtype=torch.long)
else:
default = torch.zeros(num_classes, num_classes, dtype=torch.long)
self.add_state("confmat", default=default, dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
Expand Down
16 changes: 8 additions & 8 deletions torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,26 @@ def confusion_matrix(
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])
tensor([[2, 0],
[1, 1]])
Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
[0., 1., 0.],
[0., 0., 1.]])
tensor([[1, 1, 0],
[0, 1, 0],
[0, 0, 1]])
Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target) # doctest: +NORMALIZE_WHITESPACE
tensor([[[1., 0.], [0., 1.]],
[[1., 0.], [1., 0.]],
[[0., 1.], [0., 1.]]])
tensor([[[1, 0], [0, 1]],
[[1, 0], [1, 0]],
[[0, 1], [0, 1]]])
"""
confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel)
Expand Down

0 comments on commit be3814e

Please sign in to comment.