@@ -269,7 +269,8 @@ acc = torchmetrics.functional.classification.multiclass_accuracy(
269
269
270
270
### Covered domains and example metrics
271
271
272
- We currently have implemented metrics within the following domains:
272
+ In total TorchMetrics contains [ 90+ metrics] ( https://torchmetrics.readthedocs.io/en/stable/all-metrics.html ) , which
273
+ convers the following domains:
273
274
274
275
- Audio
275
276
- Classification
@@ -281,7 +282,51 @@ We currently have implemented metrics within the following domains:
281
282
- Regression
282
283
- Text
283
284
284
- In total TorchMetrics contains [ 90+ metrics] ( https://torchmetrics.readthedocs.io/en/stable/all-metrics.html ) !
285
+ Each domain may require some additional dependencies which can be installed with ` pip install torchmetrics[audio] ` ,
286
+ ` pip install torchmetrics['image'] ` etc.
287
+
288
+ ### Additional features
289
+
290
+ #### Plotting
291
+
292
+ Visualization of metrics can be important to help understand what is going on with your machine learning algorithms.
293
+ Torchmetrics have build-in plotting support (install dependencies with ` pip install torchmetrics[visual] ` ) for nearly
294
+ all modular metrics through the ` .plot ` method. Simply call the method to get a simple visualization of any metric!
295
+
296
+ ``` python
297
+ import torch
298
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix
299
+
300
+ num_classes = 3
301
+
302
+ # this will generate two distributions that comes more similar as iterations increase
303
+ w = torch.randn(num_classes)
304
+ target = lambda it : torch.multinomial((it * w).softmax(dim = - 1 ), 100 , replacement = True )
305
+ preds = lambda it : torch.multinomial((it * w).softmax(dim = - 1 ), 100 , replacement = True )
306
+
307
+ acc = MulticlassAccuracy(num_classes = num_classes, average = " micro" )
308
+ acc_per_class = MulticlassAccuracy(num_classes = num_classes, average = None )
309
+ confmat = MulticlassConfusionMatrix(num_classes = num_classes)
310
+
311
+ # plot single value
312
+ for i in range (5 ):
313
+ acc_per_class.update(preds(i), target(i))
314
+ confmat.update(preds(i), target(i))
315
+ fig1, ax1 = acc_per_class.plot()
316
+ fig2, ax2 = confmat.plot()
317
+
318
+ # plot multiple values
319
+ values = []
320
+ for i in range (10 ):
321
+ values.append(acc(preds(i), target(i)))
322
+ fig3, ax3 = acc.plot(values)
323
+ ```
324
+
325
+ <p align =" center " >
326
+ <img src =" docs/source/_static/images/plot_example.png " width =" 1000 " >
327
+ </p >
328
+
329
+ For examples of plotting different metrics try running [ this example file] ( examples/plotting.py ) .
285
330
286
331
## Contribute!
287
332
0 commit comments