Skip to content

Commit

Permalink
rename args: is_multiclass -> multiclass (#162)
Browse files Browse the repository at this point in the history
* rename args: is_multiclass -> multiclass

* chlog

* format
  • Loading branch information
Borda authored Apr 7, 2021
1 parent 92ea4d7 commit 96ab7bd
Show file tree
Hide file tree
Showing 16 changed files with 295 additions and 183 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Rename argument `is_multiclass` -> `multiclass` ([#162](https://github.com/PyTorchLightning/metrics/pull/162))


### Removed

Expand Down
14 changes: 7 additions & 7 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])


Using the is_multiclass parameter
---------------------------------
Using the multiclass parameter
------------------------------

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.

For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument. Let's see how this is used on the example of
``multiclass`` argument. Let's see how this is used on the example of
:class:`~torchmetrics.StatScores` metric.

First, let's consider the case with label predictions with 2 classes, which we want to
Expand All @@ -83,15 +83,15 @@ treat as binary.
target = torch.tensor([1, 1, 0])

As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
as multi-class. We can set ``multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Expand All @@ -104,13 +104,13 @@ but we would like to treat them as 2-class multi-class, to obtain the metric for
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])

In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
In this case we can set ``multiclass=True``, to treat the inputs as multi-class.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])

Expand Down
22 changes: 11 additions & 11 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
seed_all(42)


def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None):
if average == "none":
average = None
if num_classes == 1:
Expand All @@ -49,7 +49,7 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno
pass

sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

Expand All @@ -62,10 +62,10 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno


def _sk_fbeta_f1_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)

if mdmc_average == "global":
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1),
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_fbeta_f1(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -233,7 +233,7 @@ def test_fbeta_f1(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
Expand All @@ -242,7 +242,7 @@ def test_fbeta_f1(
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand All @@ -258,7 +258,7 @@ def test_fbeta_f1_functional(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -282,15 +282,15 @@ def test_fbeta_f1_functional(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand Down
40 changes: 20 additions & 20 deletions tests/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _mlmd_prob_to_mc_preds_tr(x):


@pytest.mark.parametrize(
"inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target",
"inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target",
[
#############################
# Test usual expected cases
Expand Down Expand Up @@ -169,7 +169,7 @@ def _mlmd_prob_to_mc_preds_tr(x):
(_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn),
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
def test_usual_cases(inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target):

def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
Expand All @@ -180,7 +180,7 @@ def __get_data_type_enum(str_exp_mode):
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)

Expand All @@ -194,7 +194,7 @@ def __get_data_type_enum(str_exp_mode):
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def test_incorrect_threshold(threshold):


@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass",
"preds, target, num_classes, multiclass",
[
# Target not integer
(randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None),
Expand All @@ -236,9 +236,9 @@ def test_incorrect_threshold(threshold):
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
# Negative probabilities
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
# is_multiclass=False and target > 1
# multiclass=False and target > 1
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
# is_multiclass=False and preds integers with > 1
# multiclass=False and preds integers with > 1
(randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False),
# Wrong batch size
(randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None),
Expand All @@ -252,7 +252,7 @@ def test_incorrect_threshold(threshold):
(rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None),
# #dims in preds = 1 + #dims in target, preds not float
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
# is_multiclass=False, with C dimension > 2
# multiclass=False, with C dimension > 2
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
# Probs of multiclass preds do not sum up to 1
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
Expand All @@ -266,32 +266,32 @@ def test_incorrect_threshold(threshold):
(randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None),
# Max preds larger than num_classes (with #dim preds = #dims target)
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
# Num_classes=1, but is_multiclass not false
# Num_classes=1, but multiclass not false
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
# is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
# multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
(randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
# Multilabel input with implied class dimension != num_classes
(rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
# Multilabel input with is_multiclass=True, but num_classes != 2 (or None)
# Multilabel input with multiclass=True, but num_classes != 2 (or None)
(rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True),
# Binary input, num_classes > 2
(rand(size=(7, )), randint(high=2, size=(7, )), 4, None),
# Binary input, num_classes == 2 and is_multiclass not True
# Binary input, num_classes == 2 and multiclass not True
(rand(size=(7, )), randint(high=2, size=(7, )), 2, None),
(rand(size=(7, )), randint(high=2, size=(7, )), 2, False),
# Binary input, num_classes == 1 and is_multiclass=True
# Binary input, num_classes == 1 and multiclass=True
(rand(size=(7, )), randint(high=2, size=(7, )), 1, True),
],
)
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
def test_incorrect_inputs(preds, target, num_classes, multiclass):
with pytest.raises(ValueError):
_input_format_classification(
preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)


@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, top_k",
"preds, target, num_classes, multiclass, top_k",
[
# Topk set with non (md)mc or ml prob data
(_bin.preds[0], _bin.target[0], None, None, 2),
Expand All @@ -304,23 +304,23 @@ def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0),
# top_k = float
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123),
# top_k =2 with 2 classes, is_multiclass=False
# top_k =2 with 2 classes, multiclass=False
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2),
# top_k = number of classes (C dimension)
(_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES),
# is_multiclass = True for ml prob inputs, top_k set
# multiclass = True for ml prob inputs, top_k set
(_ml_prob.preds[0], _ml_prob.target[0], None, True, 2),
# top_k = num_classes for ml prob inputs
(_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES),
],
)
def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k):
def test_incorrect_inputs_topk(preds, target, num_classes, multiclass, top_k):
with pytest.raises(ValueError):
_input_format_classification(
preds=preds,
target=target,
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
top_k=top_k,
)
22 changes: 11 additions & 11 deletions tests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
seed_all(42)


def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None):
# todo: `mdmc_average` is unused
if average == "none":
average = None
Expand All @@ -50,7 +50,7 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i
pass

sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

Expand All @@ -63,10 +63,10 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i


def _sk_prec_recall_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass
)

if mdmc_average == "global":
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
"preds, target, num_classes, multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_precision_recall_class(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -233,7 +233,7 @@ def test_precision_recall_class(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
Expand All @@ -242,7 +242,7 @@ def test_precision_recall_class(
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand All @@ -258,7 +258,7 @@ def test_precision_recall_fn(
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
Expand All @@ -283,15 +283,15 @@ def test_precision_recall_fn(
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
multiclass=multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"multiclass": multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Expand Down
Loading

0 comments on commit 96ab7bd

Please sign in to comment.