Skip to content

Commit

Permalink
Merge branch 'master' into reduce_comms
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored May 3, 2021
2 parents 4d1ee30 + df6e4ba commit d13702d
Show file tree
Hide file tree
Showing 17 changed files with 866 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
args: [--settings-path, "./pyproject.toml"]

- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0
rev: v0.31.0
hooks:
- id: yapf
name: formatting
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added Specificity metric ([#210](https://github.com/PyTorchLightning/metrics/pull/210))


- Added `is_differentiable` property to `AUC`, `AUROC`, `CohenKappa` and `AveragePrecision` ([#178](https://github.com/PyTorchLightning/metrics/pull/178))


Expand All @@ -17,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164))



### Deprecated
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ select_topk [func]
.. autofunction:: torchmetrics.utilities.data.select_topk
:noindex:

specificity [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.specificity
:noindex:


stat_scores [func]
~~~~~~~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ ROC
:noindex:


Specificity
~~~~~~~~~~~

.. autoclass:: torchmetrics.Specificity
:noindex:


StatScores
~~~~~~~~~~

Expand Down
43 changes: 37 additions & 6 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, val_to_return):
super().__init__()
self._num_updates = 0
self._val_to_return = val_to_return
self._update_called = True

def update(self, *args, **kwargs) -> None:
self._num_updates += 1
Expand Down Expand Up @@ -57,6 +58,9 @@ def test_metrics_add(second_operand, expected_result):
assert isinstance(final_add, CompositionalMetric)
assert isinstance(final_radd, CompositionalMetric)

final_add.update()
final_radd.update()

assert torch.allclose(expected_result, final_add.compute())
assert torch.allclose(expected_result, final_radd.compute())

Expand All @@ -75,6 +79,8 @@ def test_metrics_and(second_operand, expected_result):
assert isinstance(final_and, CompositionalMetric)
assert isinstance(final_rand, CompositionalMetric)

final_and.update()
final_rand.update()
assert torch.allclose(expected_result, final_and.compute())
assert torch.allclose(expected_result, final_rand.compute())

Expand All @@ -95,6 +101,7 @@ def test_metrics_eq(second_operand, expected_result):

assert isinstance(final_eq, CompositionalMetric)

final_eq.update()
# can't use allclose for bool tensors
assert (expected_result == final_eq.compute()).all()

Expand All @@ -116,6 +123,7 @@ def test_metrics_floordiv(second_operand, expected_result):

assert isinstance(final_floordiv, CompositionalMetric)

final_floordiv.update()
assert torch.allclose(expected_result, final_floordiv.compute())


Expand All @@ -135,6 +143,7 @@ def test_metrics_ge(second_operand, expected_result):

assert isinstance(final_ge, CompositionalMetric)

final_ge.update()
# can't use allclose for bool tensors
assert (expected_result == final_ge.compute()).all()

Expand All @@ -155,6 +164,7 @@ def test_metrics_gt(second_operand, expected_result):

assert isinstance(final_gt, CompositionalMetric)

final_gt.update()
# can't use allclose for bool tensors
assert (expected_result == final_gt.compute()).all()

Expand All @@ -175,6 +185,7 @@ def test_metrics_le(second_operand, expected_result):

assert isinstance(final_le, CompositionalMetric)

final_le.update()
# can't use allclose for bool tensors
assert (expected_result == final_le.compute()).all()

Expand All @@ -195,6 +206,7 @@ def test_metrics_lt(second_operand, expected_result):

assert isinstance(final_lt, CompositionalMetric)

final_lt.update()
# can't use allclose for bool tensors
assert (expected_result == final_lt.compute()).all()

Expand All @@ -210,6 +222,7 @@ def test_metrics_matmul(second_operand, expected_result):

assert isinstance(final_matmul, CompositionalMetric)

final_matmul.update()
assert torch.allclose(expected_result, final_matmul.compute())


Expand All @@ -228,6 +241,8 @@ def test_metrics_mod(second_operand, expected_result):
final_mod = first_metric % second_operand

assert isinstance(final_mod, CompositionalMetric)

final_mod.update()
# prevent Runtime error for PT 1.8 - Long did not match Float
assert torch.allclose(expected_result.to(float), final_mod.compute().to(float))

Expand All @@ -250,6 +265,8 @@ def test_metrics_mul(second_operand, expected_result):
assert isinstance(final_mul, CompositionalMetric)
assert isinstance(final_rmul, CompositionalMetric)

final_mul.update()
final_rmul.update()
assert torch.allclose(expected_result, final_mul.compute())
assert torch.allclose(expected_result, final_rmul.compute())

Expand All @@ -270,6 +287,7 @@ def test_metrics_ne(second_operand, expected_result):

assert isinstance(final_ne, CompositionalMetric)

final_ne.update()
# can't use allclose for bool tensors
assert (expected_result == final_ne.compute()).all()

Expand All @@ -288,6 +306,8 @@ def test_metrics_or(second_operand, expected_result):
assert isinstance(final_or, CompositionalMetric)
assert isinstance(final_ror, CompositionalMetric)

final_or.update()
final_ror.update()
assert torch.allclose(expected_result, final_or.compute())
assert torch.allclose(expected_result, final_ror.compute())

Expand All @@ -308,6 +328,7 @@ def test_metrics_pow(second_operand, expected_result):

assert isinstance(final_pow, CompositionalMetric)

final_pow.update()
assert torch.allclose(expected_result, final_pow.compute())


Expand All @@ -322,6 +343,8 @@ def test_metrics_rfloordiv(first_operand, expected_result):
final_rfloordiv = first_operand // second_operand

assert isinstance(final_rfloordiv, CompositionalMetric)

final_rfloordiv.update()
assert torch.allclose(expected_result, final_rfloordiv.compute())


Expand All @@ -336,6 +359,7 @@ def test_metrics_rmatmul(first_operand, expected_result):

assert isinstance(final_rmatmul, CompositionalMetric)

final_rmatmul.update()
assert torch.allclose(expected_result, final_rmatmul.compute())


Expand All @@ -350,6 +374,7 @@ def test_metrics_rmod(first_operand, expected_result):

assert isinstance(final_rmod, CompositionalMetric)

final_rmod.update()
assert torch.allclose(expected_result, final_rmod.compute())


Expand All @@ -367,7 +392,7 @@ def test_metrics_rpow(first_operand, expected_result):
final_rpow = first_operand**second_operand

assert isinstance(final_rpow, CompositionalMetric)

final_rpow.update()
assert torch.allclose(expected_result, final_rpow.compute())


Expand All @@ -386,7 +411,7 @@ def test_metrics_rsub(first_operand, expected_result):
final_rsub = first_operand - second_operand

assert isinstance(final_rsub, CompositionalMetric)

final_rsub.update()
assert torch.allclose(expected_result, final_rsub.compute())


Expand All @@ -406,7 +431,7 @@ def test_metrics_rtruediv(first_operand, expected_result):
final_rtruediv = first_operand / second_operand

assert isinstance(final_rtruediv, CompositionalMetric)

final_rtruediv.update()
assert torch.allclose(expected_result, final_rtruediv.compute())


Expand All @@ -425,7 +450,7 @@ def test_metrics_sub(second_operand, expected_result):
final_sub = first_metric - second_operand

assert isinstance(final_sub, CompositionalMetric)

final_sub.update()
assert torch.allclose(expected_result, final_sub.compute())


Expand All @@ -445,7 +470,7 @@ def test_metrics_truediv(second_operand, expected_result):
final_truediv = first_metric / second_operand

assert isinstance(final_truediv, CompositionalMetric)

final_truediv.update()
assert torch.allclose(expected_result, final_truediv.compute())


Expand All @@ -463,6 +488,8 @@ def test_metrics_xor(second_operand, expected_result):
assert isinstance(final_xor, CompositionalMetric)
assert isinstance(final_rxor, CompositionalMetric)

final_xor.update()
final_rxor.update()
assert torch.allclose(expected_result, final_xor.compute())
assert torch.allclose(expected_result, final_rxor.compute())

Expand All @@ -473,7 +500,7 @@ def test_metrics_abs():
final_abs = abs(first_metric)

assert isinstance(final_abs, CompositionalMetric)

final_abs.update()
assert torch.allclose(tensor(1), final_abs.compute())


Expand All @@ -482,6 +509,7 @@ def test_metrics_invert():

final_inverse = ~first_metric
assert isinstance(final_inverse, CompositionalMetric)
final_inverse.update()
assert torch.allclose(tensor(-2), final_inverse.compute())


Expand All @@ -490,6 +518,7 @@ def test_metrics_neg():

final_neg = neg(first_metric)
assert isinstance(final_neg, CompositionalMetric)
final_neg.update()
assert torch.allclose(tensor(-1), final_neg.compute())


Expand All @@ -498,6 +527,7 @@ def test_metrics_pos():

final_pos = pos(first_metric)
assert isinstance(final_pos, CompositionalMetric)
final_pos.update()
assert torch.allclose(tensor(1), final_pos.compute())


Expand All @@ -510,6 +540,7 @@ def test_metrics_getitem(value, idx, expected_result):

final_getitem = first_metric[idx]
assert isinstance(final_getitem, CompositionalMetric)
final_getitem.update()
assert torch.allclose(expected_result, final_getitem.compute())


Expand Down
22 changes: 22 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,28 @@ def test_device_and_dtype_transfer(tmpdir):
assert metric.x.dtype == torch.float16


def test_warning_on_compute_before_update():
metric = DummyMetricSum()

# make sure everything is fine with forward
with pytest.warns(None) as record:
val = metric(1)
assert not record

metric.reset()

with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'):
val = metric.compute()
assert val == 0.0

# after update things should be fine
metric.update(2.0)
with pytest.warns(None) as record:
val = metric.compute()
assert not record
assert val == 2.0


def test_metric_scripts():
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())
8 changes: 1 addition & 7 deletions tests/classification/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,7 @@ def test_auc_functional(self, x, y, reorder):
@pytest.mark.parametrize("reorder", [True, False])
def test_auc_differentiability(self, x, y, reorder):
self.run_differentiability_test(
preds=x,
target=y,
metric_module=AUC,
metric_functional=auc,
metric_args={
"reorder": reorder
}
preds=x, target=y, metric_module=AUC, metric_functional=auc, metric_args={"reorder": reorder}
)


Expand Down
12 changes: 6 additions & 6 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macr
@pytest.mark.parametrize("average", ['macro', 'weighted', 'micro'])
@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)]
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)]
)
class TestAUROC(MetricTester):

Expand Down
18 changes: 9 additions & 9 deletions tests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None):

@pytest.mark.parametrize("weights", ['linear', 'quadratic', None])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES)]
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES)]
)
class TestCohenKappa(MetricTester):
atol = 1e-5
Expand Down
Loading

0 comments on commit d13702d

Please sign in to comment.