Skip to content

Commit

Permalink
Allow threshold to be outside (0,1) domain (#351)
Browse files Browse the repository at this point in the history
* fix threshold

* Update CHANGELOG.md

* flake8

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Jul 7, 2021
1 parent b20cbda commit 0495776
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 35 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [0.x.x] - ????-??-??

## [unreleased] - YYYY-MM-??

### Added

- Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343))


### Changed


### Deprecated


### Removed

- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))


### Fixed



## [0.4.1] - 2021-07-05

### Changed
Expand Down
67 changes: 36 additions & 31 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics import StatScores
from torchmetrics.functional import stat_scores
from torchmetrics.utilities.checks import _input_format_classification

seed_all(42)


def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce=None):
def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None):
# todo: `mdmc_reduce` is unused
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k
preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
)
sk_preds, sk_target = preds.numpy(), target.numpy()

Expand Down Expand Up @@ -75,23 +75,25 @@ def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index
return sk_stats


def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k):
def _sk_stat_scores_mdim_mcls(
preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k
preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
)

if mdmc_reduce == "global":
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k)
return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold)
if mdmc_reduce == "samplewise":
scores = []

for i in range(preds.shape[0]):
pred_i = preds[i, ...].T
target_i = target[i, ...].T
scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k)
scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold)

scores.append(np.expand_dims(scores_i, 0))

Expand Down Expand Up @@ -128,34 +130,32 @@ def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
sts(inputs.preds[0], inputs.target[0])


def test_wrong_threshold():
with pytest.raises(ValueError):
StatScores(threshold=1.5)


@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"])
@pytest.mark.parametrize(
"preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k",
"preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold",
[
(_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None, 0.0),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None,
None
None, 0.0
),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None,
None, 0.0
),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
],
)
class TestStatScores(MetricTester):
Expand All @@ -175,6 +175,7 @@ def test_stat_scores_class(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -192,13 +193,14 @@ def test_stat_scores_class(
multiclass=multiclass,
ignore_index=ignore_index,
top_k=top_k,
threshold=threshold
),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand All @@ -218,6 +220,7 @@ def test_stat_scores_fn(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -234,12 +237,13 @@ def test_stat_scores_fn(
multiclass=multiclass,
ignore_index=ignore_index,
top_k=top_k,
threshold=threshold
),
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand All @@ -257,6 +261,7 @@ def test_stat_scores_differentiability(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -270,7 +275,7 @@ def test_stat_scores_differentiability(
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand Down
3 changes: 0 additions & 3 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ def __init__(
self.ignore_index = ignore_index
self.top_k = top_k

if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if reduce not in ["micro", "macro", "samples"]:
raise ValueError(f"The `reduce` {reduce} is not valid.")

Expand Down

0 comments on commit 0495776

Please sign in to comment.