Skip to content

Commit

Permalink
Code cleaning after classification refactor 2/n (#1252)
Browse files Browse the repository at this point in the history
* functionals cleanup
* remove old functions
* revert stat score due to dice
* clean docstring
* more docstring cleaning
* remaining changes to impl
* remove old warning
* fix arg ordering
* fix import
* try fixing docs
* fix integration testing
* fix top_k arg
* fix broken tests
* fix more unittests
* fix more unittests
* another fix
* add tasks
* fix more doctests
* fix mypy

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Nov 22, 2022
1 parent d66cbc3 commit 20eab43
Show file tree
Hide file tree
Showing 58 changed files with 1,434 additions and 6,648 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed minimum Pytorch version to be 1.8 ([#1263](https://github.com/Lightning-AI/metrics/pull/1263))


- Changed interface for all functional and modular classification metrics after refactor ([#1252](https://github.com/Lightning-AI/metrics/pull/1252))


### Deprecated

-
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ import torch
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

# move the metric to device you want computations to take place
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -169,7 +169,7 @@ def metric_ddp(rank, world_size):
dist.init_process_group("gloo", rank=rank, world_size=world_size)

# initialize model
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
Expand Down Expand Up @@ -263,7 +263,9 @@ import torchmetrics
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)
acc = torchmetrics.functional.classification.multiclass_accuracy(
preds, target, num_classes=5
)
```

### Covered domains and example metrics
Expand Down
16 changes: 0 additions & 16 deletions docs/source/classification/precision_recall.rst

This file was deleted.

14 changes: 7 additions & 7 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The example below shows how to use a metric in your `LightningModule <https://py

def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
self.accuracy = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -80,8 +80,8 @@ value by calling ``.compute()``.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand All @@ -105,8 +105,8 @@ of the metrics.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -143,7 +143,7 @@ mixed as it can lead to wrong results.

def __init__(self):
...
self.valid_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def validation_step(self, batch, batch_idx):
logits = self(x)
Expand Down Expand Up @@ -187,7 +187,7 @@ The following contains a list of pitfalls to be aware of:

def __init__(self):
...
self.val_acc = nn.ModuleList([torchmetrics.Accuracy() for _ in range(2)])
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task='multiclass') for _ in range(2)])

def val_dataloader(self):
return [DataLoader(...), DataLoader(...)]
Expand Down
43 changes: 24 additions & 19 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us

.. code-block:: python
from torchmetrics.classification import Accuracy
from torchmetrics.classification import BinaryAccuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
train_accuracy = BinaryAccuracy()
valid_accuracy = BinaryAccuracy()
for epoch in range(epochs):
for x, y in train_data:
Expand Down Expand Up @@ -84,14 +84,14 @@ be moved to the same device as the input of the metric:

.. code-block:: python
from torchmetrics import Accuracy
from torchmetrics.classification import BinaryAccuracy
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
confmat = BinaryAccuracy().to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0
Expand All @@ -107,16 +107,17 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.

.. testcode::

from torchmetrics import Accuracy, MetricCollection
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy

class MyModule(torch.nn.Module):
def __init__(self):
...
# valid ways metrics will be identified as child modules
self.metric1 = Accuracy()
self.metric2 = nn.ModuleList(Accuracy())
self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
self.metric4 = MetricCollection([Accuracy()]) # torchmetrics build-in collection class
self.metric1 = BinaryAccuracy()
self.metric2 = nn.ModuleList(BinaryAccuracy())
self.metric3 = nn.ModuleDict({'accuracy': BinaryAccuracy()})
self.metric4 = MetricCollection([BinaryAccuracy()]) # torchmetrics build-in collection class

def forward(self, batch):
data, target = batch
Expand Down Expand Up @@ -254,33 +255,37 @@ Example:

.. testcode::

from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
MulticlassAccuracy(num_classes=3, average="micro"),
MulticlassPrecision(num_classes=3, average="macro"),
MulticlassRecall(num_classes=3, average="macro")
])
print(metric_collection(preds, target))

.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}
{'MulticlassAccuracy': tensor(0.1250),
'MulticlassPrecision': tensor(0.0667),
'MulticlassRecall': tensor(0.1111)}

Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule. In most cases we just have to replace ``self.log`` with ``self.log_dict``.

.. testcode::

from torchmetrics import Accuracy, MetricCollection, Precision, Recall
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

class MyModule(LightningModule):
def __init__(self):
metrics = MetricCollection([Accuracy(), Precision(), Recall()])
metrics = MetricCollection([
MulticlassAccuracy(), MulticlassPrecision(), MulticlassRecall()
])
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')

Expand Down
8 changes: 5 additions & 3 deletions docs/source/pages/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ The code-snippet below shows a simple example for calculating the accuracy using
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)
acc = torchmetrics.functional.accuracy(preds, target, task='multiclass', num_classes=5)

Module metrics
~~~~~~~~~~~~~~

Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath.
The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of
the PyTorch module) that allow them to offer additional functionalities:

* Accumulation of multiple batches
* Automatic synchronization between multiple devices
Expand All @@ -84,7 +86,7 @@ The code below shows how to use the class-based interface:
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task='multiclass', num_classes=5)

n_batches = 10
for i in range(n_batches):
Expand Down
Loading

0 comments on commit 20eab43

Please sign in to comment.