Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric aggregation testing #3517

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
af12f24
aggregation testing
SkafteNicki Sep 16, 2020
b8dc18d
add more tests
Sep 16, 2020
64b6fa2
mse
Sep 16, 2020
a334380
more tests
Sep 16, 2020
7195d57
fix tests
Sep 16, 2020
2567dc0
Merge remote-tracking branch 'upstream/master' into metrics/aggregati…
SkafteNicki Sep 17, 2020
206bf4c
fix doctest
SkafteNicki Sep 17, 2020
a9f729b
fix codefactor
SkafteNicki Sep 17, 2020
4b7c909
fix import error
SkafteNicki Sep 17, 2020
f256823
fix doctest
SkafteNicki Sep 17, 2020
3cc4dae
revert docfix
SkafteNicki Sep 18, 2020
aeac804
test for model integration
SkafteNicki Sep 18, 2020
ffc9996
fix integration test
SkafteNicki Sep 18, 2020
6582e9b
added test cases
SkafteNicki Sep 21, 2020
029b76d
fix rmsle
SkafteNicki Sep 21, 2020
7b6b3e8
aggregation testing
SkafteNicki Sep 16, 2020
d23f607
add more tests
Sep 16, 2020
d657cce
mse
Sep 16, 2020
b361bb0
more tests
Sep 16, 2020
6843e30
fix tests
Sep 16, 2020
b8a8c84
fix doctest
SkafteNicki Sep 17, 2020
6e48f1d
fix codefactor
SkafteNicki Sep 17, 2020
c73f86b
fix import error
SkafteNicki Sep 17, 2020
2477fcd
fix doctest
SkafteNicki Sep 17, 2020
b816a14
revert docfix
SkafteNicki Sep 18, 2020
427aa8c
test for model integration
SkafteNicki Sep 18, 2020
b400bb8
fix integration test
SkafteNicki Sep 18, 2020
5941795
fix psnr
SkafteNicki Sep 21, 2020
8704445
add warning/valueerror to embedding similarity
SkafteNicki Sep 21, 2020
e1b6559
fixed f scores
SkafteNicki Sep 21, 2020
d3adf8f
merge
SkafteNicki Sep 21, 2020
66f41f6
disable some test
SkafteNicki Sep 21, 2020
48beda9
fix tests
Sep 23, 2020
630f101
fixing codefactor
Sep 23, 2020
c1d6c37
fix pep8
Sep 23, 2020
46fa39c
changelog
Sep 23, 2020
750ac09
fix doctest
Sep 23, 2020
79319a0
cleaning test
Sep 23, 2020
b0a33c1
fix pickle error
Sep 23, 2020
c02fb66
pickle fix
Sep 23, 2020
3784e49
Merge remote-tracking branch 'upstream/master' into metrics/aggregati…
SkafteNicki Sep 24, 2020
235d661
fix pickle error
SkafteNicki Sep 24, 2020
3978a8a
Apply suggestions from code review
SkafteNicki Sep 28, 2020
34e36db
code cleanup + changes based on suggestions
SkafteNicki Sep 28, 2020
b923ad5
Merge branch 'metrics/aggregation_testing' of https://github.com/Skaf…
SkafteNicki Sep 28, 2020
94fb0a1
update based on suggestion
Sep 30, 2020
015b138
update based on suggestions
Sep 30, 2020
f32402c
Apply suggestions from code review
SkafteNicki Sep 30, 2020
b2eb760
Merge branch 'master' into metrics/aggregation_testing
SkafteNicki Sep 30, 2020
9490025
Merge branch 'master' into metrics/aggregation_testing
SkafteNicki Sep 30, 2020
04c9f45
Merge branch 'master' into metrics/aggregation_testing
SkafteNicki Oct 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))

- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
114 changes: 66 additions & 48 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
auroc,
average_precision,
confusion_matrix,
_confmat_normalize,
dice_score,
f1_score,
fbeta_score,
iou,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
precision_recall_curve,
recall,
roc,
precision_recall
)
from pytorch_lightning.metrics.functional.reduction import class_reduce
from pytorch_lightning.metrics.metric import TensorMetric


Expand All @@ -44,8 +45,8 @@ class Accuracy(TensorMetric):
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target).item()
0.75
>>> metric(pred, target)
tensor(0.7500)

"""

Expand Down Expand Up @@ -84,7 +85,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, class_reduction=self.class_reduction)
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, sups = output['tps'], output['sups']
return class_reduce(tps, sups, sups, class_reduction=self.class_reduction)


class ConfusionMatrix(TensorMetric):
Expand Down Expand Up @@ -135,16 +143,16 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize,
normalize=False, # we normalize after ddp sync
num_classes=self.num_classes)

def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""Aggregates results by stacking them instead of concatenating before averaging.

Returns:
the aggregated results
"""
return torch.stack(tensors).mean(0)
@staticmethod
def compute(self, data: Any, output: Any):
""" Confusion matrix normalization needs to happen after ddp sync """
confmat = output
if self.normalize:
confmat = _confmat_normalize(confmat)
return confmat


class PrecisionRecallCurve(TensorMetric):
Expand Down Expand Up @@ -202,7 +210,8 @@ def forward(
- recall values
- threshold values
"""
return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight, pos_label=self.pos_label)


class Precision(TensorMetric):
Expand Down Expand Up @@ -256,9 +265,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
A Tensor with the classification score.
"""
return precision(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, fps, sups = output['tps'], output['fps'], output['sups']
return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction)


class Recall(TensorMetric):
Expand Down Expand Up @@ -313,10 +328,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
A Tensor with the classification score.
"""
return recall(pred=pred,
target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, fns, sups = output['tps'], output['fns'], output['sups']
return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction)


class AveragePrecision(TensorMetric):
Expand Down Expand Up @@ -470,12 +490,28 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
torch.Tensor: classification score
"""
return fbeta_score(pred=pred, target=target,
beta=self.beta, num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
""" tps, fps, fns, sups needs to be synced before we do any calculations """
tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups']

intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro'
precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction)
recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction)

num = (1 + self.beta ** 2) * precision * recall
denom = ((self.beta ** 2) * precision + recall)
if intermidiate_reduction == 'micro':
return torch.sum(num) / torch.sum(denom)
return class_reduce(num, denom, sups, class_reduction=self.class_reduction)

class F1(TensorMetric):

class F1(FBeta):
"""
Computes the F1 score, which is the harmonic mean of the precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Expand Down Expand Up @@ -507,29 +543,11 @@ def __init__(

reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="f1",
reduce_group=reduce_group,
)

self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation

Args:
pred: predicted labels
target: groundtruth labels

Return:
torch.Tensor: classification score
"""
return f1_score(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
super().__init__(beta=1.0,
num_classes=num_classes,
class_reduction=class_reduction,
reduce_group=reduce_group)
self.name = "f1"


class ROC(TensorMetric):
Expand Down
32 changes: 23 additions & 9 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def accuracy(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro'
class_reduction: str = 'micro',
return_state: bool = False
) -> torch.Tensor:
"""
Computes the accuracy classification score
Expand All @@ -256,7 +257,8 @@ def accuracy(
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class

return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
A Tensor with the accuracy score.

Expand All @@ -270,10 +272,21 @@ def accuracy(
"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

if return_state:
return {'tps': tps, 'sups': sups}
return class_reduce(tps, sups, sups, class_reduction=class_reduction)


def _confmat_normalize(cm):
""" Normalization function for confusion matrix """
cm = cm / cm.sum(-1, keepdim=True)
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
return cm


def confusion_matrix(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -311,11 +324,7 @@ def confusion_matrix(
cm = bins.reshape(num_classes, num_classes).squeeze().float()

if normalize:
cm = cm / cm.sum(-1, keepdim=True)
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
cm = _confmat_normalize(cm)

return cm

Expand All @@ -325,7 +334,8 @@ def precision_recall(
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
return_support: bool = False
return_support: bool = False,
return_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes precision and recall for different thresholds
Expand All @@ -342,6 +352,8 @@ def precision_recall(
- ``'none'``: returns calculated metric per class

return_support: returns the support for each class, need for fbeta/f1 calculations
return_state: returns a internal state that can be ddp reduced
before doing the final calculation

Return:
Tensor with precision and recall
Expand All @@ -358,6 +370,8 @@ def precision_recall(

precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction)
recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction)
if return_state:
return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups}
if return_support:
return precision, recall, sups
return precision, recall
Expand Down
Loading