From 4fdea848a77d5c04e400ce7dc40a4511f10d83c8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 00:17:10 +0100 Subject: [PATCH] tests --- .../classification/test_average_precision.py | 97 ----- .../classification/test_precision_recall.py | 348 ------------------ .../test_precision_recall_curve.py | 97 ----- .../metrics/functional/test_classification.py | 25 -- 4 files changed, 567 deletions(-) delete mode 100644 tests/metrics/classification/test_average_precision.py delete mode 100644 tests/metrics/classification/test_precision_recall.py delete mode 100644 tests/metrics/classification/test_precision_recall_curve.py diff --git a/tests/metrics/classification/test_average_precision.py b/tests/metrics/classification/test_average_precision.py deleted file mode 100644 index 7cab20883e970..0000000000000 --- a/tests/metrics/classification/test_average_precision.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision_score - -from pytorch_lightning.metrics.classification.average_precision import AveragePrecision -from pytorch_lightning.metrics.functional.average_precision import average_precision -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_average_precision_score(y_true, probas_pred, num_classes=1): - if num_classes == 1: - return sk_average_precision_score(y_true, probas_pred) - - res = [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) - return res - - -def _sk_avg_prec_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestAveragePrecision(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_average_precision_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['scores', 'target', 'expected_score'], - [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), - ] -) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py deleted file mode 100644 index c9e5467414832..0000000000000 --- a/tests/metrics/classification/test_precision_recall.py +++ /dev/null @@ -1,348 +0,0 @@ -from functools import partial -from typing import Callable, Optional - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_score, recall_score -from torchmetrics import Metric -from torchmetrics.classification.checks import _input_format_classification - -from pytorch_lightning.metrics import Precision, Recall -from pytorch_lightning.metrics.functional import precision, precision_recall, recall -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): - if average == "none": - average = None - if num_classes == 1: - average = "binary" - - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass - - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) - - return sk_scores - - -def _sk_prec_recall_multidim_multiclass( - preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - - if mdmc_average == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) - elif mdmc_average == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) - - scores.append(np.expand_dims(scores_i, 0)) - - return np.concatenate(scores).mean(axis=0) - - -@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - precision_recall( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_zero_division(metric_class, metric_fn): - """ Test that zero_division works correctly (currently should just set to 0). """ - - preds = torch.tensor([1, 2, 1, 1]) - target = torch.tensor([2, 1, 2, 1]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = torch.tensor([1, 1, 0, 0]) - target = torch.tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - - assert result_cl == result_fn == 0 - - -@pytest.mark.parametrize( - "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", - [ - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", - _sk_prec_recall_multidim_multiclass - ), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", - _sk_prec_recall_multidim_multiclass - ), - ], -) -class TestPrecisionRecall(MetricTester): - - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - check_dist_sync_on_step=True, - check_batch=True, - ) - - def test_precision_recall_fn( - self, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - ) - - -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -def test_precision_recall_joint(average): - """A simple test of the joint precision_recall metric. - - No need to test this thorougly, as it is just a combination of precision and recall, - which are already tested thoroughly. - """ - - precision_result = precision( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - recall_result = recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - prec_recall_result = precision_recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - assert torch.equal(precision_result, prec_recall_result[0]) - assert torch.equal(recall_result, prec_recall_result[1]) - - -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -@pytest.mark.parametrize( - "k, preds, target, average, expected_prec, expected_recall", - [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), - ], -) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: torch.Tensor, - target: torch.Tensor, - average: str, - expected_prec: torch.Tensor, - expected_recall: torch.Tensor, -): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee - the corectness of results. - """ - - class_metric = metric_class(top_k=k, average=average, num_classes=3) - class_metric.update(preds, target) - - if metric_class.__name__ == "Precision": - result = expected_prec - else: - result = expected_recall - - assert torch.equal(class_metric.compute(), result) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) diff --git a/tests/metrics/classification/test_precision_recall_curve.py b/tests/metrics/classification/test_precision_recall_curve.py deleted file mode 100644 index 6a60e1fd36fdd..0000000000000 --- a/tests/metrics/classification/test_precision_recall_curve.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve - -from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve -from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): - """ Adjusted comparison function that can also handles multiclass """ - if num_classes == 1: - return sk_precision_recall_curve(y_true, probas_pred) - - precision, recall, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - return precision, recall, thresholds - - -def _sk_prec_rc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestPrecisionRecallCurve(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=PrecisionRecallCurve, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=precision_recall_curve, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], - [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])] -) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index bca50867dcb44..0ee0882b267e2 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -4,7 +4,6 @@ from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score -from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve def test_onehot(): @@ -54,30 +53,6 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ - pytest.param(1, 1., 42), - pytest.param(None, 1., 42), -]) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_everything(0) - pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 - target = torch.tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - - assert isinstance(tps, torch.Tensor) - assert isinstance(fps, torch.Tensor) - assert isinstance(thresh, torch.Tensor) - assert tps.shape == (exp_shape, ) - assert fps.shape == (exp_shape, ) - assert thresh.shape == (exp_shape, ) - - @pytest.mark.parametrize(['pred', 'target', 'expected'], [ pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),