Skip to content

Commit

Permalink
Merge branch 'master' into edenlightning-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Dec 12, 2020
2 parents c434c34 + 63fb7f9 commit 39e10a8
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 7 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci_test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ jobs:
with:
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: always()
if: failure()

- name: Statistics
if: success()
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ jobs:
with:
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: always()
if: failure()
3 changes: 1 addition & 2 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ jobs:
with:
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: always()
if: failure()

- name: Statistics
if: success()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1
from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve
from pytorch_lightning.metrics.classification.roc import ROC
29 changes: 29 additions & 0 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_fbeta_compute
)
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


class FBeta(Metric):
Expand Down Expand Up @@ -131,6 +132,34 @@ def compute(self) -> torch.Tensor:
self.actual_positives, self.beta, self.average)


# todo: remove in v1.2
class Fbeta(FBeta):
r"""
Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.classification.f_beta.FBeta`
"""
def __init__(
self,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
rank_zero_warn(
"This `Fbeta` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.classification.f_beta import FBeta`."
" It will be removed in v1.2.0", DeprecationWarning
)
super().__init__(
num_classes, beta, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group
)


class F1(FBeta):
"""
Computes F1 metric. F1 metrics correspond to a harmonic mean of the
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
auc,
auroc,
dice_score,
f1_score,
fbeta_score,
get_num_classes,
iou,
multiclass_auroc,
Expand Down
46 changes: 46 additions & 0 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc
from pytorch_lightning.metrics.functional.roc import roc as __roc
from pytorch_lightning.metrics.utils import (
Expand Down Expand Up @@ -871,3 +872,48 @@ def average_precision(
" It will be removed in v1.3.0", DeprecationWarning
)
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# todo: remove in 1.2
def fbeta_score(
pred: torch.Tensor,
target: torch.Tensor,
beta: float,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
) -> torch.Tensor:
"""
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.fbeta`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.functional.f_beta import fbeta`."
" It will be removed in v1.2.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target)
return __fb(preds=pred, target=target, beta=beta, num_classes=num_classes, average=class_reduction)


# todo: remove in 1.2
def f1_score(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
) -> torch.Tensor:
"""
Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.f1`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.functional.f_beta import f1`."
" It will be removed in v1.2.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target)
return __f1(preds=pred, target=target, num_classes=num_classes, average=class_reduction)
14 changes: 14 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def test_tbd_remove_in_v1_2_0():
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')


def test_tbd_remove_in_v1_2_0_metrics():
from pytorch_lightning.metrics.classification import Fbeta
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score

with pytest.deprecated_call(match='will be removed in v1.2'):
Fbeta(2)

with pytest.deprecated_call(match='will be removed in v1.2'):
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)

with pytest.deprecated_call(match='will be removed in v1.2'):
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))


# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
Expand Down

0 comments on commit 39e10a8

Please sign in to comment.