Skip to content

Commit

Permalink
Fix for bug when providing superclass arguments as kwargs (#1069)
Browse files Browse the repository at this point in the history
* Fix for bug when providing superclass arguments as kwargs
* Fix code formatting
* chlog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Jun 7, 2022
1 parent 35fcc88 commit 0a5d964
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 41 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070))

-

- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069))


## [0.9.0] - 2022-05-30
Expand Down
17 changes: 17 additions & 0 deletions tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Dict

import numpy as np
import pytest
Expand All @@ -30,6 +31,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import JaccardIndex
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix

Expand Down Expand Up @@ -186,3 +188,18 @@ def test_warning_on_nan(tmpdir):
match=".* nan values found in confusion matrix have been replaced with zeros.",
):
confusion_matrix(preds, target, num_classes=5, normalize="true")


@pytest.mark.parametrize(
"metric_args",
[
{"num_classes": 1, "normalize": "true"},
{"num_classes": 1, "normalize": "pred"},
{"num_classes": 1, "normalize": "all"},
{"num_classes": 1, "normalize": "none"},
{"num_classes": 1, "normalize": None},
],
)
def test_provide_superclass_kwargs(metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
JaccardIndex(**metric_args)
21 changes: 19 additions & 2 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Dict, Optional

import numpy as np
import pytest
Expand All @@ -30,7 +30,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics import StatScores
from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores
from torchmetrics.functional import stat_scores
from torchmetrics.utilities.checks import _input_format_classification

Expand Down Expand Up @@ -326,3 +326,20 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten

assert torch.equal(class_metric.compute(), expected.T)
assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)


@pytest.mark.parametrize(
"metric_args",
[
{"reduce": "micro"},
{"num_classes": 1, "reduce": "macro"},
{"reduce": "samples"},
{"mdmc_reduce": None},
{"mdmc_reduce": "samplewise"},
{"mdmc_reduce": "global"},
],
)
@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity])
def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
metric_cls(**metric_args)
5 changes: 2 additions & 3 deletions tests/text/test_mer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate


def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["mer"]
Expand Down
5 changes: 2 additions & 3 deletions tests/text/test_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate


def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["wer"]
Expand Down
14 changes: 9 additions & 5 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor, tensor

Expand All @@ -23,7 +23,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.enums import AverageMethod, DataType

from torchmetrics.classification.stat_scores import StatScores # isort:skip

Expand Down Expand Up @@ -170,15 +170,19 @@ def __init__(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
13 changes: 9 additions & 4 deletions torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.utilities.enums import AverageMethod


class Dice(StatScores):
Expand Down Expand Up @@ -128,15 +129,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ("weighted", "none", None) else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
10 changes: 7 additions & 3 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,20 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
self.beta = beta
allowed_average = list(AverageMethod)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
7 changes: 4 additions & 3 deletions torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -88,11 +88,12 @@ def __init__(
absent_score: float = 0.0,
threshold: float = 0.5,
multilabel: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
kwargs["normalize"] = kwargs.get("normalize")

super().__init__(
num_classes=num_classes,
normalize=None,
threshold=threshold,
multilabel=multilabel,
**kwargs,
Expand Down
23 changes: 16 additions & 7 deletions torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.utilities.enums import AverageMethod


class Precision(StatScores):
Expand Down Expand Up @@ -121,15 +122,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down Expand Up @@ -256,15 +261,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
13 changes: 9 additions & 4 deletions torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.specificity import _specificity_compute
from torchmetrics.utilities.enums import AverageMethod


class Specificity(StatScores):
Expand Down Expand Up @@ -123,15 +124,19 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

_reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
if "reduce" not in kwargs:
kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
if "mdmc_reduce" not in kwargs:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape

__doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]}

Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from pystoi import stoi as stoi_backend
else:
stoi_backend = None
__doctest_skip__ = ["short_time_objective_intelligibility"]
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def short_time_objective_intelligibility(
Expand Down

0 comments on commit 0a5d964

Please sign in to comment.