Skip to content

Commit

Permalink
ruff: ANN - static annotations (#1615)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 14, 2023
1 parent 087522b commit c8d649d
Show file tree
Hide file tree
Showing 43 changed files with 105 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def changed_domains(
return "unittests"

# parse domains
def _crop_path(fname: str, paths: List[str]):
def _crop_path(fname: str, paths: List[str]) -> str:
for p in paths:
fname = fname.replace(p, "")
return fname
Expand Down
36 changes: 18 additions & 18 deletions examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch


def pesq_example():
def pesq_example() -> tuple:
"""Plot PESQ audio example."""
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality

Expand All @@ -37,7 +37,7 @@ def pesq_example():
return fig, ax


def pit_example():
def pit_example() -> tuple:
"""Plot PIT audio example."""
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.functional import scale_invariant_signal_noise_ratio
Expand All @@ -58,7 +58,7 @@ def pit_example():
return fig, ax


def sdr_example():
def sdr_example() -> tuple:
"""Plot SDR audio example."""
from torchmetrics.audio.sdr import SignalDistortionRatio

Expand All @@ -78,7 +78,7 @@ def sdr_example():
return fig, ax


def si_sdr_example():
def si_sdr_example() -> tuple:
"""Plot SI-SDR audio example."""
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio

Expand All @@ -98,7 +98,7 @@ def si_sdr_example():
return fig, ax


def snr_example():
def snr_example() -> tuple:
"""Plot SNR audio example."""
from torchmetrics.audio.snr import SignalNoiseRatio

Expand All @@ -118,7 +118,7 @@ def snr_example():
return fig, ax


def si_snr_example():
def si_snr_example() -> tuple:
"""Plot SI-SNR example."""
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio

Expand All @@ -138,7 +138,7 @@ def si_snr_example():
return fig, ax


def stoi_example():
def stoi_example() -> tuple:
"""Plot STOI example."""
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility

Expand All @@ -158,7 +158,7 @@ def stoi_example():
return fig, ax


def accuracy_example():
def accuracy_example() -> tuple:
"""Plot Accuracy example."""
from torchmetrics.classification import MulticlassAccuracy

Expand Down Expand Up @@ -189,7 +189,7 @@ def accuracy_example():
return fig, ax


def mean_squared_error_example():
def mean_squared_error_example() -> tuple:
"""Plot mean squared error example."""
from torchmetrics.regression import MeanSquaredError

Expand All @@ -208,7 +208,7 @@ def mean_squared_error_example():
return fig, ax


def confusion_matrix_example():
def confusion_matrix_example() -> tuple:
"""Plot confusion matrix example."""
from torchmetrics.classification import MulticlassConfusionMatrix

Expand All @@ -222,7 +222,7 @@ def confusion_matrix_example():
return fig, ax


def spectral_distortion_index_example():
def spectral_distortion_index_example() -> tuple:
"""Plot spectral distortion index example example."""
from torchmetrics.image.d_lambda import SpectralDistortionIndex

Expand All @@ -242,7 +242,7 @@ def spectral_distortion_index_example():
return fig, ax


def error_relative_global_dimensionless_synthesis():
def error_relative_global_dimensionless_synthesis() -> tuple:
"""Plot error relative global dimensionless synthesis example."""
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis

Expand All @@ -262,7 +262,7 @@ def error_relative_global_dimensionless_synthesis():
return fig, ax


def peak_signal_noise_ratio():
def peak_signal_noise_ratio() -> tuple:
"""Plot peak signal noise ratio example."""
from torchmetrics.image.psnr import PeakSignalNoiseRatio

Expand All @@ -282,7 +282,7 @@ def peak_signal_noise_ratio():
return fig, ax


def spectral_angle_mapper():
def spectral_angle_mapper() -> tuple:
"""Plot spectral angle mapper example."""
from torchmetrics.image.sam import SpectralAngleMapper

Expand All @@ -302,7 +302,7 @@ def spectral_angle_mapper():
return fig, ax


def structural_similarity_index_measure():
def structural_similarity_index_measure() -> tuple:
"""Plot structural similarity index measure example."""
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

Expand All @@ -322,7 +322,7 @@ def structural_similarity_index_measure():
return fig, ax


def multiscale_structural_similarity_index_measure():
def multiscale_structural_similarity_index_measure() -> tuple:
"""Plot multiscale structural similarity index measure example."""
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure

Expand All @@ -342,7 +342,7 @@ def multiscale_structural_similarity_index_measure():
return fig, ax


def universal_image_quality_index():
def universal_image_quality_index() -> tuple:
"""Plot universal image quality index example."""
from torchmetrics.image.uqi import UniversalImageQualityIndex

Expand All @@ -362,7 +362,7 @@ def universal_image_quality_index():
return fig, ax


def mean_average_precision():
def mean_average_precision() -> tuple:
"""Plot MAP metric."""
from torchmetrics.detection.mean_ap import MeanAveragePrecision

Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@ extend-select = [
"PT", # see: https://pypi.org/project/flake8-pytest-style
"RET", # see: https://pypi.org/project/flake8-return
"SIM", # see: https://pypi.org/project/flake8-simplify
"YTT", # see: https://pypi.org/project/flake8-2020
"ANN", # see: https://pypi.org/project/flake8-annotations
]
ignore = [
"E731", # Do not assign a lambda expression, use a def
"F401", # Imports of __all__
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D107", # Missing docstring in `__init__`
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand All @@ -103,6 +107,13 @@ exclude = [
ignore-init-module-imports = true
unfixable = ["F401"]

[tool.ruff.per-file-ignores]
"setup.py" = ["ANN202", "ANN401"]
"src/**" = ["ANN401"]
"tests/**" = [
"ANN001", "ANN002", "ANN003", "ANN201", "ANN202", "ANN204", "ANN205", "ANN401"
]

[tool.ruff.pydocstyle]
# Use Google-style docstrings.
convention = "google"
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
return re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)


def _load_py_module(fname, pkg="torchmetrics"):
def _load_py_module(fname: str, pkg: str = "torchmetrics"):
spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
py = module_from_spec(spec)
spec.loader.exec_module(py)
Expand All @@ -155,7 +155,7 @@ def _load_py_module(fname, pkg="torchmetrics"):
BASE_REQUIREMENTS = _load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt")


def _prepare_extras(skip_files: Tuple[str] = ("devel.txt", "doctest.txt", "integrate.txt", "docs.txt")):
def _prepare_extras(skip_files: Tuple[str] = ("devel.txt", "doctest.txt", "integrate.txt", "docs.txt")) -> dict:
# find all extra requirements
_load_req = partial(_load_requirements, path_dir=_PATH_REQUIRE)
found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt")))
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.nominal import PearsonsContingencyCoefficient, TheilsU, TschuprowsT # noqa: E402
from torchmetrics.nominal import PearsonsContingencyCoefficient # noqa: E402
from torchmetrics.nominal import TheilsU, TschuprowsT # noqa: E402
from torchmetrics.regression import ConcordanceCorrCoef # noqa: E402
from torchmetrics.regression import CosineSimilarity # noqa: E402
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
ExplainedVariance,
KendallRankCorrCoef,
KLDivergence,
Expand All @@ -79,8 +80,8 @@
WeightedMeanAbsolutePercentageError,
)
from torchmetrics.retrieval import RetrievalFallOut # noqa: E402
from torchmetrics.retrieval import RetrievalHitRate # noqa: E402
from torchmetrics.retrieval import ( # noqa: E402
RetrievalHitRate,
RetrievalMAP,
RetrievalMRR,
RetrievalNormalizedDCG,
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
):
) -> None:
super().__init__(
"max",
-torch.tensor(float("inf")),
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
):
) -> None:
super().__init__(
"min",
torch.tensor(float("inf")),
Expand Down Expand Up @@ -336,7 +336,7 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
):
) -> None:
super().__init__(
"sum",
torch.tensor(0.0),
Expand Down Expand Up @@ -435,7 +435,7 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
):
) -> None:
super().__init__("cat", [], nan_strategy, **kwargs)

def update(self, value: Union[float, Tensor]) -> None:
Expand Down Expand Up @@ -496,7 +496,7 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
):
) -> None:
super().__init__(
"sum",
torch.tensor(0.0),
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
dist_reduce_fx="sum" if self.multidim_average == "global" else "mean",
)

def update(self, preds, target) -> None:
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update metric states with predictions and targets."""
if self.validate_args:
_multiclass_stat_scores_tensor_validation(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/detection/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)

things, stuffs = _parse_categories(things, stuffs)
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from torch import Tensor
from torch.nn import Module

from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.data import (
_flatten,
_squeeze_if_scalar,
apply_to_collection,
dim_zero_cat,
dim_zero_max,
dim_zero_mean,
Expand All @@ -35,6 +35,7 @@
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
from torchmetrics.utilities.prints import rank_zero_warn


def jit_distributed_available() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/nominal/cramers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.num_classes = num_classes
self.bias_correction = bias_correction
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/nominal/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.num_classes = num_classes

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/nominal/theils_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.num_classes = num_classes

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/nominal/tschuprows.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.num_classes = num_classes
self.bias_correction = bias_correction
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/regression/kendall.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided",
num_outputs: int = 1,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
if not isinstance(t_test, bool):
raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.")
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
baseline_path: Optional[str] = None,
baseline_url: Optional[str] = None,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path or _DEFAULT_MODEL
self.num_layers = num_layers
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Any,
):
) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
self.smooth = smooth
Expand Down
Loading

0 comments on commit c8d649d

Please sign in to comment.