From 28af34bc5134fddf544425fed9ffe04445b237e3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Sep 2020 14:36:14 +0200 Subject: [PATCH] [Metrics] Class reduction similar to sklearn (#3322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * new class reduce interface * update docs * pep8 * update_class_metrics * fix doctest * changelog * fix docs * fix codefactor * fix codefactor * formatting * fix typo * fix typo * typo pr -> per * update from suggestion * fix error * Apply suggestions from code review * Update CHANGELOG.md * formatting * timeouts * docstring formatting for reg metrics * pep * flake8 * revert workflow changes * suggestions Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Jirka Borovec Co-authored-by: rohitgr7 --- CHANGELOG.md | 5 + pytorch_lightning/metrics/classification.py | 126 +++++++++------- .../metrics/functional/classification.py | 140 +++++++++--------- .../metrics/functional/reduction.py | 41 +++++ .../metrics/functional/regression.py | 69 ++++----- pytorch_lightning/metrics/regression.py | 60 ++++---- .../metrics/functional/test_classification.py | 54 ++++--- tests/metrics/functional/test_reduction.py | 17 ++- 8 files changed, 301 insertions(+), 211 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c6c38163ee785..30bfaa56c6b92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * functional interface ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349)) * class based interface + tests ([#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) @@ -25,6 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactor `GPUStatsMonitor` to improve training speed ([#3257](https://github.com/PyTorchLightning/pytorch-lightning/pull/3257)) +- Renamed `reduction` to `class_reduction` in classification metrics ([#3322](https://github.com/PyTorchLightning/pytorch-lightning/pull/3322)) + +- Changed `class_reduction` similar to sklearn for classification metrics ([#3322](https://github.com/PyTorchLightning/pytorch-lightning/pull/3322)) + ### Deprecated diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index 1ed645f86909c..aa14d48ead6ed 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -52,22 +52,25 @@ class Accuracy(TensorMetric): def __init__( self, num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + class_reduction: str = 'micro', reduce_group: Any = None, ): """ Args: num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + reduce_group: the process group to reduce metric results from DDP """ super().__init__(name="accuracy", reduce_group=reduce_group) self.num_classes = num_classes - self.reduction = reduction + assert class_reduction in ('micro', 'macro', 'weighted', 'none') + self.class_reduction = class_reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -80,7 +83,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return accuracy(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) + return accuracy(pred=pred, target=target, + num_classes=self.num_classes, class_reduction=self.class_reduction) class ConfusionMatrix(TensorMetric): @@ -209,7 +213,7 @@ class Precision(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = Precision(num_classes=4) + >>> metric = Precision(num_classes=4, class_reduction='macro') >>> metric(pred, target) tensor(0.7500) @@ -218,17 +222,19 @@ class Precision(TensorMetric): def __init__( self, num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + class_reduction: str = 'micro', reduce_group: Any = None, ): """ Args: num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + reduce_group: the process group to reduce metric results from DDP """ super().__init__( @@ -236,7 +242,8 @@ def __init__( reduce_group=reduce_group, ) self.num_classes = num_classes - self.reduction = reduction + assert class_reduction in ('micro', 'macro', 'weighted', 'none') + self.class_reduction = class_reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -249,7 +256,9 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return precision(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) + return precision(pred=pred, target=target, + num_classes=self.num_classes, + class_reduction=self.class_reduction) class Recall(TensorMetric): @@ -262,24 +271,26 @@ class Recall(TensorMetric): >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = Recall() >>> metric(pred, target) - tensor(0.6250) + tensor(0.7500) """ def __init__( self, num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + class_reduction: str = 'micro', reduce_group: Any = None, ): """ Args: num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + reduce_group: the process group to reduce metric results from DDP """ super().__init__( @@ -288,7 +299,8 @@ def __init__( ) self.num_classes = num_classes - self.reduction = reduction + assert class_reduction in ('micro', 'macro', 'weighted', 'none') + self.class_reduction = class_reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -301,7 +313,10 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return recall(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) + return recall(pred=pred, + target=target, + num_classes=self.num_classes, + class_reduction=self.class_reduction) class AveragePrecision(TensorMetric): @@ -409,7 +424,7 @@ class FBeta(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = FBeta(0.25) + >>> metric = FBeta(0.25, class_reduction='macro') >>> metric(pred, target) tensor(0.7361) """ @@ -418,18 +433,20 @@ def __init__( self, beta: float, num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + class_reduction: str = 'micro', reduce_group: Any = None, ): """ Args: beta: determines the weight of recall in the combined score. num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + reduce_group: the process group to reduce metric results from DDP """ super().__init__( @@ -439,7 +456,8 @@ def __init__( self.beta = beta self.num_classes = num_classes - self.reduction = reduction + assert class_reduction in ('micro', 'macro', 'weighted', 'none') + self.class_reduction = class_reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -452,9 +470,9 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: classification score """ - return fbeta_score( - pred=pred, target=target, beta=self.beta, num_classes=self.num_classes, reduction=self.reduction - ) + return fbeta_score(pred=pred, target=target, + beta=self.beta, num_classes=self.num_classes, + class_reduction=self.class_reduction) class F1(TensorMetric): @@ -466,7 +484,7 @@ class F1(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = F1() + >>> metric = F1(class_reduction='macro') >>> metric(pred, target) tensor(0.6667) """ @@ -474,17 +492,19 @@ class F1(TensorMetric): def __init__( self, num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + class_reduction: str = 'micro', reduce_group: Any = None, ): """ Args: num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + reduce_group: the process group to reduce metric results from DDP """ super().__init__( @@ -493,7 +513,8 @@ def __init__( ) self.num_classes = num_classes - self.reduction = reduction + assert class_reduction in ('micro', 'macro', 'weighted', 'none') + self.class_reduction = class_reduction def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -506,7 +527,9 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: classification score """ - return f1_score(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction) + return f1_score(pred=pred, target=target, + num_classes=self.num_classes, + class_reduction=self.class_reduction) class ROC(TensorMetric): @@ -518,13 +541,10 @@ class ROC(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = ROC() - >>> fps, tps, thresholds = metric(pred, target) - >>> fps - tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) - >>> tps - tensor([0., 0., 0., 1., 1.]) - >>> thresholds - tensor([4., 3., 2., 1., 0.]) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), + tensor([0., 0., 0., 1., 1.]), + tensor([4., 3., 2., 1., 0.])) """ diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 5a14aa22edd2b..75c0ab358798a 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -4,7 +4,7 @@ import torch from torch.nn import functional as F -from pytorch_lightning.metrics.functional.reduction import reduce +from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce from pytorch_lightning.utilities import FLOAT16_EPSILON, rank_zero_warn @@ -232,14 +232,14 @@ def stat_scores_multiple_classes( tns /= num_classes sups /= num_classes - return tps, fps, tns, fns, sups + return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() def accuracy( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction='elementwise_mean', + class_reduction: str = 'micro' ) -> torch.Tensor: """ Computes the accuracy classification score @@ -248,15 +248,15 @@ def accuracy( pred: predicted labels target: ground truth labels num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: - A Tensor with the classification score. + A Tensor with the accuracy score. Example: @@ -270,9 +270,9 @@ def accuracy( raise RuntimeError("cannot infer num_classes when target is all zero") tps, fps, tns, fns, sups = stat_scores_multiple_classes( - pred=pred, target=target, num_classes=num_classes, reduction=reduction) + pred=pred, target=target, num_classes=num_classes) - return tps / sups + return class_reduce(tps, sups, sups, class_reduction=class_reduction) def confusion_matrix( @@ -325,7 +325,8 @@ def precision_recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', + return_support: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -334,12 +335,14 @@ def precision_recall( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + return_support: returns the support for each class, need for fbeta/f1 calculations Return: Tensor with precision and recall @@ -347,26 +350,17 @@ def precision_recall( Example: >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> precision_recall(x, y) - (tensor(0.7500), tensor(0.6250)) + >>> y = torch.tensor([0, 2, 2, 2]) + >>> precision_recall(x, y, class_reduction='macro') + (tensor(0.5000), tensor(0.3333)) """ tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) - tps = tps.to(torch.float) - fps = fps.to(torch.float) - fns = fns.to(torch.float) - - precision = tps / (tps + fps) - recall = tps / (tps + fns) - - # solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9 - precision[precision != precision] = 0 - recall[recall != recall] = 0 - - precision = reduce(precision, reduction=reduction) - recall = reduce(recall, reduction=reduction) + precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) + recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) + if return_support: + return precision, recall, sups return precision, recall @@ -374,7 +368,7 @@ def precision( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes precision score. @@ -383,12 +377,12 @@ def precision( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with precision. @@ -402,14 +396,14 @@ def precision( """ return precision_recall(pred=pred, target=target, - num_classes=num_classes, reduction=reduction)[0] + num_classes=num_classes, class_reduction=class_reduction)[0] def recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes recall score. @@ -418,12 +412,12 @@ def recall( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with recall. @@ -433,10 +427,10 @@ def recall( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> recall(x, y) - tensor(0.6250) + tensor(0.7500) """ return precision_recall(pred=pred, target=target, - num_classes=num_classes, reduction=reduction)[1] + num_classes=num_classes, class_reduction=class_reduction)[1] def fbeta_score( @@ -444,7 +438,7 @@ def fbeta_score( target: torch.Tensor, beta: float, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes the F-beta score which is a weighted harmonic mean of precision and recall. @@ -459,12 +453,12 @@ def fbeta_score( beta = 0: only precision beta -> inf: only recall num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements. + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with the value of F-score. It is a value between 0-1. @@ -474,27 +468,27 @@ def fbeta_score( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> fbeta_score(x, y, 0.2) - tensor(0.7407) + tensor(0.7500) """ - prec, rec = precision_recall(pred=pred, target=target, - num_classes=num_classes, - reduction='none') + # We need to differentiate at which point to do class reduction + intermidiate_reduction = 'none' if class_reduction != "micro" else 'micro' - nom = (1 + beta ** 2) * prec * rec + prec, rec, sups = precision_recall(pred=pred, target=target, + num_classes=num_classes, + class_reduction=intermidiate_reduction, + return_support=True) + num = (1 + beta ** 2) * prec * rec denom = ((beta ** 2) * prec + rec) - fbeta = nom / denom - - # drop NaN after zero division - fbeta[fbeta != fbeta] = 0 - - return reduce(fbeta, reduction=reduction) + if intermidiate_reduction == 'micro': + return torch.sum(num) / torch.sum(denom) + return class_reduce(num, denom, sups, class_reduction=class_reduction) def f1_score( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction='elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. @@ -504,12 +498,12 @@ def f1_score( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements. + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor containing F1-score @@ -519,10 +513,10 @@ def f1_score( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> f1_score(x, y) - tensor(0.6667) + tensor(0.7500) """ return fbeta_score(pred=pred, target=target, beta=1., - num_classes=num_classes, reduction=reduction) + num_classes=num_classes, class_reduction=class_reduction) def _binary_clf_curve( diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index b9be8ca7daeb5..d0618abd65b96 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -22,3 +22,44 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: if reduction == 'sum': return torch.sum(to_reduce) raise ValueError('Reduction parameter unknown.') + + +def class_reduce(num: torch.Tensor, + denom: torch.Tensor, + weights: torch.Tensor, + class_reduction: str = 'none') -> torch.Tensor: + """ + Function used to reduce classification metrics of the form `num / denom * weights`. + For example for calculating standard accuracy the num would be number of + true positives per class, denom would be the support per class, and weights + would be a tensor of 1s + + Args: + num: numerator tensor + decom: denominator tensor + weights: weights for each class + class_reduction: reduction method for multiclass problems + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + """ + valid_reduction = ('micro', 'macro', 'weighted', 'none') + if class_reduction == 'micro': + return torch.sum(num) / torch.sum(denom) + + # For the rest we need to take care of instances where the denom can be 0 + # for some classes which will produce nans for that class + fraction = num / denom + fraction[fraction != fraction] = 0 + if class_reduction == 'macro': + return torch.mean(fraction) + elif class_reduction == 'weighted': + return torch.sum(fraction * (weights / torch.sum(weights))) + elif class_reduction == 'none': + return fraction + + raise ValueError(f'Reduction parameter {class_reduction} unknown.' + f' Choose between one of these: {valid_reduction}') diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py index 89fee6d21de4c..75d8f1adf9a86 100644 --- a/pytorch_lightning/metrics/functional/regression.py +++ b/pytorch_lightning/metrics/functional/regression.py @@ -17,12 +17,11 @@ def mse( Args: pred: estimated labels target: ground truth labels - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with MSE @@ -51,12 +50,11 @@ def rmse( Args: pred: estimated labels target: ground truth labels - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with RMSE @@ -83,12 +81,11 @@ def mae( Args: pred: estimated labels target: ground truth labels - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with MAE @@ -117,12 +114,11 @@ def rmsle( Args: pred: estimated labels target: ground truth labels - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with RMSLE @@ -154,12 +150,11 @@ def psnr( target: groun truth signal data_range: the range of the data. If None, it is determined from the data (max - min) base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with PSNR score @@ -174,7 +169,6 @@ def psnr( tensor(2.5527) """ - if data_range is None: data_range = max(target.max() - target.min(), pred.max() - pred.min()) else: @@ -187,16 +181,19 @@ def psnr( def _gaussian_kernel(channel, kernel_size, sigma, device): - def gaussian(kernel_size, sigma, device): + def _gaussian(kernel_size, sigma, device): gauss = torch.arange( - start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device + start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, + step=1, + dtype=torch.float32, + device=device ) gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2))) return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device) - gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device) + gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) @@ -219,12 +216,11 @@ def ssim( target: ground truth image kernel_size: size of the gaussian kernel (default: (11, 11)) sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass away - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 @@ -241,7 +237,6 @@ def ssim( tensor(0.9219) """ - if pred.dtype != target.dtype: raise TypeError( "Expected `pred` and `target` to have the same data type." diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index acaaf26ade563..66963f19bcc56 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -47,11 +47,11 @@ def __init__( ): """ Args: - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name='mse') self.reduction = reduction @@ -90,11 +90,11 @@ def __init__( ): """ Args: - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name='rmse') self.reduction = reduction @@ -133,11 +133,11 @@ def __init__( ): """ Args: - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name='mae') self.reduction = reduction @@ -176,11 +176,11 @@ def __init__( ): """ Args: - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name='rmsle') self.reduction = reduction @@ -223,11 +223,11 @@ def __init__( Args: data_range: the range of the data. If None, it is determined from the data (max - min) base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name='psnr') self.data_range = data_range @@ -275,11 +275,11 @@ def __init__( Args: kernel_size: Size of the gaussian kernel (default: (11, 11)) sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass away - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 2074f70db5b46..f8269384b3477 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -37,10 +37,10 @@ @pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ pytest.param(sk_accuracy, accuracy, id='accuracy'), - pytest.param(partial(sk_precision, average='macro'), precision, id='precision'), - pytest.param(partial(sk_recall, average='macro'), recall, id='recall'), - pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'), - pytest.param(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'), + pytest.param(partial(sk_precision, average='micro'), precision, id='precision'), + pytest.param(partial(sk_recall, average='micro'), recall, id='recall'), + pytest.param(partial(sk_f1_score, average='micro'), f1_score, id='f1_score'), + pytest.param(partial(sk_fbeta_score, average='micro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'), pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix') ]) def test_against_sklearn(sklearn_metric, torch_metric): @@ -59,6 +59,27 @@ def test_against_sklearn(sklearn_metric, torch_metric): assert torch.allclose(sk_score, pl_score) +@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(sk_precision, precision, id='precision'), + pytest.param(sk_recall, recall, id='recall'), + pytest.param(sk_f1_score, f1_score, id='f1_score'), + pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score') +]) +def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): + """ Test metrics where the class_reduction parameter have a correponding + value in sklearn """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + pred = torch.randint(10, (300,), device=device) + target = torch.randint(10, (300,), device=device) + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + average=class_reduction) + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + pl_score = torch_metric(pred, target, class_reduction=class_reduction) + assert torch.allclose(sk_score, pl_score) + + def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) expected = torch.stack([ @@ -147,14 +168,14 @@ def test_multilabel_accuracy(): y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.])) - assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.])) - assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.])) + assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.])) + assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.])) + assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.])) with pytest.raises(RuntimeError): - accuracy(y2, torch.zeros_like(y2), reduction='none') + accuracy(y2, torch.zeros_like(y2), class_reduction='none') def test_accuracy(): @@ -198,14 +219,13 @@ def test_confusion_matrix(): assert torch.allclose(cm, to_compare) - @pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) ]) def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, reduction='none') - rec = recall(pred, target, reduction='none') + prec = precision(pred, target, class_reduction='none') + rec = recall(pred, target, class_reduction='none') assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) @@ -217,10 +237,10 @@ def test_precision_recall(pred, target, expected_prec, expected_rec): pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), ]) def test_fbeta_score(pred, target, beta, exp_score): - score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none') + score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) - score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none') + score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) @@ -230,10 +250,10 @@ def test_fbeta_score(pred, target, beta, exp_score): pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), ]) def test_f1_score(pred, target, exp_score): - score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none') + score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) - score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none') + score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index 71d2b6f7735e1..aec54c1806715 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,7 @@ import pytest import torch -from pytorch_lightning.metrics.functional.reduction import reduce +from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce def test_reduce(): @@ -13,3 +13,18 @@ def test_reduce(): with pytest.raises(ValueError): reduce(start_tensor, 'error_reduction') + + +def test_class_reduce(): + num = torch.randint(1, 10, (100,)).float() + denom = torch.randint(10, 20, (100,)).float() + weights = torch.randint(1, 100, (100,)).float() + + assert torch.allclose(class_reduce(num, denom, weights, 'micro'), + torch.sum(num) / torch.sum(denom)) + assert torch.allclose(class_reduce(num, denom, weights, 'macro'), + torch.mean(num / denom)) + assert torch.allclose(class_reduce(num, denom, weights, 'weighted'), + torch.sum(num / denom * (weights / torch.sum(weights)))) + assert torch.allclose(class_reduce(num, denom, weights, 'none'), + num / denom)