diff --git a/CHANGELOG.md b/CHANGELOG.md index b011af9c5cb..990c8ea4751 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated +- Rename argument `is_multiclass` -> `multiclass` ([#162](https://github.com/PyTorchLightning/metrics/pull/162)) + ### Removed diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 8286acbdbb1..29d58252e16 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -59,8 +59,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -Using the is_multiclass parameter ---------------------------------- +Using the multiclass parameter +------------------------------ In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are @@ -68,7 +68,7 @@ integer (binary) tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. For these cases, the metrics where this distinction would make a difference, expose the -``is_multiclass`` argument. Let's see how this is used on the example of +``multiclass`` argument. Let's see how this is used on the example of :class:`~torchmetrics.StatScores` metric. First, let's consider the case with label predictions with 2 classes, which we want to @@ -83,7 +83,7 @@ treat as binary. target = torch.tensor([1, 1, 0]) As you can see below, by default the inputs are treated -as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary - +as multi-class. We can set ``multiclass=False`` to treat the inputs as binary - which is the same as converting the predictions to float beforehand. .. doctest:: @@ -91,7 +91,7 @@ which is the same as converting the predictions to float beforehand. >>> stat_scores(preds, target, reduce='macro', num_classes=2) tensor([[1, 1, 1, 0, 1], [1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False) + >>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False) tensor([[1, 0, 1, 1, 2]]) >>> stat_scores(preds.float(), target, reduce='macro', num_classes=1) tensor([[1, 0, 1, 1, 2]]) @@ -104,13 +104,13 @@ but we would like to treat them as 2-class multi-class, to obtain the metric for preds = torch.tensor([0.2, 0.7, 0.3]) target = torch.tensor([1, 1, 0]) -In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class. +In this case we can set ``multiclass=True``, to treat the inputs as multi-class. .. doctest:: >>> stat_scores(preds, target, reduce='macro', num_classes=1) tensor([[1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True) + >>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True) tensor([[1, 1, 1, 0, 1], [1, 0, 1, 1, 2]]) diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index f117e299d1d..fbd87408dba 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -36,7 +36,7 @@ seed_all(42) -def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): +def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): if average == "none": average = None if num_classes == 1: @@ -49,7 +49,7 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno pass sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass ) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() @@ -62,10 +62,10 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, is_multiclass, igno def _sk_fbeta_f1_multidim_multiclass( - preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average + preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average ): preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass ) if mdmc_average == "global": @@ -174,7 +174,7 @@ def test_no_support(metric_class, metric_fn): @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) @pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", [ (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1), (_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1), @@ -208,7 +208,7 @@ def test_fbeta_f1( metric_class: Metric, metric_fn: Callable, sk_fn: Callable, - is_multiclass: Optional[bool], + multiclass: Optional[bool], num_classes: Optional[int], average: str, mdmc_average: Optional[str], @@ -233,7 +233,7 @@ def test_fbeta_f1( sk_fn=sk_fn, average=average, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -242,7 +242,7 @@ def test_fbeta_f1( "num_classes": num_classes, "average": average, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, @@ -258,7 +258,7 @@ def test_fbeta_f1_functional( metric_class: Metric, metric_fn: Callable, sk_fn: Callable, - is_multiclass: Optional[bool], + multiclass: Optional[bool], num_classes: Optional[int], average: str, mdmc_average: Optional[str], @@ -282,7 +282,7 @@ def test_fbeta_f1_functional( sk_fn=sk_fn, average=average, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -290,7 +290,7 @@ def test_fbeta_f1_functional( "num_classes": num_classes, "average": average, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, diff --git a/tests/classification/test_inputs.py b/tests/classification/test_inputs.py index f9098389c32..a72bea8a496 100644 --- a/tests/classification/test_inputs.py +++ b/tests/classification/test_inputs.py @@ -127,7 +127,7 @@ def _mlmd_prob_to_mc_preds_tr(x): @pytest.mark.parametrize( - "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + "inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target", [ ############################# # Test usual expected cases @@ -169,7 +169,7 @@ def _mlmd_prob_to_mc_preds_tr(x): (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), ], ) -def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): +def test_usual_cases(inputs, num_classes, multiclass, top_k, exp_mode, post_preds, post_target): def __get_data_type_enum(str_exp_mode): return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode) @@ -180,7 +180,7 @@ def __get_data_type_enum(str_exp_mode): target=inputs.target[0], threshold=THRESHOLD, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, top_k=top_k, ) @@ -194,7 +194,7 @@ def __get_data_type_enum(str_exp_mode): target=inputs.target[0][[0], ...], threshold=THRESHOLD, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, top_k=top_k, ) @@ -226,7 +226,7 @@ def test_incorrect_threshold(threshold): @pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass", + "preds, target, num_classes, multiclass", [ # Target not integer (randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None), @@ -236,9 +236,9 @@ def test_incorrect_threshold(threshold): (-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None), # Negative probabilities (-rand(size=(7, )), randint(high=2, size=(7, )), None, None), - # is_multiclass=False and target > 1 + # multiclass=False and target > 1 (rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False), - # is_multiclass=False and preds integers with > 1 + # multiclass=False and preds integers with > 1 (randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False), # Wrong batch size (randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None), @@ -252,7 +252,7 @@ def test_incorrect_threshold(threshold): (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), # #dims in preds = 1 + #dims in target, preds not float (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), - # is_multiclass=False, with C dimension > 2 + # multiclass=False, with C dimension > 2 (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False), # Probs of multiclass preds do not sum up to 1 (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), @@ -266,32 +266,32 @@ def test_incorrect_threshold(threshold): (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), # Max preds larger than num_classes (with #dim preds = #dims target) (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), - # Num_classes=1, but is_multiclass not false + # Num_classes=1, but multiclass not false (randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None), - # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + # multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), # Multilabel input with implied class dimension != num_classes (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), - # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + # Multilabel input with multiclass=True, but num_classes != 2 (or None) (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), # Binary input, num_classes > 2 (rand(size=(7, )), randint(high=2, size=(7, )), 4, None), - # Binary input, num_classes == 2 and is_multiclass not True + # Binary input, num_classes == 2 and multiclass not True (rand(size=(7, )), randint(high=2, size=(7, )), 2, None), (rand(size=(7, )), randint(high=2, size=(7, )), 2, False), - # Binary input, num_classes == 1 and is_multiclass=True + # Binary input, num_classes == 1 and multiclass=True (rand(size=(7, )), randint(high=2, size=(7, )), 1, True), ], ) -def test_incorrect_inputs(preds, target, num_classes, is_multiclass): +def test_incorrect_inputs(preds, target, num_classes, multiclass): with pytest.raises(ValueError): _input_format_classification( - preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass ) @pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, top_k", + "preds, target, num_classes, multiclass, top_k", [ # Topk set with non (md)mc or ml prob data (_bin.preds[0], _bin.target[0], None, None, 2), @@ -304,23 +304,23 @@ def test_incorrect_inputs(preds, target, num_classes, is_multiclass): (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0), # top_k = float (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123), - # top_k =2 with 2 classes, is_multiclass=False + # top_k =2 with 2 classes, multiclass=False (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), # top_k = number of classes (C dimension) (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), - # is_multiclass = True for ml prob inputs, top_k set + # multiclass = True for ml prob inputs, top_k set (_ml_prob.preds[0], _ml_prob.target[0], None, True, 2), # top_k = num_classes for ml prob inputs (_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES), ], ) -def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): +def test_incorrect_inputs_topk(preds, target, num_classes, multiclass, top_k): with pytest.raises(ValueError): _input_format_classification( preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, top_k=top_k, ) diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index f46392545f2..dedb86d5178 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -36,7 +36,7 @@ seed_all(42) -def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): +def _sk_prec_recall(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): # todo: `mdmc_average` is unused if average == "none": average = None @@ -50,7 +50,7 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i pass sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass ) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() @@ -63,10 +63,10 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i def _sk_prec_recall_multidim_multiclass( - preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average + preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average ): preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass ) if mdmc_average == "global": @@ -173,7 +173,7 @@ def test_no_support(metric_class, metric_fn): @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) @pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", [ (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), @@ -207,7 +207,7 @@ def test_precision_recall_class( metric_class: Metric, metric_fn: Callable, sk_fn: Callable, - is_multiclass: Optional[bool], + multiclass: Optional[bool], num_classes: Optional[int], average: str, mdmc_average: Optional[str], @@ -233,7 +233,7 @@ def test_precision_recall_class( sk_fn=sk_fn, average=average, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -242,7 +242,7 @@ def test_precision_recall_class( "num_classes": num_classes, "average": average, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, @@ -258,7 +258,7 @@ def test_precision_recall_fn( metric_class: Metric, metric_fn: Callable, sk_fn: Callable, - is_multiclass: Optional[bool], + multiclass: Optional[bool], num_classes: Optional[int], average: str, mdmc_average: Optional[str], @@ -283,7 +283,7 @@ def test_precision_recall_fn( sk_fn=sk_fn, average=average, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, mdmc_average=mdmc_average, ), @@ -291,7 +291,7 @@ def test_precision_recall_fn( "num_classes": num_classes, "average": average, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "mdmc_average": mdmc_average, }, diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 3112ae37744..6436dfc4c23 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -35,10 +35,10 @@ seed_all(42) -def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): +def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce=None): # todo: `mdmc_reduce` is unused preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k ) sk_preds, sk_target = preds.numpy(), target.numpy() @@ -73,9 +73,9 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_in return sk_stats -def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k): +def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k): preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k ) if mdmc_reduce == "global": @@ -134,7 +134,7 @@ def test_wrong_threshold(): @pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) @pytest.mark.parametrize( - "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k", + "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k", [ (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None), (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None), @@ -167,7 +167,7 @@ def test_stat_scores_class( reduce: str, mdmc_reduce: Optional[str], num_classes: Optional[int], - is_multiclass: Optional[bool], + multiclass: Optional[bool], ignore_index: Optional[int], top_k: Optional[int], ): @@ -184,7 +184,7 @@ def test_stat_scores_class( reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, top_k=top_k, ), @@ -194,7 +194,7 @@ def test_stat_scores_class( "reduce": reduce, "mdmc_reduce": mdmc_reduce, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "top_k": top_k, }, @@ -210,7 +210,7 @@ def test_stat_scores_fn( reduce: str, mdmc_reduce: Optional[str], num_classes: Optional[int], - is_multiclass: Optional[bool], + multiclass: Optional[bool], ignore_index: Optional[int], top_k: Optional[int], ): @@ -226,7 +226,7 @@ def test_stat_scores_fn( reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, top_k=top_k, ), @@ -235,7 +235,7 @@ def test_stat_scores_fn( "reduce": reduce, "mdmc_reduce": mdmc_reduce, "threshold": THRESHOLD, - "is_multiclass": is_multiclass, + "multiclass": multiclass, "ignore_index": ignore_index, "top_k": top_k, }, diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index d86ea72da3b..0d60d4c5d24 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -315,6 +315,7 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, ): + def metric_functional_ignore_indexes(preds, target, indexes): return metric_functional(preds, target) diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index b5f5c31735b..2882d51e542 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -118,11 +118,13 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): metric_functional=retrieval_fall_out, ) - @pytest.mark.parametrize(**_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_neg_target, - _errors_test_class_metric_parameters_k, - )) + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_neg_target, + _errors_test_class_metric_parameters_k, + ) + ) def test_arguments_class_metric( self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): @@ -137,10 +139,12 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(**_concat_tests( - _errors_test_functional_metric_parameters_default, - _errors_test_functional_metric_parameters_k, - )) + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + ) + ) def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index 8b54185aab4..95d24bec1a5 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -112,13 +112,15 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): metric_functional=retrieval_normalized_dcg, ) - @pytest.mark.parametrize(**_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, - _errors_test_class_metric_parameters_k, - )) + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + ) + ) def test_arguments_class_metric( - self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -131,13 +133,13 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(**_concat_tests( - _errors_test_functional_metric_parameters_default, - _errors_test_functional_metric_parameters_k, - )) - def test_arguments_functional_metric( - self, preds: Tensor, target: Tensor, message: str, metric_args: dict, - ): + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + ) + ) + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index db88b66d8e5..3e41b6146ad 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn import torch from torch import Tensor @@ -96,10 +97,10 @@ class FBeta(StatScores): this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -137,12 +138,20 @@ def __init__( mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + self.beta = beta allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -154,7 +163,7 @@ def __init__( threshold=threshold, top_k=top_k, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, @@ -243,10 +252,10 @@ class F1(FBeta): this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -278,12 +287,19 @@ def __init__( mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass super().__init__( num_classes=num_classes, @@ -293,7 +309,7 @@ def __init__( mdmc_average=mdmc_average, ignore_index=ignore_index, top_k=top_k, - is_multiclass=is_multiclass, + multiclass=multiclass, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 1a0f687a9ad..de137262878 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn import torch from torch import Tensor @@ -86,10 +87,10 @@ class Precision(StatScores): this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -129,12 +130,20 @@ def __init__( mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -145,7 +154,7 @@ def __init__( threshold=threshold, top_k=top_k, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, @@ -237,10 +246,10 @@ class Recall(StatScores): Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -280,12 +289,20 @@ def __init__( mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -296,7 +313,7 @@ def __init__( threshold=threshold, top_k=top_k, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index c96c75b5171..9e896401bb1 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Tuple +from warnings import warn import numpy as np import torch @@ -86,10 +87,10 @@ class StatScores(Metric): flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. compute_on_step: @@ -140,12 +141,20 @@ def __init__( num_classes: Optional[int] = None, ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, @@ -157,7 +166,7 @@ def __init__( self.mdmc_reduce = mdmc_reduce self.num_classes = num_classes self.threshold = threshold - self.is_multiclass = is_multiclass + self.multiclass = multiclass self.ignore_index = ignore_index self.top_k = top_k @@ -206,7 +215,7 @@ def update(self, preds: Tensor, target: Tensor): threshold=self.threshold, num_classes=self.num_classes, top_k=self.top_k, - is_multiclass=self.is_multiclass, + multiclass=self.multiclass, ignore_index=self.ignore_index, ) diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 70e0687e2da..f0d6cfaa095 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional +from warnings import warn import torch from torch import Tensor @@ -80,7 +81,8 @@ def fbeta( num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes f_beta metric. @@ -151,10 +153,10 @@ def fbeta( only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -172,6 +174,13 @@ def fbeta( tensor(0.3333) """ + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -195,7 +204,7 @@ def fbeta( threshold=threshold, num_classes=num_classes, top_k=top_k, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, ) @@ -212,7 +221,8 @@ def f1( num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: """ Computes F1 metric. F1 metrics correspond to a equally weighted average of the @@ -286,10 +296,10 @@ def f1( this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -306,4 +316,10 @@ def f1( >>> f1(preds, target, num_classes=3) tensor(0.3333) """ - return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass) + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass) diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index b62142b33f3..b4d530c1e29 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple +from warnings import warn import torch from torch import Tensor @@ -47,7 +48,8 @@ def precision( num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Precision `_: @@ -117,10 +119,10 @@ def precision( this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -152,6 +154,13 @@ def precision( tensor(0.2500) """ + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -175,7 +184,7 @@ def precision( threshold=threshold, num_classes=num_classes, top_k=top_k, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, ) @@ -210,7 +219,8 @@ def recall( num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Recall `_: @@ -280,10 +290,10 @@ def recall( this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -315,6 +325,13 @@ def recall( tensor(0.2500) """ + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -338,7 +355,7 @@ def recall( threshold=threshold, num_classes=num_classes, top_k=top_k, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, ) @@ -354,7 +371,8 @@ def precision_recall( num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tuple[Tensor, Tensor]: r""" Computes `Precision and Recall `_: @@ -427,10 +445,10 @@ def precision_recall( this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -463,6 +481,13 @@ def precision_recall( (tensor(0.2500), tensor(0.2500)) """ + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -486,7 +511,7 @@ def precision_recall( threshold=threshold, num_classes=num_classes, top_k=top_k, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, ) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 985f84dd9b2..9054c3b6433 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple +from warnings import warn import torch from torch import Tensor, tensor @@ -82,12 +83,19 @@ def _stat_scores_update( num_classes: Optional[int] = None, top_k: Optional[int] = None, threshold: float = 0.5, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k ) if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: @@ -145,8 +153,9 @@ def stat_scores( num_classes: Optional[int] = None, top_k: Optional[int] = None, threshold: float = 0.5, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors `__ @@ -211,10 +220,10 @@ def stat_scores( flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Return: @@ -271,6 +280,12 @@ def stat_scores( >>> stat_scores(preds, target, reduce='micro') tensor([2, 2, 6, 2, 4]) """ + if is_multiclass is not None and multiclass is None: + warn( + "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", + DeprecationWarning + ) + multiclass = is_multiclass if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") @@ -292,7 +307,7 @@ def stat_scores( top_k=top_k, threshold=threshold, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, ignore_index=ignore_index, ) return _stat_scores_compute(tp, fp, tn, fn) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 3746ed38da9..4f1dc01f442 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -26,7 +26,7 @@ def _check_same_shape(pred: Tensor, target: Tensor): raise RuntimeError("Predictions and targets are expected to have the same shape") -def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, is_multiclass: bool): +def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, multiclass: bool): """ Perform basic validation of inputs that does not require deducing any information of the type of inputs. @@ -50,11 +50,11 @@ def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, is_ if not 0 < threshold < 1: raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") + if multiclass is False and target.max() > 1: + raise ValueError("If you set `multiclass=False`, then `target` should not exceed 1.") - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + if multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set `multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[str, int]: @@ -119,44 +119,44 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[st return case, implied_classes -def _check_num_classes_binary(num_classes: int, is_multiclass: bool): +def _check_num_classes_binary(num_classes: int, multiclass: bool): """ This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for binary data. + and `multiclass` param for binary data. """ if num_classes > 2: raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - if num_classes == 2 and not is_multiclass: + if num_classes == 2 and not multiclass: raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + "Your data is binary and `num_classes=2`, but `multiclass` is not True." " Set it to True if you want to transform binary data to multi-class format." ) - if num_classes == 1 and is_multiclass: + if num_classes == 1 and multiclass: raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either set `is_multiclass=None`(default) or set `num_classes=2`" + "You have binary data and have set `multiclass=True`, but `num_classes` is 1." + " Either set `multiclass=None`(default) or set `num_classes=2`" " to transform binary data to multi-class format." ) -def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, is_multiclass: bool, implied_classes: int): +def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, multiclass: bool, implied_classes: int): """ This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for (multi-dimensional) multi-class data. + and `multiclass` param for (multi-dimensional) multi-class data. """ - if num_classes == 1 and is_multiclass is not False: + if num_classes == 1 and multiclass is not False: raise ValueError( "You have set `num_classes=1`, but predictions are integers." " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." + " to binary/multi-label, set `multiclass=False`." ) if num_classes > 1: - if is_multiclass is False: + if multiclass is False: if implied_classes != num_classes: raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " + "You have set `multiclass=False`, but the implied number of classes " " (from shape of inputs) does not match `num_classes`. If you are trying to" " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" " should be either None or the product of the size of extra dimensions (...)." @@ -170,35 +170,35 @@ def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, is_mu raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") -def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): +def _check_num_classes_ml(num_classes: int, multiclass: bool, implied_classes: int): """ This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for multi-label data. + and `multiclass` param for multi-label data. """ - if is_multiclass and num_classes != 2: + if multiclass and num_classes != 2: raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + "Your have set `multiclass=True`, but `num_classes` is not equal to 2." " If you are trying to transform multi-label data to 2 class multi-dimensional" " multi-class, you should set `num_classes` to either 2 or None." ) - if not is_multiclass and num_classes != implied_classes: + if not multiclass and num_classes != implied_classes: raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") -def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): +def _check_top_k(top_k: int, case: str, implied_classes: int, multiclass: Optional[bool], preds_float: bool): if case == DataType.BINARY: raise ValueError("You can not use `top_k` parameter with binary data.") if not isinstance(top_k, int) or top_k <= 0: raise ValueError("The `top_k` has to be an integer larger than 0.") if not preds_float: raise ValueError("You have set `top_k`, but you do not have probability predictions.") - if is_multiclass is False: - raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == DataType.MULTILABEL and is_multiclass: + if multiclass is False: + raise ValueError("If you set `multiclass=False`, you can not set `top_k`.") + if case == DataType.MULTILABEL and multiclass: raise ValueError( "If you want to transform multi-label data to 2 class multi-dimensional" - "multi-class data using `is_multiclass=True`, you can not use `top_k`." + "multi-class data using `multiclass=True`, you can not use `top_k`." ) if top_k >= implied_classes: raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") @@ -209,14 +209,14 @@ def _check_classification_inputs( target: Tensor, threshold: float, num_classes: Optional[int], - is_multiclass: bool, + multiclass: bool, top_k: Optional[int], ) -> str: """Performs error checking on inputs for classification. This ensures that preds and target take one of the shape/type combinations that are specified in ``_input_format_classification`` docstring. It also checks the cases of - over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + over-rides with ``multiclass`` by checking (for multi-class and multi-dim multi-class cases) that there are only up to 2 distinct labels. In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. @@ -252,10 +252,10 @@ def _check_classification_inputs( it will take precedence over threshold. Should be left unset (``None``) for inputs with label predictions. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. @@ -265,7 +265,7 @@ def _check_classification_inputs( """ # Basic validation (that does not need case/type information) - _basic_input_validation(preds, target, threshold, is_multiclass) + _basic_input_validation(preds, target, threshold, multiclass) # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) @@ -277,9 +277,9 @@ def _check_classification_inputs( # Check consistency with the `C` dimension in case of multi-class data if preds.shape != target.shape: - if is_multiclass is False and implied_classes != 2: + if multiclass is False and implied_classes != 2: raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," + "You have set `multiclass=False`, but have more than 2 classes in your data," " based on the C dimension of `preds`." ) if target.max() >= implied_classes: @@ -290,15 +290,15 @@ def _check_classification_inputs( # Check that num_classes is consistent if num_classes: if case == DataType.BINARY: - _check_num_classes_binary(num_classes, is_multiclass) + _check_num_classes_binary(num_classes, multiclass) elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): - _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) + _check_num_classes_mc(preds, target, num_classes, multiclass, implied_classes) elif case.MULTILABEL: - _check_num_classes_ml(num_classes, is_multiclass, implied_classes) + _check_num_classes_ml(num_classes, multiclass, implied_classes) # Check that top_k is consistent if top_k is not None: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) + _check_top_k(top_k, case, implied_classes, multiclass, preds.is_floating_point()) return case @@ -309,7 +309,7 @@ def _input_format_classification( threshold: float = 0.5, top_k: Optional[int] = None, num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, + multiclass: Optional[bool] = None, ) -> Tuple[Tensor, Tensor, str]: """Convert preds and target tensors into common format. @@ -333,28 +333,28 @@ def _input_format_classification( The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` of ``(N, C, X)``, the details for each case are described below. The function also returns a ``case`` string, which describes which of the above cases the inputs belonged to - regardless - of whether this was "overridden" by other settings (like ``is_multiclass``). + of whether this was "overridden" by other settings (like ``multiclass``). In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + ``threshold`` or 0 otherwise). If ``multiclass=True``, then then both targets are preds become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to preds first. In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original - shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + shape was ``(N,C)``). However, if ``multiclass=False``, then targets and preds will be returned as ``(N,1)`` tensor. In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening - all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + all dimensions after the first one. However if ``multiclass=True``, then both are returned as ``(N, 2, C)``, by an equivalent transformation as in the binary case. In multi-dimensional multi-class case, normally both target and preds are returned as ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and ``C``. The transformations performed here are equivalent to the multi-class case. However, if - ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``multiclass=False`` (and there are up to two classes), then the data is returned as ``(N, X)`` binary tensors (multi-label). Note: @@ -379,10 +379,10 @@ def _input_format_classification( default value (``None``) will be interepreted as 1 for these inputs. Should be left unset (``None``) for all other types of inputs. - is_multiclass: + multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's - :ref:`documentation section ` + :ref:`documentation section ` for a more detailed explanation and examples. Returns: @@ -407,18 +407,18 @@ def _input_format_classification( target, threshold=threshold, num_classes=num_classes, - is_multiclass=is_multiclass, + multiclass=multiclass, top_k=top_k, ) if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: preds = (preds >= threshold).int() - num_classes = num_classes if not is_multiclass else 2 + num_classes = num_classes if not multiclass else 2 if case == DataType.MULTILABEL and top_k: preds = select_topk(preds, top_k) - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: + if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] preds = select_topk(preds, top_k or 1) @@ -428,10 +428,10 @@ def _input_format_classification( target = to_onehot(target, max(2, num_classes)) - if is_multiclass is False: + if multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] - if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: + if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass is not False) or multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: @@ -495,7 +495,9 @@ def _input_format_classification_one_hot( def _check_retrieval_functional_inputs( - preds: Tensor, target: Tensor, allow_non_binary_target: bool = False + preds: Tensor, + target: Tensor, + allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor]: """Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. @@ -532,7 +534,10 @@ def _check_retrieval_functional_inputs( def _check_retrieval_inputs( - indexes: Tensor, preds: Tensor, target: Tensor, allow_non_binary_target: bool = False + indexes: Tensor, + preds: Tensor, + target: Tensor, + allow_non_binary_target: bool = False, ) -> Tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.