Skip to content

Commit f205b84

Browse files
authored
Merge branch 'master' into plot/image
2 parents e526c55 + 5682c22 commit f205b84

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

README.md

+47-2
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ acc = torchmetrics.functional.classification.multiclass_accuracy(
269269

270270
### Covered domains and example metrics
271271

272-
We currently have implemented metrics within the following domains:
272+
In total TorchMetrics contains [90+ metrics](https://torchmetrics.readthedocs.io/en/stable/all-metrics.html), which
273+
convers the following domains:
273274

274275
- Audio
275276
- Classification
@@ -281,7 +282,51 @@ We currently have implemented metrics within the following domains:
281282
- Regression
282283
- Text
283284

284-
In total TorchMetrics contains [90+ metrics](https://torchmetrics.readthedocs.io/en/stable/all-metrics.html)!
285+
Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
286+
`pip install torchmetrics['image']` etc.
287+
288+
### Additional features
289+
290+
#### Plotting
291+
292+
Visualization of metrics can be important to help understand what is going on with your machine learning algorithms.
293+
Torchmetrics have build-in plotting support (install dependencies with `pip install torchmetrics[visual]`) for nearly
294+
all modular metrics through the `.plot` method. Simply call the method to get a simple visualization of any metric!
295+
296+
```python
297+
import torch
298+
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix
299+
300+
num_classes = 3
301+
302+
# this will generate two distributions that comes more similar as iterations increase
303+
w = torch.randn(num_classes)
304+
target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
305+
preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
306+
307+
acc = MulticlassAccuracy(num_classes=num_classes, average="micro")
308+
acc_per_class = MulticlassAccuracy(num_classes=num_classes, average=None)
309+
confmat = MulticlassConfusionMatrix(num_classes=num_classes)
310+
311+
# plot single value
312+
for i in range(5):
313+
acc_per_class.update(preds(i), target(i))
314+
confmat.update(preds(i), target(i))
315+
fig1, ax1 = acc_per_class.plot()
316+
fig2, ax2 = confmat.plot()
317+
318+
# plot multiple values
319+
values = []
320+
for i in range(10):
321+
values.append(acc(preds(i), target(i)))
322+
fig3, ax3 = acc.plot(values)
323+
```
324+
325+
<p align="center">
326+
<img src="docs/source/_static/images/plot_example.png" width="1000">
327+
</p>
328+
329+
For examples of plotting different metrics try running [this example file](examples/plotting.py).
285330

286331
## Contribute!
287332

94.4 KB
Loading

0 commit comments

Comments
 (0)