Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename args: is_multiclass -> multiclass #162

Merged
merged 7 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Rename argument `is_multiclass` -> `multiclass` ([#162](https://github.com/PyTorchLightning/metrics/pull/162))


### Removed

Expand Down
14 changes: 7 additions & 7 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])


Using the is_multiclass parameter
---------------------------------
Using the multiclass parameter
------------------------------

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.

For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument. Let's see how this is used on the example of
``multiclass`` argument. Let's see how this is used on the example of
:class:`~torchmetrics.StatScores` metric.

First, let's consider the case with label predictions with 2 classes, which we want to
Expand All @@ -83,15 +83,15 @@ treat as binary.
target = torch.tensor([1, 1, 0])

As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
as multi-class. We can set ``multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Expand All @@ -104,13 +104,13 @@ but we would like to treat them as 2-class multi-class, to obtain the metric for
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])

In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
In this case we can set ``multiclass=True``, to treat the inputs as multi-class.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])

Expand Down
22 changes: 11 additions & 11 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
seed_all(42)


def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None):
if average == "none":
average = None
if num_classes == 1:
Expand All @@ -49,7 +49,7 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno
pass

sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

Expand All @@ -62,10 +62,10 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno


def _sk_fbeta_f1_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)

if mdmc_average == "global":
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1),
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_fbeta_f1(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -233,7 +233,7 @@ def test_fbeta_f1(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
Expand All @@ -242,7 +242,7 @@ def test_fbeta_f1(
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand All @@ -258,7 +258,7 @@ def test_fbeta_f1_functional(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -282,15 +282,15 @@ def test_fbeta_f1_functional(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand Down
40 changes: 20 additions & 20 deletions tests/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _mlmd_prob_to_mc_preds_tr(x):


@pytest.mark.parametrize(
"inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target",
"inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target",
[
#############################
# Test usual expected cases
Expand Down Expand Up @@ -169,7 +169,7 @@ def _mlmd_prob_to_mc_preds_tr(x):
(_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn),
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
def test_usual_cases(inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target):

def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
Expand All @@ -180,7 +180,7 @@ def __get_data_type_enum(str_exp_mode):
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)

Expand All @@ -194,7 +194,7 @@ def __get_data_type_enum(str_exp_mode):
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def test_incorrect_threshold(threshold):


@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass",
"preds, target, num_classes, multiclass",
[
# Target not integer
(randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None),
Expand All @@ -236,9 +236,9 @@ def test_incorrect_threshold(threshold):
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
# Negative probabilities
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
# is_multiclass=False and target > 1
# multiclass=False and target > 1
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
# is_multiclass=False and preds integers with > 1
# multiclass=False and preds integers with > 1
(randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False),
# Wrong batch size
(randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None),
Expand All @@ -252,7 +252,7 @@ def test_incorrect_threshold(threshold):
(rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None),
# #dims in preds = 1 + #dims in target, preds not float
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
# is_multiclass=False, with C dimension > 2
# multiclass=False, with C dimension > 2
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
# Probs of multiclass preds do not sum up to 1
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
Expand All @@ -266,32 +266,32 @@ def test_incorrect_threshold(threshold):
(randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None),
# Max preds larger than num_classes (with #dim preds = #dims target)
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
# Num_classes=1, but is_multiclass not false
# Num_classes=1, but multiclass not false
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
# is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
# multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
(randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
# Multilabel input with implied class dimension != num_classes
(rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
# Multilabel input with is_multiclass=True, but num_classes != 2 (or None)
# Multilabel input with multiclass=True, but num_classes != 2 (or None)
(rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True),
# Binary input, num_classes > 2
(rand(size=(7, )), randint(high=2, size=(7, )), 4, None),
# Binary input, num_classes == 2 and is_multiclass not True
# Binary input, num_classes == 2 and multiclass not True
(rand(size=(7, )), randint(high=2, size=(7, )), 2, None),
(rand(size=(7, )), randint(high=2, size=(7, )), 2, False),
# Binary input, num_classes == 1 and is_multiclass=True
# Binary input, num_classes == 1 and multiclass=True
(rand(size=(7, )), randint(high=2, size=(7, )), 1, True),
],
)
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
def test_incorrect_inputs(preds, target, num_classes, multiclass):
with pytest.raises(ValueError):
_input_format_classification(
preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)


@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, top_k",
"preds, target, num_classes, multiclass, top_k",
[
# Topk set with non (md)mc or ml prob data
(_bin.preds[0], _bin.target[0], None, None, 2),
Expand All @@ -304,23 +304,23 @@ def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0),
# top_k = float
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123),
# top_k =2 with 2 classes, is_multiclass=False
# top_k =2 with 2 classes, multiclass=False
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2),
# top_k = number of classes (C dimension)
(_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES),
# is_multiclass = True for ml prob inputs, top_k set
# multiclass = True for ml prob inputs, top_k set
(_ml_prob.preds[0], _ml_prob.target[0], None, True, 2),
# top_k = num_classes for ml prob inputs
(_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES),
],
)
def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k):
def test_incorrect_inputs_topk(preds, target, num_classes, multiclass, top_k):
with pytest.raises(ValueError):
_input_format_classification(
preds=preds,
target=target,
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)
22 changes: 11 additions & 11 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
seed_all(42)


def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None):
# todo: `mdmc_average` is unused
if average == "none":
average = None
Expand All @@ -50,7 +50,7 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i
pass

sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

Expand All @@ -63,10 +63,10 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i


def _sk_prec_recall_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)

if mdmc_average == "global":
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_precision_recall_class(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -233,7 +233,7 @@ def test_precision_recall_class(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
Expand All @@ -242,7 +242,7 @@ def test_precision_recall_class(
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand All @@ -258,7 +258,7 @@ def test_precision_recall_fn(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -283,15 +283,15 @@ def test_precision_recall_fn(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand Down
Loading