diff --git a/.github/assistant.py b/.github/assistant.py index 5ffc5e69a0d..3fa2fa4b3a4 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -140,7 +140,7 @@ def changed_domains( return "unittests" # parse domains - def _crop_path(fname: str, paths: List[str]): + def _crop_path(fname: str, paths: List[str]) -> str: for p in paths: fname = fname.replace(p, "") return fname diff --git a/examples/plotting.py b/examples/plotting.py index 2ef46e03665..b00b8d4b1fb 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -17,7 +17,7 @@ import torch -def pesq_example(): +def pesq_example() -> tuple: """Plot PESQ audio example.""" from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality @@ -37,7 +37,7 @@ def pesq_example(): return fig, ax -def pit_example(): +def pit_example() -> tuple: """Plot PIT audio example.""" from torchmetrics.audio.pit import PermutationInvariantTraining from torchmetrics.functional import scale_invariant_signal_noise_ratio @@ -58,7 +58,7 @@ def pit_example(): return fig, ax -def sdr_example(): +def sdr_example() -> tuple: """Plot SDR audio example.""" from torchmetrics.audio.sdr import SignalDistortionRatio @@ -78,7 +78,7 @@ def sdr_example(): return fig, ax -def si_sdr_example(): +def si_sdr_example() -> tuple: """Plot SI-SDR audio example.""" from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio @@ -98,7 +98,7 @@ def si_sdr_example(): return fig, ax -def snr_example(): +def snr_example() -> tuple: """Plot SNR audio example.""" from torchmetrics.audio.snr import SignalNoiseRatio @@ -118,7 +118,7 @@ def snr_example(): return fig, ax -def si_snr_example(): +def si_snr_example() -> tuple: """Plot SI-SNR example.""" from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio @@ -138,7 +138,7 @@ def si_snr_example(): return fig, ax -def stoi_example(): +def stoi_example() -> tuple: """Plot STOI example.""" from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -158,7 +158,7 @@ def stoi_example(): return fig, ax -def accuracy_example(): +def accuracy_example() -> tuple: """Plot Accuracy example.""" from torchmetrics.classification import MulticlassAccuracy @@ -189,7 +189,7 @@ def accuracy_example(): return fig, ax -def mean_squared_error_example(): +def mean_squared_error_example() -> tuple: """Plot mean squared error example.""" from torchmetrics.regression import MeanSquaredError @@ -208,7 +208,7 @@ def mean_squared_error_example(): return fig, ax -def confusion_matrix_example(): +def confusion_matrix_example() -> tuple: """Plot confusion matrix example.""" from torchmetrics.classification import MulticlassConfusionMatrix @@ -222,7 +222,7 @@ def confusion_matrix_example(): return fig, ax -def spectral_distortion_index_example(): +def spectral_distortion_index_example() -> tuple: """Plot spectral distortion index example example.""" from torchmetrics.image.d_lambda import SpectralDistortionIndex @@ -242,7 +242,7 @@ def spectral_distortion_index_example(): return fig, ax -def error_relative_global_dimensionless_synthesis(): +def error_relative_global_dimensionless_synthesis() -> tuple: """Plot error relative global dimensionless synthesis example.""" from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis @@ -262,7 +262,7 @@ def error_relative_global_dimensionless_synthesis(): return fig, ax -def peak_signal_noise_ratio(): +def peak_signal_noise_ratio() -> tuple: """Plot peak signal noise ratio example.""" from torchmetrics.image.psnr import PeakSignalNoiseRatio @@ -282,7 +282,7 @@ def peak_signal_noise_ratio(): return fig, ax -def spectral_angle_mapper(): +def spectral_angle_mapper() -> tuple: """Plot spectral angle mapper example.""" from torchmetrics.image.sam import SpectralAngleMapper @@ -302,7 +302,7 @@ def spectral_angle_mapper(): return fig, ax -def structural_similarity_index_measure(): +def structural_similarity_index_measure() -> tuple: """Plot structural similarity index measure example.""" from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure @@ -322,7 +322,7 @@ def structural_similarity_index_measure(): return fig, ax -def multiscale_structural_similarity_index_measure(): +def multiscale_structural_similarity_index_measure() -> tuple: """Plot multiscale structural similarity index measure example.""" from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure @@ -342,7 +342,7 @@ def multiscale_structural_similarity_index_measure(): return fig, ax -def universal_image_quality_index(): +def universal_image_quality_index() -> tuple: """Plot universal image quality index example.""" from torchmetrics.image.uqi import UniversalImageQualityIndex @@ -362,7 +362,7 @@ def universal_image_quality_index(): return fig, ax -def mean_average_precision(): +def mean_average_precision() -> tuple: """Plot MAP metric.""" from torchmetrics.detection.mean_ap import MeanAveragePrecision diff --git a/pyproject.toml b/pyproject.toml index 561c8202777..47d5682ce6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,8 @@ extend-select = [ "PT", # see: https://pypi.org/project/flake8-pytest-style "RET", # see: https://pypi.org/project/flake8-return "SIM", # see: https://pypi.org/project/flake8-simplify + "YTT", # see: https://pypi.org/project/flake8-2020 + "ANN", # see: https://pypi.org/project/flake8-annotations ] ignore = [ "E731", # Do not assign a lambda expression, use a def @@ -87,6 +89,8 @@ ignore = [ "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D107", # Missing docstring in `__init__` + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod ] # Exclude a variety of commonly ignored directories. exclude = [ @@ -103,6 +107,13 @@ exclude = [ ignore-init-module-imports = true unfixable = ["F401"] +[tool.ruff.per-file-ignores] +"setup.py" = ["ANN202", "ANN401"] +"src/**" = ["ANN401"] +"tests/**" = [ + "ANN001", "ANN002", "ANN003", "ANN201", "ANN202", "ANN204", "ANN205", "ANN401" +] + [tool.ruff.pydocstyle] # Use Google-style docstrings. convention = "google" diff --git a/setup.py b/setup.py index 0c63ade6bd5..54b19530652 100755 --- a/setup.py +++ b/setup.py @@ -139,7 +139,7 @@ def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: return re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) -def _load_py_module(fname, pkg="torchmetrics"): +def _load_py_module(fname: str, pkg: str = "torchmetrics"): spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname)) py = module_from_spec(spec) spec.loader.exec_module(py) @@ -155,7 +155,7 @@ def _load_py_module(fname, pkg="torchmetrics"): BASE_REQUIREMENTS = _load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt") -def _prepare_extras(skip_files: Tuple[str] = ("devel.txt", "doctest.txt", "integrate.txt", "docs.txt")): +def _prepare_extras(skip_files: Tuple[str] = ("devel.txt", "doctest.txt", "integrate.txt", "docs.txt")) -> dict: # find all extra requirements _load_req = partial(_load_requirements, path_dir=_PATH_REQUIRE) found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt"))) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index ab1b9e975b7..9bddec9a057 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -58,10 +58,11 @@ ) from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.nominal import CramersV # noqa: E402 -from torchmetrics.nominal import PearsonsContingencyCoefficient, TheilsU, TschuprowsT # noqa: E402 +from torchmetrics.nominal import PearsonsContingencyCoefficient # noqa: E402 +from torchmetrics.nominal import TheilsU, TschuprowsT # noqa: E402 from torchmetrics.regression import ConcordanceCorrCoef # noqa: E402 +from torchmetrics.regression import CosineSimilarity # noqa: E402 from torchmetrics.regression import ( # noqa: E402 - CosineSimilarity, ExplainedVariance, KendallRankCorrCoef, KLDivergence, @@ -79,8 +80,8 @@ WeightedMeanAbsolutePercentageError, ) from torchmetrics.retrieval import RetrievalFallOut # noqa: E402 +from torchmetrics.retrieval import RetrievalHitRate # noqa: E402 from torchmetrics.retrieval import ( # noqa: E402 - RetrievalHitRate, RetrievalMAP, RetrievalMRR, RetrievalNormalizedDCG, diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index 02777a55c54..85a74a9c5e1 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -56,7 +56,7 @@ def __init__( default_value: Union[Tensor, List], nan_strategy: Union[str, float] = "error", **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) allowed_nan_strategy = ("error", "warn", "ignore") if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): @@ -138,7 +138,7 @@ def __init__( self, nan_strategy: Union[str, float] = "warn", **kwargs: Any, - ): + ) -> None: super().__init__( "max", -torch.tensor(float("inf")), @@ -238,7 +238,7 @@ def __init__( self, nan_strategy: Union[str, float] = "warn", **kwargs: Any, - ): + ) -> None: super().__init__( "min", torch.tensor(float("inf")), @@ -336,7 +336,7 @@ def __init__( self, nan_strategy: Union[str, float] = "warn", **kwargs: Any, - ): + ) -> None: super().__init__( "sum", torch.tensor(0.0), @@ -435,7 +435,7 @@ def __init__( self, nan_strategy: Union[str, float] = "warn", **kwargs: Any, - ): + ) -> None: super().__init__("cat", [], nan_strategy, **kwargs) def update(self, value: Union[float, Tensor]) -> None: @@ -496,7 +496,7 @@ def __init__( self, nan_strategy: Union[str, float] = "warn", **kwargs: Any, - ): + ) -> None: super().__init__( "sum", torch.tensor(0.0), diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 692e879a9e6..596dd0f3dd5 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -119,7 +119,7 @@ def __init__( dist_reduce_fx="sum" if self.multidim_average == "global" else "mean", ) - def update(self, preds, target) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states with predictions and targets.""" if self.validate_args: _multiclass_stat_scores_tensor_validation( diff --git a/src/torchmetrics/detection/panoptic_quality.py b/src/torchmetrics/detection/panoptic_quality.py index b9befdca60c..a13f2c4c97e 100644 --- a/src/torchmetrics/detection/panoptic_quality.py +++ b/src/torchmetrics/detection/panoptic_quality.py @@ -96,7 +96,7 @@ def __init__( stuffs: Collection[int], allow_unknown_preds_category: bool = False, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) things, stuffs = _parse_categories(things, stuffs) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 13568c28cc9..d4469e92bc7 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -22,10 +22,10 @@ from torch import Tensor from torch.nn import Module -from torchmetrics.utilities import apply_to_collection, rank_zero_warn from torchmetrics.utilities.data import ( _flatten, _squeeze_if_scalar, + apply_to_collection, dim_zero_cat, dim_zero_max, dim_zero_mean, @@ -35,6 +35,7 @@ from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val +from torchmetrics.utilities.prints import rank_zero_warn def jit_distributed_available() -> bool: diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index a52810587b6..a28bc7e2b59 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -85,7 +85,7 @@ def __init__( nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[Union[int, float]] = 0.0, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.num_classes = num_classes self.bias_correction = bias_correction diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index 5729e4df373..bc3b17958a5 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -89,7 +89,7 @@ def __init__( nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[Union[int, float]] = 0.0, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.num_classes = num_classes diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index e12e4144939..31cea7bc53b 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -73,7 +73,7 @@ def __init__( nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[Union[int, float]] = 0.0, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.num_classes = num_classes diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index 589e91bf738..e130a40e42a 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -85,7 +85,7 @@ def __init__( nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[Union[int, float]] = 0.0, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.num_classes = num_classes self.bias_correction = bias_correction diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 98a78a4e04c..eb01129c0e3 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -125,7 +125,7 @@ def __init__( alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", num_outputs: int = 1, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index c966640f49e..b32c0e0cc75 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -146,7 +146,7 @@ def __init__( baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.model_name_or_path = model_name_or_path or _DEFAULT_MODEL self.num_layers = num_layers diff --git a/src/torchmetrics/text/bleu.py b/src/torchmetrics/text/bleu.py index 3a14218160a..462b4b347c9 100644 --- a/src/torchmetrics/text/bleu.py +++ b/src/torchmetrics/text/bleu.py @@ -72,7 +72,7 @@ def __init__( smooth: bool = False, weights: Optional[Sequence[float]] = None, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.n_gram = n_gram self.smooth = smooth diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 6cabcee306e..6dcd06cc9a2 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -70,7 +70,7 @@ class CharErrorRate(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 8dcce6dec4a..58228e93462 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -102,7 +102,7 @@ def __init__( whitespace: bool = False, return_sentence_level_score: bool = False, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) if not isinstance(n_char_order, int) or n_char_order < 1: diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index 5558e8877bc..a28c6b18294 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -68,7 +68,7 @@ def __init__( deletion: float = 0.2, insertion: float = 1.0, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) if language not in ("en", "ja"): diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index e96252a55b6..e7c5abb8f69 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -127,7 +127,7 @@ def __init__( verbose: bool = True, return_sentence_level_score: bool = False, **kwargs: Dict[str, Any], - ): + ) -> None: super().__init__(**kwargs) self.model_name_or_path = model_name_or_path self.temperature = temperature diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index 3e137932da3..088a5ab9fe0 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -67,7 +67,7 @@ class MatchErrorRate(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index 41eb9c5b77c..e7394450056 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -60,7 +60,7 @@ def __init__( self, ignore_index: Optional[int] = None, **kwargs: Dict[str, Any], - ): + ) -> None: super().__init__(**kwargs) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Argument `ignore_index` expected to either be `None` or an `int` but got {ignore_index}") diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index 509dd75af1e..9bb60a47df2 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -102,7 +102,7 @@ def __init__( accumulate: Literal["avg", "best"] = "best", rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) if use_stemmer or "rougeLsum" in rouge_keys: if not _NLTK_AVAILABLE: diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 2120b77224f..00a67fe3998 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -90,7 +90,7 @@ def __init__( lowercase: bool = False, weights: Optional[Sequence[float]] = None, **kwargs: Any, - ): + ) -> None: super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index 78e9d3e05d1..ea7b0eed987 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -95,7 +95,7 @@ class SQuAD(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state(name="f1_score", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 0f918c19a9d..5d1ffa211fc 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -69,7 +69,7 @@ def __init__( asian_support: bool = False, return_sentence_level_score: bool = False, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) if not isinstance(normalize, bool): raise ValueError(f"Expected argument `normalize` to be of type boolean but got {normalize}.") diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index 71e4f92ef27..4687da840b5 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -68,7 +68,7 @@ class WordErrorRate(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index 3d3bca262b9..d0c9c8bc96d 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -67,7 +67,7 @@ class WordInfoLost(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index ca76c3a76fd..335c524a00d 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -68,7 +68,7 @@ class WordInfoPreserved(Metric): def __init__( self, **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index ccaff4d3e66..50dff939100 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -22,6 +22,7 @@ import torch from torch import Tensor +from torchmetrics.metric import Metric from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType @@ -626,12 +627,12 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-6) -> bool: @no_type_check def check_forward_full_state_property( - metric_class, - init_args: Dict[str, Any] = {}, - input_args: Dict[str, Any] = {}, + metric_class: Metric, + init_args: Optional[Dict[str, Any]] = None, + input_args: Optional[Dict[str, Any]] = None, num_update_to_compare: Sequence[int] = [10, 100, 1000], reps: int = 5, -) -> bool: +) -> None: """Check if the new ``full_state_update`` property works as intended. This function checks if the property can safely be set to ``False`` which will for most metrics results in a @@ -676,6 +677,8 @@ def check_forward_full_state_property( ... ) Recommended setting `full_state_update=True` """ + init_args = init_args or {} + input_args = input_args or {} class FullState(metric_class): full_state_update = True @@ -728,6 +731,7 @@ class PartState(metric_class): faster = (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading print(f"Recommended setting `full_state_update={not faster}`") + return def is_overridden(method_name: str, instance: object, parent: object) -> bool: diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index b9008a38966..1b58c97b16f 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -14,10 +14,14 @@ """Import utilities.""" import operator import shutil +import sys +from distutils.version import LooseVersion from typing import Optional from lightning_utilities.core.imports import compare_version, package_available +_PYTHON_VERSION = ".".join(map(str, [sys.version_info.major, sys.version_info.minor, sys.version_info.micro])) +_PYTHON_LOWER_3_8 = LooseVersion(_PYTHON_VERSION) < LooseVersion("3.8") _TORCH_LOWER_1_12_DEV: Optional[bool] = compare_version("torch", operator.lt, "1.12.0.dev") _TORCH_GREATER_EQUAL_1_9: Optional[bool] = compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_10: Optional[bool] = compare_version("torch", operator.ge, "1.10.0") diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 7e1aaaf1933..7f2ce8d2cf3 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -75,7 +75,7 @@ def __init__( output_dim: int = -1, remove_nans: bool = True, squeeze_outputs: bool = True, - ): + ) -> None: super().__init__() self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim diff --git a/tests/integrations/lightning/boring_model.py b/tests/integrations/lightning/boring_model.py index 9bcbf2050c1..143ef756983 100644 --- a/tests/integrations/lightning/boring_model.py +++ b/tests/integrations/lightning/boring_model.py @@ -20,7 +20,7 @@ class RandomDictStringDataset(Dataset): """Class for creating a dictionary of random strings.""" - def __init__(self, size, length): + def __init__(self, size, length) -> None: self.len = length self.data = torch.randn(length, size) @@ -36,7 +36,7 @@ def __len__(self): class RandomDataset(Dataset): """Random dataset for testing PL Module.""" - def __init__(self, size, length): + def __init__(self, size, length) -> None: self.len = length self.data = torch.randn(length, size) @@ -66,7 +66,7 @@ def training_step(...): model.training_epoch_end = None """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.layer = torch.nn.Linear(32, 2) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 3bc7878e83e..e627eac2ea8 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -37,7 +37,7 @@ def test_metric_lightning(tmpdir): """Test that including a metric inside a lightning module calculates a simple sum correctly.""" class TestModel(BoringModel): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = SumMetric() self.sum = 0.0 @@ -75,7 +75,7 @@ def test_metrics_reset(tmpdir): """ class TestModel(LightningModule): - def __init__(self): + def __init__(self) -> None: super().__init__() self.layer = torch.nn.Linear(32, 1) @@ -189,7 +189,7 @@ def test_metric_lightning_log(tmpdir): """Test logging a metric object and that the metric state gets reset after each epoch.""" class TestModel(BoringModel): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric_step = SumMetric() self.metric_epoch = SumMetric() @@ -232,7 +232,7 @@ def test_metric_collection_lightning_log(tmpdir): """Test that MetricCollection works with Lightning modules.""" class TestModel(BoringModel): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = MetricCollection([SumMetric(), DiffMetric()]) self.sum = torch.tensor(0.0) @@ -274,7 +274,7 @@ def test_scriptable(tmpdir): """Test that lightning modules can still be scripted even if metrics cannot.""" class TestModel(BoringModel): - def __init__(self): + def __init__(self) -> None: super().__init__() # the metric is not used in the module's `forward` # so the module should be exportable to TorchScript @@ -312,7 +312,7 @@ def test_dtype_in_pl_module_transfer(tmpdir): """Test that metric states don't change dtype when .half() or .float() is called on the LightningModule.""" class BoringModel(LightningModule): - def __init__(self, metric_dtype=torch.float32): + def __init__(self, metric_dtype=torch.float32) -> None: super().__init__() self.layer = Linear(32, 32) self.metric = SumMetric() diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index e50285c14e3..390c69f35bc 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -282,7 +282,7 @@ def test_collection_filtering(): class DummyMetric(Metric): full_state_update = True - def __init__(self): + def __init__(self) -> None: super().__init__() def update(self, *args, kwarg): @@ -294,7 +294,7 @@ def compute(self): class MyAccuracy(Metric): full_state_update = True - def __init__(self): + def __init__(self) -> None: super().__init__() def update(self, preds, target, kwarg2): diff --git a/tests/unittests/bases/test_composition.py b/tests/unittests/bases/test_composition.py index 2751dfcff97..632b783a451 100644 --- a/tests/unittests/bases/test_composition.py +++ b/tests/unittests/bases/test_composition.py @@ -25,7 +25,7 @@ class DummyMetric(Metric): full_state_update = True - def __init__(self, val_to_return): + def __init__(self, val_to_return) -> None: super().__init__() self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum") self._val_to_return = val_to_return diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 41253cecfcd..d6587f117d4 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -106,7 +106,7 @@ def _test_non_contiguous_tensors(rank): class DummyCatMetric(Metric): full_state_update = True - def __init__(self): + def __init__(self) -> None: super().__init__() self.add_state("x", default=[], dist_reduce_fx=None) @@ -131,7 +131,7 @@ def _test_state_dict_is_synced(rank, tmpdir): class DummyCatMetric(Metric): full_state_update = True - def __init__(self): + def __init__(self) -> None: super().__init__() self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 2af3de42668..f8b5e4e7d3c 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -271,7 +271,7 @@ def test_child_metric_state_dict(): """Test that child metric states will be added to parent state dict.""" class TestModule(Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = DummyMetric() self.metric.add_state("a", tensor(0), persistent=True) @@ -389,7 +389,7 @@ def test_device_if_child_module(metric_class): """Test that if a metric is a child module all values gets moved to the correct device.""" class TestModule(Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = metric_class() self.register_buffer("dummy", torch.zeros(1)) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 5ee770e9d7f..ffba9788952 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import sys import unittest.mock as mock from functools import partial from typing import Any, Callable, Dict, Optional @@ -28,6 +27,7 @@ from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness +from torchmetrics.utilities.imports import _PYTHON_LOWER_3_8 from unittests import THRESHOLD from unittests.classification.inputs import _group_cases from unittests.helpers import seed_all @@ -216,7 +216,7 @@ def run_precision_test_gpu( @mock.patch("unittests.helpers.testers._assert_tensor", _assert_tensor) @mock.patch("unittests.helpers.testers._assert_allclose", _assert_allclose) -@pytest.mark.skipif(sys.version_info.minor < 8, reason="`TestBinaryFairness` requires `python>=3.8`.") +@pytest.mark.skipif(_PYTHON_LOWER_3_8, reason="`TestBinaryFairness` requires `python>=3.8`.") @pytest.mark.parametrize("inputs", _group_cases) class TestBinaryFairness(BinaryFairnessTester): """Test class for `BinaryFairness` metric.""" diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 9b2b275f9b2..ece1d97e95c 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -548,7 +548,7 @@ class DummyMetric(Metric): name = "Dummy" full_state_update: Optional[bool] = True - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.add_state("x", tensor(0.0), dist_reduce_fx="sum") @@ -567,7 +567,7 @@ class DummyListMetric(Metric): name = "DummyList" full_state_update: Optional[bool] = True - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.add_state("x", [], dist_reduce_fx="cat") diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index e195ea4addf..824c1e5c7eb 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -47,7 +47,7 @@ def test_no_train(): """Assert that metric never leaves evaluation mode.""" class MyModel(Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = FrechetInceptionDistance() @@ -108,7 +108,7 @@ def test_fid_same_input(feature): class _ImgDataset(Dataset): - def __init__(self, imgs): + def __init__(self, imgs) -> None: self.imgs = imgs def __getitem__(self, idx): diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index 66998f68003..56772c1f1e9 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -30,7 +30,7 @@ def test_no_train(): """Assert that metric never leaves evaluation mode.""" class MyModel(Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = InceptionScore() @@ -95,7 +95,7 @@ def test_is_update_compute(): class _ImgDataset(Dataset): - def __init__(self, imgs): + def __init__(self, imgs) -> None: self.imgs = imgs def __getitem__(self, idx): diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index e9b859245ab..68dc06a9f30 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -30,7 +30,7 @@ def test_no_train(): """Assert that metric never leaves evaluation mode.""" class MyModel(Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.metric = KernelInceptionDistance() @@ -122,7 +122,7 @@ def test_kid_same_input(feature): class _ImgDataset(Dataset): - def __init__(self, imgs): + def __init__(self, imgs) -> None: self.imgs = imgs def __getitem__(self, idx):