Releases: Lightning-AI/torchmetrics
Minor PL development patch
Information retrieval
Information Retrieval
Information retrieval (IR) metrics are used to evaluate how well a system is retrieving information from a database or from a collection of documents. This is the case with search engines, where a query provided by the user is compared with many possible results, some of which are relevant and some are not.
When you query a search engine, you hope that results that could be useful are ranked higher on the results page. However, each query is usually compared with a different set of documents. For this reason, we had to implement a mechanism to allow users to easily compute the IR metrics in cases where each query is compared with a different number of possible candidates.
For this reason, IR metrics feature an additional argument called indexes that say to which query a prediction refers to. In the end, all query-document pairs are grouped by query index and then the final result is computed as the average of the metric over each group.
In total 6 new metrics have been added for doing information retrieval:
- RetrievalMAP (Mean Average Precision)
- RetrievalMRR (Mean Reciprocal Rank)
- RetrievalPrecision (Precision for IR)
- RetrievalRecall (Recall for IR)
- RetrievalNormalizedDCG (Normalized Discounted Cumulative Gain)
- RetrievalFallOut (Fall Out rate for IR)
Special thanks go to @lucadiliello, for implementing all IR.
Expanding and improving the collection
In addition to expanding our collection to the field of information retrieval, this release also includes new metrics for the classification domain:
- BootStrapper metric that can wrap around any other metric in our collection for easy computation of confidence intervals
- CohenKappa is a statistic that is used to measure inter-rater reliability for qualitative (categorical) items
- MatthewsCorrcoef or phi coefficient is used in machine learning as a measure of the quality of binary (two-class) classifications
- Hinge loss is used for "maximum-margin" classification, most notably for support vector machines.
- PearsonCorrcoef is a metric for measuring the linear correlation between two sets of data
- SpearmanCorrcoef is a metric for measuring the rank correlation between two sets of data. It assesses how well the relationship between two variables can be described using a monotonic function.
Binned metrics
The current implementation of the AveragePrecision and PrecisionRecallCurve has the drawback that it saves all predictions and targets in memory to correctly calculate the metric value. These metrics now receive a binned version that calculates the value at fixed thresholds. This is less precise than original implementations but also much more memory efficient.
Special thanks go to @SkafteNicki, for letting all this happen.
[0.3.0] - 2021-04-20
Added
- Added
BootStrapper
to easily calculate confidence intervals for metrics (#101) - Added Binned metrics (#128)
- Added metrics for Information Retrieval:
- Added other metrics:
- Added
average='micro'
as an option in AUROC for multilabel problems (#110) - Added multilabel support to
ROC
metric (#114) - Added testing for
half
precision (#77, #135) - Added
AverageMeter
for ad-hoc averages of values (#138) - Added
prefix
argument toMetricCollection
(#70) - Added
__getitem__
as metric arithmetic operation (#142) - Added property
is_differentiable
to metrics and test for differentiability (#154) - Added support for
average
,ignore_index
andmdmc_average
inAccuracy
metric (#166) - Added
postfix
arg toMetricCollection
(#188)
Changed
- Changed
ExplainedVariance
from storing all preds/targets to tracking 5 statistics (#68) - Changed behavior of
confusionmatrix
for multilabel data to better matchmultilabel_confusion_matrix
from sklearn (#134) - Updated FBeta arguments (#111)
- Changed
reset
method to usedetach.clone()
instead ofdeepcopy
when resetting to default (#163) - Metrics passed as dict to
MetricCollection
will now always be in deterministic order (#173) - Allowed
MetricCollection
pass metrics as arguments (#176)
Deprecated
- Rename argument
is_multiclass
->multiclass
(#162)
Removed
- Prune remaining deprecated (#92)
Fixed
- Fixed when
_stable_1d_sort
to work whenn>=N
(PL^6177) - Fixed
_computed
attribute not being correctly reset (#147) - Fixed to Blau score (#165)
- Fixed backwards compatibility for logging with older version of pytorch-lightning (#182)
Contributors
@alanhdu, @arvindmuralie77, @bhadreshpsavani, @Borda, @ethanwharris, @lucadiliello, @maximsch2, @SkafteNicki, @thomasgaudelet, @victorjoos
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Initial release
What is Torchmetrics
TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- A standardized interface to increase reproducability
- Reduces Boilerplate
- Distributed-training compatible
- Automatic accumulation over batches
- Automatic synchronization between multiple devices
You can use TorchMetrics in any PyTorch model, or with in PyTorch Lightning to enjoy additional features:
- Module metrics are automatically placed on the correct device.
- Native support for logging metrics in Lightning to reduce even more boilerplate.
Using functional metrics
Similar to torch.nn
, most metrics have both a module-based and a functional version. The functional version implements the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor.
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target)
Using Module metrics
Nearly all functional metrics have a corresponding module-based metric that calls it a functional counterpart underneath. The module-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
- Accumulation of multiple batches
- Automatic synchronization between multiple devices
- Metric arithmetic
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
Built-in metrics
- Accuracy
- AveragePrecision
- AUC
- AUROC
- F1
- Hamming Distance
- ROC
- ExplainedVariance
- MeanSquaredError
- R2Score
- bleu_score
- embedding_similarity
And many more!
Contributors
@Borda, @SkafteNicki, @williamFalcon, @teddykoker, @justusschock, @tadejsv, @edenlightning, @ydcjeff, @ddrevicky, @ananyahjha93, @awaelchli, @rohitgr7, @akihironitta, @manipopopo, @Diuven, @arnaudgelas, @s-rog, @c00k1ez, @tgaddair, @elias-ramzi, @cuent, @jpcarzolio, @bryant1410, @shivdhar, @Sordie, @krzysztofwos, @abhik-99, @bernardomig, @peblair, @InCogNiTo124, @j-dsouza, @pranjaldatta, @ananthsub, @deng-cy, @abhinavg97, @tridao, @prampey, @abrahambotros, @ozen, @ShomyLiu, @yuntai, @pwwang
If we forgot someone due to not matching commit email with GitHub account, let us know :]