Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow logit input in classification metrics #200

Merged
merged 40 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
db844e5
remove checks
SkafteNicki Apr 19, 2021
107138d
Merge branch 'master' into unnormalize
SkafteNicki Apr 23, 2021
972e719
accuracy fix
SkafteNicki Apr 23, 2021
dfe4a4f
Merge branch 'master' into unnormalize
SkafteNicki Apr 26, 2021
96094da
fix docstring
SkafteNicki Apr 26, 2021
20c209a
Merge branch 'master' into unnormalize
SkafteNicki Apr 26, 2021
6d87327
testing
SkafteNicki Apr 26, 2021
72ea906
pep8
SkafteNicki Apr 26, 2021
ff5e09a
docstring
SkafteNicki Apr 26, 2021
c3dd42f
stat score
SkafteNicki Apr 26, 2021
6e93e89
hamming
SkafteNicki Apr 27, 2021
b3b7c29
confusion_matrix
SkafteNicki Apr 27, 2021
78f891a
Update CHANGELOG.md
SkafteNicki Apr 27, 2021
93f78b7
Merge branch 'master' into unnormalize
tchaton Apr 28, 2021
16700e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2021
944631d
Merge branch 'master' into unnormalize
Borda Apr 28, 2021
48c920e
Merge branch 'master' into unnormalize
Borda Apr 28, 2021
071b87a
Merge branch 'master' into unnormalize
SkafteNicki Apr 30, 2021
df830de
fix tests
SkafteNicki Apr 30, 2021
17fe9f2
Merge branch 'master' into unnormalize
SkafteNicki Apr 30, 2021
6d3fe26
Merge branch 'master' into unnormalize
mergify[bot] Apr 30, 2021
defd12d
Merge branch 'master' into unnormalize
mergify[bot] Apr 30, 2021
fcca45e
Merge branch 'master' into unnormalize
mergify[bot] May 3, 2021
4fb4dee
Merge branch 'master' into unnormalize
mergify[bot] May 3, 2021
4fa366f
Merge branch 'master' into unnormalize
mergify[bot] May 3, 2021
000c281
Merge branch 'master' into unnormalize
mergify[bot] May 3, 2021
875fa9e
Merge branch 'master' into unnormalize
Borda May 4, 2021
3c6d651
Merge branch 'master' into unnormalize
mergify[bot] May 4, 2021
674577c
Merge branch 'master' into unnormalize
mergify[bot] May 4, 2021
18b7a78
Merge branch 'master' into unnormalize
SkafteNicki May 4, 2021
ede4d1c
Merge branch 'master' into unnormalize
mergify[bot] May 4, 2021
ec01731
Merge branch 'master' into unnormalize
mergify[bot] May 4, 2021
3a9c94d
Merge branch 'master' into unnormalize
SkafteNicki May 6, 2021
096a3d9
Merge branch 'master' into unnormalize
mergify[bot] May 7, 2021
b54b6b2
Merge branch 'master' into unnormalize
mergify[bot] May 8, 2021
644e509
Merge branch 'master' into unnormalize
mergify[bot] May 9, 2021
64d6b08
Merge branch 'master' into unnormalize
mergify[bot] May 10, 2021
3716e36
Merge branch 'master' into unnormalize
mergify[bot] May 10, 2021
d087543
Merge branch 'master' into unnormalize
mergify[bot] May 11, 2021
53103e2
Merge branch 'master' into unnormalize
mergify[bot] May 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added pre-gather reduction in the case of `dist_reduce_fx="cat"` to reduce communication cost ([#217](https://github.com/PyTorchLightning/metrics/pull/217))


- Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200))


### Changed

Expand Down
4 changes: 2 additions & 2 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c

"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-class with logits or probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with logits or probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"

.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
Expand Down
18 changes: 16 additions & 2 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_binary_logits = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_multilabel_prob = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
Expand All @@ -38,11 +42,17 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
)

_input_multilabel_logits = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
)

_input_multilabel = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
)


_input_multilabel_multidim = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
Expand All @@ -54,13 +64,17 @@

_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target)

__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True)
__mc_prob_logits = torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
__mc_prob_preds = __mc_prob_logits.abs() / __mc_prob_logits.abs().sum(dim=2, keepdim=True)

_input_multiclass_prob = Input(
preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_multiclass_logits = Input(
preds=__mc_prob_logits, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_multiclass = Input(
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from sklearn.metrics import accuracy_score as sk_accuracy
from torch import tensor

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
Expand Down Expand Up @@ -55,13 +57,16 @@ def _sk_accuracy(preds, target, subset_accuracy):
@pytest.mark.parametrize(
"preds, target, subset_accuracy",
[
(_input_binary_logits.preds, _input_binary_logits.target, False),
(_input_binary_prob.preds, _input_binary_prob.target, False),
(_input_binary.preds, _input_binary.target, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, True),
(_input_mlb_logits.preds, _input_mlb_logits.target, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, False),
(_input_mlb.preds, _input_mlb.target, True),
(_input_mlb.preds, _input_mlb.target, False),
(_input_mcls_prob.preds, _input_mcls_prob.target, False),
(_input_mcls_logits.preds, _input_mcls_logits.target, False),
(_input_mcls.preds, _input_mcls.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, True),
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
Expand Down Expand Up @@ -112,10 +114,13 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None):
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes, multilabel",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False),
(_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False)]
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from sklearn.metrics import f1_score, fbeta_score
from torch import Tensor

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
Expand Down Expand Up @@ -176,10 +178,13 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize(
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_fbeta_f1),
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1),
(_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1),
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1),
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_fbeta_f1),
(_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1),
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1),
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_fbeta_f1),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_fbeta_f1_multidim_multiclass),
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import pytest
from sklearn.metrics import hamming_loss as sk_hamming_loss

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
Expand All @@ -43,10 +45,13 @@ def _sk_hamming_loss(preds, target):
@pytest.mark.parametrize(
"preds, target",
[
(_input_binary_logits.preds, _input_binary_logits.target),
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_binary.preds, _input_binary.target),
(_input_mlb_logits.preds, _input_mlb_logits.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
(_input_mlb.preds, _input_mlb.target),
(_input_mcls_logits.preds, _input_mcls_logits.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mcls.preds, _input_mcls.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
Expand Down
11 changes: 0 additions & 11 deletions tests/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,6 @@ def test_threshold():
########################################################################


@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5])
def test_incorrect_threshold(threshold):
preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
with pytest.raises(ValueError):
_input_format_classification(preds, target, threshold=threshold)


@pytest.mark.parametrize(
"preds, target, num_classes, multiclass",
[
Expand All @@ -234,8 +227,6 @@ def test_incorrect_threshold(threshold):
(randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None),
# Preds negative integers
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
# Negative probabilities
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
# multiclass=False and target > 1
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
# multiclass=False and preds integers with > 1
Expand All @@ -254,8 +245,6 @@ def test_incorrect_threshold(threshold):
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
# multiclass=False, with C dimension > 2
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
# Probs of multiclass preds do not sum up to 1
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
# Max target larger or equal to C dimension
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None),
# C dimension not equal to num_classes
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from sklearn.metrics import precision_score, recall_score
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
Expand Down Expand Up @@ -175,10 +177,13 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize(
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_prec_recall),
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
(_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
(_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
Expand Down
7 changes: 6 additions & 1 deletion tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from sklearn.metrics import multilabel_confusion_matrix
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mcls
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
Expand Down Expand Up @@ -136,12 +138,15 @@ def test_wrong_threshold():
@pytest.mark.parametrize(
"preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k",
[
(_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
Expand Down
Loading