diff --git a/.github/assistant.py b/.github/assistant.py index ad054d96e2f..174ed2117a4 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -16,7 +16,7 @@ import os import re import sys -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import fire from packaging.version import parse @@ -83,7 +83,7 @@ def _replace_requirement(fpath: str, old_str: str = "", new_str: str = "") -> No fp.write(req) @staticmethod - def replace_str_requirements(old_str: str, new_str: str, req_files: List[str] = REQUIREMENTS_FILES) -> None: + def replace_str_requirements(old_str: str, new_str: str, req_files: list[str] = REQUIREMENTS_FILES) -> None: """Replace a particular string in all requirements files.""" if isinstance(req_files, str): req_files = [req_files] @@ -96,7 +96,7 @@ def replace_min_requirements(fpath: str) -> None: AssistantCLI._replace_requirement(fpath, old_str=">=", new_str="==") @staticmethod - def set_oldest_versions(req_files: List[str] = REQUIREMENTS_FILES) -> None: + def set_oldest_versions(req_files: list[str] = REQUIREMENTS_FILES) -> None: """Set the oldest version for requirements.""" AssistantCLI.set_min_torch_by_python() if isinstance(req_files, str): @@ -109,8 +109,8 @@ def changed_domains( pr: Optional[int] = None, auth_token: Optional[str] = None, as_list: bool = False, - general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES, - ) -> Union[str, List[str]]: + general_sub_pkgs: tuple[str] = _PKG_WIDE_SUBPACKAGES, + ) -> Union[str, list[str]]: """Determine what domains were changed in particular PR.""" import github @@ -139,7 +139,7 @@ def changed_domains( return "unittests" # parse domains - def _crop_path(fname: str, paths: List[str]) -> str: + def _crop_path(fname: str, paths: list[str]) -> str: for p in paths: fname = fname.replace(p, "") return fname diff --git a/.github/workflows/_focus-diff.yml b/.github/workflows/_focus-diff.yml index faa4812f375..0e68eac2406 100644 --- a/.github/workflows/_focus-diff.yml +++ b/.github/workflows/_focus-diff.yml @@ -18,8 +18,6 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - #with: - # python-version: 3.8 - name: Get PR diff id: diff-domains diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index d7abe9bb569..fa3eb8ccdee 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -31,7 +31,7 @@ jobs: testing-matrix: | { "os": ["ubuntu-22.04", "macos-13", "windows-2022"], - "python-version": ["3.8", "3.11"] + "python-version": ["3.9", "3.11"] } check-md-links: diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 7d44adf3aab..4e79c632692 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -42,7 +42,7 @@ jobs: - "2.5.0" include: # cover additional python and PT combinations - - { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" } + - { os: "ubuntu-20.04", python-version: "3.9", pytorch-version: "2.0.1", requires: "oldest" } - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" } - { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" } # standard mac machine, not the M1 diff --git a/.github/workflows/publish-pkg.yml b/.github/workflows/publish-pkg.yml index f7b3f46997f..490eeac9c15 100644 --- a/.github/workflows/publish-pkg.yml +++ b/.github/workflows/publish-pkg.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: >- diff --git a/CHANGELOG.md b/CHANGELOG.md index 22225fc5471..8301a7da446 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671)) +- Dropped support for Python 3.8 ([#2827](https://github.com/Lightning-AI/torchmetrics/pull/2827)) + + - Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) diff --git a/_samples/bert_score-own_model.py b/_samples/bert_score-own_model.py index d5e74078c65..982d7f63876 100644 --- a/_samples/bert_score-own_model.py +++ b/_samples/bert_score-own_model.py @@ -18,7 +18,7 @@ """ from pprint import pprint -from typing import Dict, List, Union +from typing import Union import torch from torch import Tensor, nn @@ -53,7 +53,7 @@ def __init__(self) -> None: self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM), } - def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, Tensor]: + def __call__(self, sentences: Union[str, list[str]], max_len: int = _MAX_LEN) -> dict[str, Tensor]: """Call method to tokenize user input. The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method @@ -69,7 +69,7 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values. """ - output_dict: Dict[str, Tensor] = {} + output_dict: dict[str, Tensor] = {} if isinstance(sentences, str): sentences = [sentences] # Add special tokens @@ -96,7 +96,7 @@ def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_ return nn.TransformerEncoder(encoder_layer, num_layers=num_layers) -def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor: +def user_forward_fn(model: Module, batch: dict[str, Tensor]) -> Tensor: """User forward function used for the computation of model embeddings. This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the diff --git a/_samples/rouge_score-own_normalizer_and_tokenizer.py b/_samples/rouge_score-own_normalizer_and_tokenizer.py index 14e2252d438..28619effcba 100644 --- a/_samples/rouge_score-own_normalizer_and_tokenizer.py +++ b/_samples/rouge_score-own_normalizer_and_tokenizer.py @@ -18,8 +18,8 @@ """ import re +from collections.abc import Sequence from pprint import pprint -from typing import Sequence from torchmetrics.text.rouge import ROUGEScore diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index c7130a895e4..7099fc08d2b 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -8,7 +8,6 @@ # %% # Import necessary libraries -from typing import Tuple import matplotlib.animation as animation import matplotlib.pyplot as plt @@ -20,7 +19,7 @@ # Generate a clean signal (simulating a high-quality recording) -def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]: +def generate_clean_signal(length: int = 1000) -> tuple[np.ndarray, np.ndarray]: """Generate a clean signal (sine wave)""" t = np.linspace(0, 1, length) signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave, representing the clean recording diff --git a/pyproject.toml b/pyproject.toml index 5a765978081..d8e25bd895a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ ] [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 120 #[tool.ruff.pycodestyle] @@ -68,6 +68,8 @@ lint.per-file-ignores."setup.py" = [ lint.per-file-ignores."src/**" = [ "ANN401", "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. + "UP006", # todo: Use `list` instead of `List` for type annotation + "UP035", # todo: `typing.List` is deprecated, use `list` instead ] lint.per-file-ignores."tests/**" = [ "ANN001", @@ -77,9 +79,6 @@ lint.per-file-ignores."tests/**" = [ "S101", "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue ] -lint.unfixable = [ - "F401", -] # Unlike Flake8, default to a complexity level of 10. lint.mccabe.max-complexity = 10 # Use Google-style docstrings. diff --git a/setup.py b/setup.py index 2324b660cc0..915045028ac 100755 --- a/setup.py +++ b/setup.py @@ -2,11 +2,12 @@ import glob import os import re +from collections.abc import Iterable, Iterator from functools import partial from importlib.util import module_from_spec, spec_from_file_location from itertools import chain from pathlib import Path -from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Optional, Union from pkg_resources import Requirement, yield_lines from setuptools import find_packages, setup @@ -97,7 +98,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen def _load_requirements( path_dir: str, file_name: str = "base.txt", unfreeze: bool = not _FREEZE_REQUIREMENTS -) -> List[str]: +) -> list[str]: """Load requirements from a file. >>> _load_requirements(_PATH_REQUIRE) @@ -161,7 +162,7 @@ def _load_py_module(fname: str, pkg: str = "torchmetrics"): BASE_REQUIREMENTS = _load_requirements(path_dir=_PATH_REQUIRE, file_name="base.txt") -def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.txt",)) -> dict: +def _prepare_extras(skip_pattern: str = "^_", skip_files: tuple[str] = ("base.txt",)) -> dict: """Preparing extras for the package listing requirements. Args: @@ -215,7 +216,7 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx include_package_data=True, zip_safe=False, keywords=["deep learning", "machine learning", "pytorch", "metrics", "AI"], - python_requires=">=3.8", + python_requires=">=3.9", setup_requires=[], install_requires=BASE_REQUIREMENTS, extras_require=_prepare_extras(), diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index ee4f86ffdc3..312197ccbc4 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch from torch import Tensor @@ -55,7 +56,7 @@ class BaseAggregator(Metric): def __init__( self, fn: Union[Callable, str], - default_value: Union[Tensor, List], + default_value: Union[Tensor, list], nan_strategy: Union[str, float] = "error", state_name: str = "value", **kwargs: Any, @@ -74,7 +75,7 @@ def __init__( def _cast_and_nan_check_input( self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Convert input ``x`` to a tensor and check for Nans.""" if not isinstance(x, Tensor): x = torch.as_tensor(x, dtype=self.dtype, device=self.device) diff --git a/src/torchmetrics/audio/dnsmos.py b/src/torchmetrics/audio/dnsmos.py index 406b817eb05..6b6f8fc60de 100644 --- a/src/torchmetrics/audio/dnsmos.py +++ b/src/torchmetrics/audio/dnsmos.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/nisqa.py b/src/torchmetrics/audio/nisqa.py index c079d903101..d1be81a01da 100644 --- a/src/torchmetrics/audio/nisqa.py +++ b/src/torchmetrics/audio/nisqa.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index ee6a1751359..38146f6c943 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 2def91d4d01..6c28738a3f9 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal @@ -87,7 +88,7 @@ def __init__( eval_func: Literal["max", "min"] = "max", **kwargs: Any, ) -> None: - base_kwargs: Dict[str, Any] = { + base_kwargs: dict[str, Any] = { "dist_sync_on_step": kwargs.pop("dist_sync_on_step", False), "process_group": kwargs.pop("process_group", None), "dist_sync_fn": kwargs.pop("dist_sync_fn", None), diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 9b8646aaa1f..e932e06199b 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index d8b9fd4c173..4947d5141c3 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/srmr.py b/src/torchmetrics/audio/srmr.py index 6c910738880..453f1bb7eab 100644 --- a/src/torchmetrics/audio/srmr.py +++ b/src/torchmetrics/audio/srmr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index cea3df3514e..60d1fb9d670 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index efd337d496a..bc1a8bb5e36 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -489,7 +490,7 @@ class Accuracy(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Accuracy"], + cls: type["Accuracy"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 65e9493b14c..8e5a69092fb 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -108,7 +109,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): def __init__( self, max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -257,7 +258,7 @@ def __init__( self, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -412,7 +413,7 @@ def __init__( self, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -507,9 +508,9 @@ class AUROC(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["AUROC"], + cls: type["AUROC"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 9d36774938c..221f400918c 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -255,7 +256,7 @@ def __init__( self, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -415,7 +416,7 @@ def __init__( self, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -517,9 +518,9 @@ class AveragePrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["AveragePrecision"], + cls: type["AveragePrecision"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 404d9089bbc..da2dc4d3d40 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -370,7 +371,7 @@ class CalibrationError(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["CalibrationError"], + cls: type["CalibrationError"], task: Literal["binary", "multiclass"], n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 3531eb6b106..aa1d1d0780c 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -314,7 +315,7 @@ class labels. """ def __new__( # type: ignore[misc] - cls: Type["CohenKappa"], + cls: type["CohenKappa"], task: Literal["binary", "multiclass"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index f1f870bcbfd..b5f304dbd82 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Type +from typing import Any, Optional import torch from torch import Tensor @@ -150,7 +150,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -294,7 +294,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -441,7 +441,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -517,7 +517,7 @@ class ConfusionMatrix(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ConfusionMatrix"], + cls: type["ConfusionMatrix"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index df9807cea7f..ad160aa0d09 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Tuple, Union, no_type_check +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union, no_type_check import torch from torch import Tensor @@ -247,7 +248,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.fn.append(fn) @no_type_check - def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _get_final_stats(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform concatenation on the stat scores if necessary, before passing them to a compute function.""" tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 10b9aedc2fc..e167b22219c 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -394,7 +395,7 @@ class ExactMatch(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ExactMatch"], + cls: type["ExactMatch"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 526ad1ae0da..dcbe8a8b69d 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -1116,7 +1117,7 @@ class FBetaScore(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["FBetaScore"], + cls: type["FBetaScore"], task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, threshold: float = 0.5, @@ -1183,7 +1184,7 @@ class F1Score(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["F1Score"], + cls: type["F1Score"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 8e38b24faeb..d966063856c 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -47,7 +48,7 @@ def _create_states(self, num_groups: int) -> None: self.add_state("tn", default(), dist_reduce_fx="sum") self.add_state("fn", default(), dist_reduce_fx="sum") - def _update_states(self, group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]) -> None: + def _update_states(self, group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]) -> None: for group, stats in enumerate(group_stats): tp, fp, tn, fn = stats self.tp[group] += tp @@ -147,7 +148,7 @@ def update(self, preds: Tensor, target: Tensor, groups: Tensor) -> None: def compute( self, - ) -> Dict[str, Tensor]: + ) -> dict[str, Tensor]: """Compute tp, fp, tn and fn rates based on inputs passed in to ``update`` previously.""" results = torch.stack((self.tp, self.fp, self.tn, self.fn), dim=1) @@ -267,7 +268,7 @@ def update(self, preds: Tensor, target: Tensor, groups: Tensor) -> None: def compute( self, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """Compute fairness criteria based on inputs passed in to ``update`` previously.""" if self.task == "demographic_parity": return _compute_binary_demographic_parity(self.tp, self.fp, self.tn, self.fn) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index bd0bfa733c6..183af336ae8 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -493,7 +494,7 @@ class HammingDistance(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["HammingDistance"], + cls: type["HammingDistance"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 5514f98cccc..878ea271049 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -353,7 +354,7 @@ class HingeLoss(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["HingeLoss"], + cls: type["HingeLoss"], task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, squared: bool = False, diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 385009d5a6a..5f9a15b2e4f 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -458,7 +459,7 @@ class JaccardIndex(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["JaccardIndex"], + cls: type["JaccardIndex"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 49de1f03795..ea1b7a23cb5 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -390,7 +391,7 @@ class MatthewsCorrCoef(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["MatthewsCorrCoef"], + cls: type["MatthewsCorrCoef"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py index d0b19dc1247..cdff97f86e2 100644 --- a/src/torchmetrics/classification/negative_predictive_value.py +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -486,7 +487,7 @@ class NegativePredictiveValue(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["NegativePredictiveValue"], + cls: type["NegativePredictiveValue"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index c761f9aa8a9..a17f19aa39b 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -115,7 +116,7 @@ class BinaryPrecisionAtFixedRecall(BinaryPrecisionRecallCurve): def __init__( self, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -126,7 +127,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute( @@ -258,7 +259,7 @@ def __init__( self, num_classes: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -271,7 +272,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( @@ -404,7 +405,7 @@ def __init__( self, num_labels: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -417,7 +418,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( @@ -485,10 +486,10 @@ class PrecisionAtFixedRecall(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["PrecisionAtFixedRecall"], + cls: type["PrecisionAtFixedRecall"], task: Literal["binary", "multiclass", "multilabel"], min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 0380545b5ac..0bd6f8b0d99 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -985,7 +986,7 @@ class Precision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Precision"], + cls: type["Precision"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, @@ -1050,7 +1051,7 @@ class Recall(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Recall"], + cls: type["Recall"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index d0f9c632c02..7a8430eeab2 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -136,7 +136,7 @@ class BinaryPrecisionRecallCurve(Metric): def __init__( self, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -171,14 +171,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_precision_recall_curve_compute(state, self.thresholds) def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -330,7 +330,7 @@ class MulticlassPrecisionRecallCurve(Metric): def __init__( self, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, @@ -374,14 +374,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -530,7 +530,7 @@ class MultilabelPrecisionRecallCurve(Metric): def __init__( self, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -570,14 +570,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -667,9 +667,9 @@ class PrecisionRecallCurve(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["PrecisionRecallCurve"], + cls: type["PrecisionRecallCurve"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index 9dda737030d..fca0d64ff86 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 58a460b7b2e..f34b3bb580a 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -114,7 +115,7 @@ class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): def __init__( self, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -125,7 +126,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision) @@ -257,7 +258,7 @@ def __init__( self, num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -270,7 +271,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( @@ -403,7 +404,7 @@ def __init__( self, num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -416,7 +417,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( @@ -484,10 +485,10 @@ class RecallAtFixedPrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["RecallAtFixedPrecision"], + cls: type["RecallAtFixedPrecision"], task: Literal["binary", "multiclass", "multilabel"], min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index fc378e53ee0..b56bc856744 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -117,14 +117,14 @@ class BinaryROC(BinaryPrecisionRecallCurve): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -287,17 +287,17 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -449,17 +449,17 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Label" - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -571,9 +571,9 @@ class ROC(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ROC"], + cls: type["ROC"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index c9bcd0bad6d..23851affe24 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -111,7 +111,7 @@ class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): def __init__( self, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -122,7 +122,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _binary_sensitivity_at_specificity_compute(state, self.thresholds, self.min_specificity) @@ -208,7 +208,7 @@ def __init__( self, num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -223,7 +223,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_sensitivity_at_specificity_compute( @@ -309,7 +309,7 @@ def __init__( self, num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -322,7 +322,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_sensitivity_at_specificity_compute( @@ -346,10 +346,10 @@ class SensitivityAtSpecificity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["SensitivityAtSpecificity"], + cls: type["SensitivityAtSpecificity"], task: Literal["binary", "multiclass", "multilabel"], min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index caca10dfa2b..dab5fde8a60 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -477,7 +478,7 @@ class Specificity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Specificity"], + cls: type["Specificity"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index 54f38f20b06..d8ed0f27cd4 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -111,7 +111,7 @@ class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve): def __init__( self, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -122,7 +122,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity) @@ -208,7 +208,7 @@ def __init__( self, num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -223,7 +223,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_specificity_at_sensitivity_compute( @@ -309,7 +309,7 @@ def __init__( self, num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -322,7 +322,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_specificity_at_sensitivity_compute( @@ -346,10 +346,10 @@ class SpecificityAtSensitivity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["SpecificityAtSensitivity"], + cls: type["SpecificityAtSensitivity"], task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 96d797fd5d6..3c55c431fa2 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -79,7 +79,7 @@ def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: self.tn += tn self.fn += fn - def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _final_state(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Aggregate states that are lists and return final states.""" tp = dim_zero_cat(self.tp) fp = dim_zero_cat(self.fp) @@ -526,7 +526,7 @@ class StatScores(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["StatScores"], + cls: type["StatScores"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py index c797a5ac23b..ebcf4749d08 100644 --- a/src/torchmetrics/clustering/adjusted_mutual_info_score.py +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 5c1f5f49276..20278f74bc3 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index 483e4332148..c331fba7866 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index 40827b568cb..ddd079793cd 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 5a85074443d..65d1c0c9a94 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py index 32fcffe37a7..1317a0cee1c 100644 --- a/src/torchmetrics/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index 01039a8fdcb..260ab522245 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index ede0c1393fc..a2be02f834e 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/normalized_mutual_info_score.py b/src/torchmetrics/clustering/normalized_mutual_info_score.py index 927a4451009..2583b0b2a9e 100644 --- a/src/torchmetrics/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/clustering/normalized_mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 8ded8b27d0d..724a38b227c 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index b4ec0c4e4cc..e034027fe0d 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -13,8 +13,9 @@ # limitations under the License. # this is just a bypass for this module name collision with built-in one from collections import OrderedDict +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy -from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import torch from torch import Tensor @@ -190,17 +191,17 @@ class name of the metric: """ - _modules: Dict[str, Metric] # type: ignore[assignment] + _modules: dict[str, Metric] # type: ignore[assignment] _groups: Dict[int, List[str]] - __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] + __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"] def __init__( self, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], *additional_metrics: Metric, prefix: Optional[str] = None, postfix: Optional[str] = None, - compute_groups: Union[bool, List[List[str]]] = True, + compute_groups: Union[bool, list[list[str]]] = True, ) -> None: super().__init__() @@ -213,12 +214,12 @@ def __init__( self.add_metrics(metrics, *additional_metrics) @property - def metric_state(self) -> Dict[str, Dict[str, Any]]: + def metric_state(self) -> dict[str, dict[str, Any]]: """Get the current state of the metric.""" return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)} @torch.jit.unused - def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Call forward for each metric sequentially. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) @@ -341,13 +342,13 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count self._state_is_copy = copy - def compute(self) -> Dict[str, Any]: + def compute(self) -> dict[str, Any]: """Compute the result for each metric in the collection.""" return self._compute_and_reduce("compute") def _compute_and_reduce( self, method_name: Literal["compute", "forward"], *args: Any, **kwargs: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Compute result from collection and reduce into a single dictionary. Args: @@ -421,7 +422,7 @@ def persistent(self, mode: bool = True) -> None: m.persistent(mode) def add_metrics( - self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric + self, metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], *additional_metrics: Metric ) -> None: """Add new metrics to Metric Collection.""" if isinstance(metrics, Metric): @@ -547,7 +548,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: return self._modules.keys() return self._to_renamed_dict().keys() - def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]: + def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[tuple[str, Metric]]: r"""Return an iterable of the ModuleDict key/value pairs. Args: @@ -616,7 +617,7 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection": def plot( self, - val: Optional[Union[Dict, Sequence[Dict]]] = None, + val: Optional[Union[dict, Sequence[dict]]] = None, ax: Optional[Union[_AX_TYPE, Sequence[_AX_TYPE]]] = None, together: bool = False, ) -> Sequence[_PLOT_OUT_TYPE]: diff --git a/src/torchmetrics/detection/_deprecated.py b/src/torchmetrics/detection/_deprecated.py index c162c751554..f8acd23adb6 100644 --- a/src/torchmetrics/detection/_deprecated.py +++ b/src/torchmetrics/detection/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Collection +from collections.abc import Collection +from typing import Any from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality from torchmetrics.utilities.prints import _deprecated_root_import_class diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 604daef82eb..9831842734d 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Union import numpy as np import torch @@ -34,7 +35,7 @@ log = logging.getLogger(__name__) -def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: +def compute_area(inputs: list[Any], iou_type: str = "bbox") -> Tensor: """Compute area of input depending on the specified iou_type. Default output for empty input is :class:`~torch.Tensor` @@ -56,8 +57,8 @@ def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: def compute_iou( - det: List[Any], - gt: List[Any], + det: list[Any], + gt: list[Any], iou_type: str = "bbox", ) -> Tensor: """Compute IOU between detections and ground-truth using the specified iou_type.""" @@ -124,7 +125,7 @@ class COCOMetricResults(BaseMetricResults): ) -def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarray, np.ndarray]]) -> Tensor: +def _segm_iou(det: list[tuple[np.ndarray, np.ndarray]], gt: list[tuple[np.ndarray, np.ndarray]]) -> Tensor: """Compute IOU between detections and ground-truths using mask-IOU. Implementation is based on pycocotools toolkit for mask_utils. @@ -315,9 +316,9 @@ def __init__( self, box_format: str = "xyxy", iou_type: str = "bbox", - iou_thresholds: Optional[List[float]] = None, - rec_thresholds: Optional[List[float]] = None, - max_detection_thresholds: Optional[List[int]] = None, + iou_thresholds: Optional[list[float]] = None, + rec_thresholds: Optional[list[float]] = None, + max_detection_thresholds: Optional[list[int]] = None, class_metrics: bool = False, **kwargs: Any, ) -> None: @@ -364,7 +365,7 @@ def __init__( self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update state with predictions and targets.""" _input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type] @@ -393,7 +394,7 @@ def _move_list_states_to_cpu(self) -> None: current_to_cpu.append(cur_v) setattr(self, key, current_to_cpu) - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + def _get_safe_item_values(self, item: dict[str, Any]) -> Union[Tensor, tuple]: import pycocotools.mask as mask_utils from torchvision.ops import box_convert @@ -410,7 +411,7 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: return tuple(masks) raise Exception(f"IOU type {self.iou_type} is not supported") - def _get_classes(self) -> List: + def _get_classes(self) -> list: """Return a list of unique classes found in ground truth and detection data.""" if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: return torch.cat(self.detection_labels + self.groundtruth_labels).unique().tolist() @@ -457,8 +458,8 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: return compute_iou(det, gt, self.iou_type).to(self.device) def __evaluate_image_gt_no_preds( - self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], num_iou_thrs: int - ) -> Dict[str, Any]: + self, gt: Tensor, gt_label_mask: Tensor, area_range: tuple[int, int], num_iou_thrs: int + ) -> dict[str, Any]: """Evaluate images with a ground truth but no predictions.""" # GTs gt = [gt[i] for i in gt_label_mask] @@ -486,9 +487,9 @@ def __evaluate_image_preds_no_gt( idx: int, det_label_mask: Tensor, max_det: int, - area_range: Tuple[int, int], + area_range: tuple[int, int], num_iou_thrs: int, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Evaluate images with a prediction but no ground truth.""" # GTs num_gt = 0 @@ -520,7 +521,7 @@ def __evaluate_image_preds_no_gt( } def _evaluate_image( - self, idx: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict + self, idx: int, class_id: int, area_range: tuple[int, int], max_det: int, ious: dict ) -> Optional[dict]: """Perform evaluation for single class and image. @@ -651,7 +652,7 @@ def _find_best_gt_match( def _summarize( self, - results: Dict, + results: dict, avg_prec: bool = True, iou_threshold: Optional[float] = None, area_range: str = "all", @@ -694,7 +695,7 @@ def _summarize( return torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1]) - def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]: + def _calculate(self, class_ids: list) -> tuple[MAPMetricResults, MARMetricResults]: """Calculate the precision and recall for all supplied classes to calculate mAP/mAR. Args: @@ -752,7 +753,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult return precision, recall # type: ignore[return-value] - def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]: + def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> tuple[MAPMetricResults, MARMetricResults]: """Summarizes the precision and recall values to calculate mAP/mAR. Args: @@ -800,7 +801,7 @@ def __calculate_recall_precision_scores( max_det: int, num_imgs: int, num_bbox_areas: int, - ) -> Tuple[Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor]: num_rec_thrs = len(rec_thresholds) idx_cls_pointer = idx_cls * num_bbox_areas * num_imgs idx_bbox_area_pointer = idx_bbox_area * num_imgs @@ -916,8 +917,8 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt @staticmethod def _gather_tuple_list( - list_to_gather: List[Union[tuple, Tensor]], process_group: Optional[Any] = None - ) -> List[Any]: + list_to_gather: list[Union[tuple, Tensor]], process_group: Optional[Any] = None + ) -> list[Any]: """Gather a list of tuples over multiple devices.""" world_size = dist.get_world_size(group=process_group) dist.barrier(group=process_group) @@ -928,7 +929,7 @@ def _gather_tuple_list( return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index] def plot( - self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index b6174c3b60c..1545ab62a83 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 7eb3780a112..edfebb38184 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index d024adad817..cf03a73c65d 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index dc31a7c7497..ddb54463f54 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Literal, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Literal, Union from torch import Tensor def _input_validator( - preds: Sequence[Dict[str, Tensor]], - targets: Sequence[Dict[str, Tensor]], - iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox", + preds: Sequence[dict[str, Tensor]], + targets: Sequence[dict[str, Tensor]], + iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"]]] = "bbox", ignore_score: bool = False, ) -> None: """Ensure the correct input format of `preds` and `targets`.""" @@ -88,8 +89,8 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: def _validate_iou_type_arg( - iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", -) -> Tuple[str]: + iou_type: Union[Literal["bbox", "segm"], tuple[str]] = "bbox", +) -> tuple[str]: """Validate that iou type argument is correct.""" allowed_iou_types = ("segm", "bbox") if isinstance(iou_type, str): diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index d4930d905ab..22d7e5225d4 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -178,7 +179,7 @@ def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: return _iou_compute(*args, **kwargs) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update state with predictions and targets.""" _input_validator(preds, target, ignore_score=True) @@ -204,7 +205,7 @@ def _get_safe_item_values(self, boxes: Tensor) -> Tensor: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") return boxes - def _get_gt_classes(self) -> List: + def _get_gt_classes(self) -> list: """Returns a list of unique classes found in ground truth and detection data.""" if len(self.groundtruth_labels) > 0: return torch.cat(self.groundtruth_labels).unique().tolist() @@ -213,7 +214,7 @@ def _get_gt_classes(self) -> List: def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() - results: Dict[str, Tensor] = {f"{self._iou_type}": score} + results: dict[str, Tensor] = {f"{self._iou_type}": score} if torch.isnan(score): # if no valid boxes are found results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device) if self.class_metrics: diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index c1c63d1a9b4..a60be809bac 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -14,8 +14,9 @@ import contextlib import io import json +from collections.abc import Sequence from types import ModuleType -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, ClassVar, List, Optional, Union import numpy as np import torch @@ -47,7 +48,7 @@ ] -def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> Tuple[object, object, ModuleType]: +def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> tuple[object, object, ModuleType]: """Load the backend tools for the given backend.""" if backend == "pycocotools": if not _PYCOCOTOOLS_AVAILABLE: @@ -355,7 +356,7 @@ class MeanAveragePrecision(Metric): warn_on_many_detections: bool = True - __jit_unused_properties__: ClassVar[List[str]] = [ + __jit_unused_properties__: ClassVar[list[str]] = [ "is_differentiable", "higher_is_better", "plot_lower_bound", @@ -372,10 +373,10 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", - iou_thresholds: Optional[List[float]] = None, - rec_thresholds: Optional[List[float]] = None, - max_detection_thresholds: Optional[List[int]] = None, + iou_type: Union[Literal["bbox", "segm"], tuple[str]] = "bbox", + iou_thresholds: Optional[list[float]] = None, + rec_thresholds: Optional[list[float]] = None, + max_detection_thresholds: Optional[list[int]] = None, class_metrics: bool = False, extended_summary: bool = False, average: Literal["macro", "micro"] = "macro", @@ -474,7 +475,7 @@ def mask_utils(self) -> object: _, _, mask_utils = _load_backend_tools(self.backend) return mask_utils - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update metric state. Raises: @@ -596,7 +597,7 @@ def compute(self) -> dict: return result_dict - def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object, object]: + def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> tuple[object, object]: """Returns the coco datasets for the target and the predictions.""" if average == "micro": # for micro averaging we set everything to be the same class @@ -628,7 +629,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object return coco_preds, coco_target - def _coco_stats_to_tensor_dict(self, stats: List[float], prefix: str) -> Dict[str, Tensor]: + def _coco_stats_to_tensor_dict(self, stats: list[float], prefix: str) -> dict[str, Tensor]: """Converts the output of COCOeval.stats to a dict of tensors.""" mdt = self.max_detection_thresholds return { @@ -650,9 +651,9 @@ def _coco_stats_to_tensor_dict(self, stats: List[float], prefix: str) -> Dict[st def coco_to_tm( coco_preds: str, coco_target: str, - iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", + iou_type: Union[Literal["bbox", "segm"], list[str]] = "bbox", backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools", - ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: + ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: """Utility function for converting .json coco format files to the input format of this metric. The function accepts a file for the predictions and a file for the target in coco format and converts them to @@ -824,8 +825,8 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: f.write(target_json) def _get_safe_item_values( - self, item: Dict[str, Any], warn: bool = False - ) -> Tuple[Optional[Tensor], Optional[Tuple]]: + self, item: dict[str, Any], warn: bool = False + ) -> tuple[Optional[Tensor], Optional[tuple]]: """Convert and return the boxes or masks from the item depending on the iou_type. Args: @@ -857,7 +858,7 @@ def _get_safe_item_values( _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return output # type: ignore[return-value] - def _get_classes(self) -> List: + def _get_classes(self) -> list: """Return a list of unique classes found in ground truth and detection data.""" if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() @@ -865,13 +866,13 @@ def _get_classes(self) -> List: def _get_coco_format( self, - labels: List[torch.Tensor], - boxes: Optional[List[torch.Tensor]] = None, - masks: Optional[List[torch.Tensor]] = None, - scores: Optional[List[torch.Tensor]] = None, - crowds: Optional[List[torch.Tensor]] = None, - area: Optional[List[torch.Tensor]] = None, - ) -> Dict: + labels: List[Tensor], + boxes: Optional[List[Tensor]] = None, + masks: Optional[List[Tensor]] = None, + scores: Optional[List[Tensor]] = None, + crowds: Optional[List[Tensor]] = None, + area: Optional[List[Tensor]] = None, + ) -> dict: """Transforms and returns all cached targets or predictions in COCO format. Format is defined at @@ -957,7 +958,7 @@ def _get_coco_format( return {"images": images, "annotations": annotations, "categories": classes} def plot( - self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -1042,7 +1043,7 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) # type: ignore[arg-type] @staticmethod - def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + def _gather_tuple_list(list_to_gather: list[tuple], process_group: Optional[Any] = None) -> list[Any]: """Gather a list of tuples over multiple devices. Args: diff --git a/src/torchmetrics/detection/panoptic_qualities.py b/src/torchmetrics/detection/panoptic_qualities.py index b4629be8e69..ce3948f3234 100644 --- a/src/torchmetrics/detection/panoptic_qualities.py +++ b/src/torchmetrics/detection/panoptic_qualities.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Optional, Sequence, Union +from collections.abc import Collection, Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/audio/_deprecated.py b/src/torchmetrics/functional/audio/_deprecated.py index 8b337318f7a..ff3c18e5458 100644 --- a/src/torchmetrics/functional/audio/_deprecated.py +++ b/src/torchmetrics/functional/audio/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional from torch import Tensor from typing_extensions import Literal @@ -16,7 +16,7 @@ def _permutation_invariant_training( mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise", eval_func: Literal["max", "min"] = "max", **kwargs: Any, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Wrapper for deprecated import. >>> from torch import tensor diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py index 8f74f8374c9..2e23fba23a2 100644 --- a/src/torchmetrics/functional/audio/dnsmos.py +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import lru_cache -from typing import Any, Dict, Optional +from typing import Any, Optional import numpy as np import torch @@ -33,7 +33,7 @@ class InferenceSession: # type:ignore """Dummy InferenceSession.""" - def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + def __init__(self, **kwargs: dict[str, Any]) -> None: ... __doctest_requires__ = { diff --git a/src/torchmetrics/functional/audio/nisqa.py b/src/torchmetrics/functional/audio/nisqa.py index 1a1c066ebf0..8718c97d8e6 100644 --- a/src/torchmetrics/functional/audio/nisqa.py +++ b/src/torchmetrics/functional/audio/nisqa.py @@ -40,7 +40,7 @@ import os import warnings from functools import lru_cache -from typing import Any, Dict, Tuple +from typing import Any import numpy as np import torch @@ -121,7 +121,7 @@ def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor: @lru_cache -def _load_nisqa_model() -> Tuple[nn.Module, Dict[str, Any]]: +def _load_nisqa_model() -> tuple[nn.Module, dict[str, Any]]: """Load NISQA model and its parameters. Returns: @@ -157,7 +157,7 @@ class _NISQADIM(nn.Module): # ported from https://github.com/gabrielmittag/NISQA # Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab # MIT License - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.cnn = _Framewise(args) self.time_dependency = _TimeDependency(args) @@ -173,7 +173,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _Framewise(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _AdaptCNN(args) @@ -187,7 +187,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _AdaptCNN(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.pool_1 = args["cnn_pool_1"] self.pool_2 = args["cnn_pool_2"] @@ -231,7 +231,7 @@ def forward(self, x: Tensor) -> Tensor: class _TimeDependency(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _SelfAttention(args) @@ -241,7 +241,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _SelfAttention(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() encoder_layer = _SelfAttentionLayer(args) self.norm1 = nn.LayerNorm(args["td_sa_d_model"]) @@ -254,7 +254,7 @@ def _reset_parameters(self) -> None: if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]: src = self.linear(src) output = src.transpose(1, 0) output = self.norm1(output) @@ -265,7 +265,7 @@ def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: class _SelfAttentionLayer(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.self_attn = nn.MultiheadAttention(args["td_sa_d_model"], args["td_sa_nhead"], args["td_sa_dropout"]) self.linear1 = nn.Linear(args["td_sa_d_model"], args["td_sa_h"]) @@ -277,7 +277,7 @@ def __init__(self, args: Dict[str, Any]) -> None: self.dropout2 = nn.Dropout(args["td_sa_dropout"]) self.activation = relu - def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]: mask = torch.arange(src.shape[0])[None, :] < n_wins[:, None] src2 = self.self_attn(src, src, src, key_padding_mask=~mask)[0] src = src + self.dropout1(src2) @@ -290,7 +290,7 @@ def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: class _Pooling(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _PoolAttFF(args) @@ -300,7 +300,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _PoolAttFF(torch.nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.linear1 = nn.Linear(args["td_sa_d_model"], args["pool_att_h"]) self.linear2 = nn.Linear(args["pool_att_h"], 1) @@ -319,7 +319,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: return self.linear3(x) -def _get_librosa_melspec(y: np.ndarray, sr: int, args: Dict[str, Any]) -> np.ndarray: +def _get_librosa_melspec(y: np.ndarray, sr: int, args: dict[str, Any]) -> np.ndarray: """Compute mel spectrogram from waveform using librosa. Args: @@ -360,7 +360,7 @@ def _get_librosa_melspec(y: np.ndarray, sr: int, args: Dict[str, Any]) -> np.nda return np.stack([librosa.amplitude_to_db(m, ref=1.0, amin=1e-4, top_db=80.0) for m in melspec]) -def _segment_specs(x: Tensor, args: Dict[str, Any]) -> Tuple[Tensor, Tensor]: +def _segment_specs(x: Tensor, args: dict[str, Any]) -> tuple[Tensor, Tensor]: """Segment mel spectrogram into overlapping windows. Args: diff --git a/src/torchmetrics/functional/audio/pit.py b/src/torchmetrics/functional/audio/pit.py index a7cc72d48f7..2d126c7000f 100644 --- a/src/torchmetrics/functional/audio/pit.py +++ b/src/torchmetrics/functional/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import permutations -from typing import Any, Callable, Tuple +from typing import Any, Callable import numpy as np import torch @@ -42,7 +42,7 @@ def _gen_permutations(spk_num: int, device: torch.device) -> Tensor: def _find_best_perm_by_linear_sum_assignment( metric_mtx: Tensor, eval_func: Callable, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Solves the linear sum assignment problem. This implementation uses scipy and input is therefore transferred to cpu during calculations. @@ -68,7 +68,7 @@ def _find_best_perm_by_linear_sum_assignment( def _find_best_perm_by_exhaustive_method( metric_mtx: Tensor, eval_func: Callable, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Solves the linear sum assignment problem using exhaustive method. This is done by exhaustively calculating the metric values of all possible permutations, and returns the best metric @@ -111,7 +111,7 @@ def permutation_invariant_training( mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise", eval_func: Literal["max", "min"] = "max", **kwargs: Any, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Calculate `Permutation invariant training`_ (PIT). This metric can evaluate models for speaker independent multi-talker speech separation in a permutation diff --git a/src/torchmetrics/functional/audio/sdr.py b/src/torchmetrics/functional/audio/sdr.py index f6cb0bf2a2e..85da9597998 100644 --- a/src/torchmetrics/functional/audio/sdr.py +++ b/src/torchmetrics/functional/audio/sdr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -53,7 +53,7 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: ).flip(dims=(-1,)) -def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> Tuple[Tensor, Tensor]: +def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> tuple[Tensor, Tensor]: r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds`. This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz metric of the diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index 26d27f00999..20ad898fe8d 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -17,7 +17,7 @@ from functools import lru_cache from math import ceil, pi -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -56,7 +56,7 @@ def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.devi @lru_cache(maxsize=100) def _compute_modulation_filterbank_and_cutoffs( min_cf: float, max_cf: float, n: int, fs: float, q: int, device: torch.device -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: # this function is translated from the SRMRpy packaged spacing_factor = (max_cf / min_cf) ** (1.0 / (n - 1)) cfs = torch.zeros(n, dtype=torch.float64) @@ -73,7 +73,7 @@ def _make_modulation_filter(w0: Tensor, q: int) -> Tensor: mfb = torch.stack([_make_modulation_filter(w0, q) for w0 in 2 * pi * cfs / fs], dim=0) - def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> Tuple[Tensor, Tensor]: + def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> tuple[Tensor, Tensor]: # Calculates cutoff frequencies (3 dB) for 2nd order bandpass w0 = 2 * pi * cfs / fs b0 = torch.tan(w0 / 2) / q diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index fb802c05ec3..f88b79e2a28 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -72,7 +72,7 @@ def _reduce_auroc( def _binary_auroc_arg_validation( max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -81,7 +81,7 @@ def _binary_auroc_arg_validation( def _binary_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], max_fpr: Optional[float] = None, pos_label: int = 1, @@ -111,7 +111,7 @@ def binary_auroc( preds: Tensor, target: Tensor, max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -181,7 +181,7 @@ def binary_auroc( def _multiclass_auroc_arg_validation( num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -191,7 +191,7 @@ def _multiclass_auroc_arg_validation( def _multiclass_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Tensor] = None, @@ -210,7 +210,7 @@ def multiclass_auroc( target: Tensor, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -296,7 +296,7 @@ def multiclass_auroc( def _multilabel_auroc_arg_validation( num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -306,7 +306,7 @@ def _multilabel_auroc_arg_validation( def _multilabel_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Tensor], @@ -338,7 +338,7 @@ def multilabel_auroc( target: Tensor, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -429,7 +429,7 @@ def auroc( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 93002bb6d2b..1cddf8833ed 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -68,7 +68,7 @@ def _reduce_average_precision( def _binary_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], ) -> Tensor: precision, recall, _ = _binary_precision_recall_curve_compute(state, thresholds) @@ -78,7 +78,7 @@ def _binary_average_precision_compute( def binary_average_precision( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -152,7 +152,7 @@ def binary_average_precision( def _multiclass_average_precision_arg_validation( num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -162,7 +162,7 @@ def _multiclass_average_precision_arg_validation( def _multiclass_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Tensor] = None, @@ -181,7 +181,7 @@ def multiclass_average_precision( target: Tensor, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -272,7 +272,7 @@ def multiclass_average_precision( def _multilabel_average_precision_arg_validation( num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -282,7 +282,7 @@ def _multilabel_average_precision_arg_validation( def _multilabel_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Tensor], @@ -314,7 +314,7 @@ def multilabel_average_precision( target: Tensor, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -410,7 +410,7 @@ def average_precision( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index ebb5eecbe0d..ca55bb2f79b 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor @@ -28,7 +28,7 @@ def _binning_bucketize( confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute calibration bins using ``torch.bucketize``. Use for ``pytorch >=1.6``. Args: @@ -133,7 +133,7 @@ def _binary_calibration_error_tensor_validation( ) -def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: confidences, accuracies = preds, target return confidences, accuracies @@ -238,7 +238,7 @@ def _multiclass_calibration_error_tensor_validation( def _multiclass_calibration_error_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) confidences, predictions = preds.max(dim=1) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 1b93450ab69..92059072490 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -121,7 +121,7 @@ def _binary_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, convert_to_labels: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - Remove all datapoints that should be ignored @@ -300,7 +300,7 @@ def _multiclass_confusion_matrix_format( target: Tensor, ignore_index: Optional[int] = None, convert_to_labels: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - Applies argmax if preds have one more dimension than target @@ -482,7 +482,7 @@ def _multilabel_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, should_threshold: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 9e6f7c49df7..d22b4cc6f89 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -42,7 +42,7 @@ def _multiclass_exact_match_update( target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute the statistics.""" if ignore_index is not None: preds = preds.clone() @@ -123,7 +123,7 @@ def multiclass_exact_match( def _multilabel_exact_match_update( preds: Tensor, target: Tensor, num_labels: int, multidim_average: Literal["global", "samplewise"] = "global" -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute the statistics.""" if multidim_average == "global": preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels) diff --git a/src/torchmetrics/functional/classification/group_fairness.py b/src/torchmetrics/functional/classification/group_fairness.py index edccd4a6b2d..5e52e94733f 100644 --- a/src/torchmetrics/functional/classification/group_fairness.py +++ b/src/torchmetrics/functional/classification/group_fairness.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from typing_extensions import Literal @@ -57,7 +57,7 @@ def _binary_groups_stat_scores( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: +) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute the true/false positives and true/false negatives rates for binary classification by group. Related to `Type I and Type II errors`_. @@ -84,15 +84,15 @@ def _binary_groups_stat_scores( def _groups_reduce( - group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], -) -> Dict[str, torch.Tensor]: + group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], +) -> dict[str, torch.Tensor]: """Compute rates for all the group statistics.""" return {f"group_{group}": torch.stack(stats) / torch.stack(stats).sum() for group, stats in enumerate(group_stats)} def _groups_stat_transform( - group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], -) -> Dict[str, torch.Tensor]: + group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], +) -> dict[str, torch.Tensor]: """Transform group statistics by creating a tensor for each statistic.""" return { "tp": torch.stack([stat[0] for stat in group_stats]), @@ -110,7 +110,7 @@ def binary_groups_stat_rates( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""Compute the true/false positives and true/false negatives rates for binary classification by group. Related to `Type I and Type II errors`_. @@ -163,7 +163,7 @@ def binary_groups_stat_rates( def _compute_binary_demographic_parity( tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Compute demographic parity based on the binary stats.""" pos_rates = _safe_divide(tp + fp, tp + fp + tn + fn) min_pos_rate_id = torch.argmin(pos_rates) @@ -180,7 +180,7 @@ def demographic_parity( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""`Demographic parity`_ compares the positivity rates between all groups. If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest @@ -242,7 +242,7 @@ def demographic_parity( def _compute_binary_equal_opportunity( tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Compute equal opportunity based on the binary stats.""" true_pos_rates = _safe_divide(tp, tp + fn) min_pos_rate_id = torch.argmin(true_pos_rates) @@ -262,7 +262,7 @@ def equal_opportunity( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""`Equal opportunity`_ compares the true positive rates between all groups. If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest @@ -331,7 +331,7 @@ def binary_fairness( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""Compute either `Demographic parity`_ and `Equal opportunity`_ ratio for binary classification problems. This is done by setting the ``task`` argument to either ``'demographic_parity'``, ``'equal_opportunity'`` diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index c3fe40be105..8fe7cf840b8 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor, tensor @@ -51,7 +51,7 @@ def _binary_hinge_loss_update( preds: Tensor, target: Tensor, squared: bool, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] @@ -152,7 +152,7 @@ def _multiclass_hinge_loss_update( target: Tensor, squared: bool, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) diff --git a/src/torchmetrics/functional/classification/precision_fixed_recall.py b/src/torchmetrics/functional/classification/precision_fixed_recall.py index a708f703dc1..078d970299e 100644 --- a/src/torchmetrics/functional/classification/precision_fixed_recall.py +++ b/src/torchmetrics/functional/classification/precision_fixed_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor @@ -44,7 +44,7 @@ def _precision_at_recall( recall: Tensor, thresholds: Tensor, min_recall: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: try: max_precision, _, best_threshold = max( (p, r, t) for p, r, t in zip(precision, recall, thresholds) if r >= min_recall @@ -64,10 +64,10 @@ def binary_precision_at_fixed_recall( preds: Tensor, target: Tensor, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for binary tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -140,10 +140,10 @@ def multiclass_precision_at_fixed_recall( target: Tensor, num_classes: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for multiclass tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -226,10 +226,10 @@ def multilabel_precision_at_fixed_recall( target: Tensor, num_labels: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for multilabel tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -311,12 +311,12 @@ def precision_at_fixed_recall( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Optional[Tuple[Tensor, Tensor]]: +) -> Optional[tuple[Tensor, Tensor]]: r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index c4607fd9489..3c5a840efa1 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -31,7 +32,7 @@ def _binary_clf_curve( target: Tensor, sample_weights: Optional[Union[Sequence, Tensor]] = None, pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the TPs and false positives for all unique thresholds in the preds tensor. Adapted from @@ -82,7 +83,7 @@ def _binary_clf_curve( def _adjust_threshold_arg( - thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None + thresholds: Optional[Union[int, list[float], Tensor]] = None, device: Optional[torch.device] = None ) -> Optional[Tensor]: """Convert threshold arg for list and int to tensor format.""" if isinstance(thresholds, int): @@ -93,7 +94,7 @@ def _adjust_threshold_arg( def _binary_precision_recall_curve_arg_validation( - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: """Validate non tensor input. @@ -163,9 +164,9 @@ def _binary_precision_recall_curve_tensor_validation( def _binary_precision_recall_curve_format( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -192,7 +193,7 @@ def _binary_precision_recall_curve_update( preds: Tensor, target: Tensor, thresholds: Optional[Tensor], -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -212,7 +213,7 @@ def _binary_precision_recall_curve_update_vectorized( preds: Tensor, target: Tensor, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small @@ -230,7 +231,7 @@ def _binary_precision_recall_curve_update_loop( preds: Tensor, target: Tensor, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation loops over thresholds and is more memory-efficient than @@ -252,10 +253,10 @@ def _binary_precision_recall_curve_update_loop( def _binary_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -293,10 +294,10 @@ def _binary_precision_recall_curve_compute( def binary_precision_recall_curve( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Compute the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -368,7 +369,7 @@ def binary_precision_recall_curve( def _multiclass_precision_recall_curve_arg_validation( num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ) -> None: @@ -431,10 +432,10 @@ def _multiclass_precision_recall_curve_format( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -468,7 +469,7 @@ def _multiclass_precision_recall_curve_update( num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -491,7 +492,7 @@ def _multiclass_precision_recall_curve_update_vectorized( target: Tensor, num_classes: int, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small @@ -513,7 +514,7 @@ def _multiclass_precision_recall_curve_update_loop( target: Tensor, num_classes: int, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. This implementation loops over thresholds and is more memory-efficient than @@ -535,11 +536,11 @@ def _multiclass_precision_recall_curve_update_loop( def _multiclass_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -594,11 +595,11 @@ def multiclass_precision_recall_curve( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -711,7 +712,7 @@ def multiclass_precision_recall_curve( def _multilabel_precision_recall_curve_arg_validation( num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: """Validate non tensor input. @@ -747,9 +748,9 @@ def _multilabel_precision_recall_curve_format( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -780,7 +781,7 @@ def _multilabel_precision_recall_curve_update( target: Tensor, num_labels: int, thresholds: Optional[Tensor], -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -801,11 +802,11 @@ def _multilabel_precision_recall_curve_update( def _multilabel_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -841,10 +842,10 @@ def multilabel_precision_recall_curve( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -946,13 +947,13 @@ def precision_recall_curve( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index 87bade5e88d..57dd389ec3a 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -45,7 +45,7 @@ def _multilabel_ranking_tensor_validation( raise ValueError(f"Expected preds tensor to be floating point, but received input with dtype {preds.dtype}") -def _multilabel_coverage_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_coverage_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for coverage error.""" offset = torch.zeros_like(preds) offset[target == 0] = preds.min().abs() + 10 # Any number >1 works @@ -109,7 +109,7 @@ def multilabel_coverage_error( return _ranking_reduce(coverage, total) -def _multilabel_ranking_average_precision_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_ranking_average_precision_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for label ranking average precision.""" # Invert so that the highest score receives rank 1 neg_preds = -preds @@ -182,7 +182,7 @@ def multilabel_ranking_average_precision( return _ranking_reduce(score, num_elements) -def _multilabel_ranking_loss_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_ranking_loss_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for label ranking loss. Args: diff --git a/src/torchmetrics/functional/classification/recall_fixed_precision.py b/src/torchmetrics/functional/classification/recall_fixed_precision.py index 745a9de2c34..196209bce21 100644 --- a/src/torchmetrics/functional/classification/recall_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_fixed_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from torch import Tensor @@ -60,7 +60,7 @@ def _recall_at_precision( recall: Tensor, thresholds: Tensor, min_precision: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) best_threshold = torch.tensor(0) @@ -78,7 +78,7 @@ def _recall_at_precision( def _binary_recall_at_fixed_precision_arg_validation( min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -89,12 +89,12 @@ def _binary_recall_at_fixed_precision_arg_validation( def _binary_recall_at_fixed_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_precision: float, pos_label: int = 1, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _binary_precision_recall_curve_compute(state, thresholds, pos_label) return reduce_fn(precision, recall, thresholds, min_precision) @@ -103,10 +103,10 @@ def binary_recall_at_fixed_precision( preds: Tensor, target: Tensor, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for binary tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall @@ -175,7 +175,7 @@ def binary_recall_at_fixed_precision( def _multiclass_recall_at_fixed_precision_arg_validation( num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -186,12 +186,12 @@ def _multiclass_recall_at_fixed_precision_arg_validation( def _multiclass_recall_at_fixed_precision_arg_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_precision: float, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) if isinstance(state, Tensor): res = [reduce_fn(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] @@ -207,10 +207,10 @@ def multiclass_recall_at_fixed_precision( target: Tensor, num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for multiclass tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a @@ -285,7 +285,7 @@ def multiclass_recall_at_fixed_precision( def _multilabel_recall_at_fixed_precision_arg_validation( num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -296,13 +296,13 @@ def _multilabel_recall_at_fixed_precision_arg_validation( def _multilabel_recall_at_fixed_precision_arg_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_precision: float, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _multilabel_precision_recall_curve_compute( state, num_labels, thresholds, ignore_index ) @@ -320,10 +320,10 @@ def multilabel_recall_at_fixed_precision( target: Tensor, num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for multilabel tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a @@ -403,12 +403,12 @@ def recall_at_fixed_precision( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Optional[Tuple[Tensor, Tensor]]: +) -> Optional[tuple[Tensor, Tensor]]: r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index d61b920aa9b..741733ec98f 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -38,10 +38,10 @@ def _binary_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: if isinstance(state, Tensor) and thresholds is not None: tps = state[:, 1, 1] fps = state[:, 0, 1] @@ -83,10 +83,10 @@ def _binary_roc_compute( def binary_roc( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Compute the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -160,11 +160,11 @@ def binary_roc( def _multiclass_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: if average == "micro": return _binary_roc_compute(state, thresholds, pos_label=1) @@ -208,11 +208,11 @@ def multiclass_roc( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multiclass tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -327,11 +327,11 @@ def multiclass_roc( def _multilabel_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -360,10 +360,10 @@ def multilabel_roc( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multilabel tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -472,13 +472,13 @@ def roc( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index b1f7e456b06..546984262b3 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -49,7 +49,7 @@ def _sensitivity_at_specificity( specificity: Tensor, thresholds: Tensor, min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: # get indices where specificity is greater than min_specificity indices = specificity >= min_specificity @@ -72,7 +72,7 @@ def _sensitivity_at_specificity( def _binary_sensitivity_at_specificity_arg_validation( min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -83,11 +83,11 @@ def _binary_sensitivity_at_specificity_arg_validation( def _binary_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_specificity: float, pos_label: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) specificity = _convert_fpr_to_specificity(fpr) return _sensitivity_at_specificity(sensitivity, specificity, thresholds, min_specificity) @@ -97,10 +97,10 @@ def binary_sensitivity_at_specificity( preds: Tensor, target: Tensor, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given the minimum specificity levels provided for binary tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -169,7 +169,7 @@ def binary_sensitivity_at_specificity( def _multiclass_sensitivity_at_specificity_arg_validation( num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -180,11 +180,11 @@ def _multiclass_sensitivity_at_specificity_arg_validation( def _multiclass_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -207,10 +207,10 @@ def multiclass_sensitivity_at_specificity( target: Tensor, num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given minimum specificity level provided for multiclass tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the @@ -285,7 +285,7 @@ def multiclass_sensitivity_at_specificity( def _multilabel_sensitivity_at_specificity_arg_validation( num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -296,12 +296,12 @@ def _multilabel_sensitivity_at_specificity_arg_validation( def _multilabel_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -324,10 +324,10 @@ def multilabel_sensitivity_at_specificity( target: Tensor, num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given minimum specificity level provided for multilabel tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -408,12 +408,12 @@ def sensitivity_at_specificity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index d85b47eb453..07e7bc4243c 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -50,7 +50,7 @@ def _specificity_at_sensitivity( sensitivity: Tensor, thresholds: Tensor, min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: # get indices where sensitivity is greater than min_sensitivity indices = sensitivity >= min_sensitivity @@ -73,7 +73,7 @@ def _specificity_at_sensitivity( def _binary_specificity_at_sensitivity_arg_validation( min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -84,11 +84,11 @@ def _binary_specificity_at_sensitivity_arg_validation( def _binary_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_sensitivity: float, pos_label: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) specificity = _convert_fpr_to_specificity(fpr) return _specificity_at_sensitivity(specificity, sensitivity, thresholds, min_sensitivity) @@ -98,10 +98,10 @@ def binary_specificity_at_sensitivity( preds: Tensor, target: Tensor, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given the minimum sensitivity levels provided for binary tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -170,7 +170,7 @@ def binary_specificity_at_sensitivity( def _multiclass_specificity_at_sensitivity_arg_validation( num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -181,11 +181,11 @@ def _multiclass_specificity_at_sensitivity_arg_validation( def _multiclass_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -208,10 +208,10 @@ def multiclass_specificity_at_sensitivity( target: Tensor, num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given minimum sensitivity level provided for multiclass tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the @@ -286,7 +286,7 @@ def multiclass_specificity_at_sensitivity( def _multilabel_specificity_at_sensitivity_arg_validation( num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -297,12 +297,12 @@ def _multilabel_specificity_at_sensitivity_arg_validation( def _multilabel_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -325,10 +325,10 @@ def multilabel_specificity_at_sensitivity( target: Tensor, num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given minimum sensitivity level provided for multilabel tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -409,12 +409,12 @@ def specicity_at_sensitivity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. .. warning:: @@ -445,12 +445,12 @@ def specificity_at_sensitivity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 565c212f9bd..e9d78b1a1b7 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor @@ -96,7 +96,7 @@ def _binary_stat_scores_format( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range @@ -125,7 +125,7 @@ def _binary_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics.""" sum_dim = [0, 1] if multidim_average == "global" else [1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() @@ -326,7 +326,7 @@ def _multiclass_stat_scores_format( preds: Tensor, target: Tensor, top_k: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format except if ``top_k`` is not 1. - Applies argmax if preds have one more dimension than target @@ -349,7 +349,7 @@ def _multiclass_stat_scores_update( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics. - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and @@ -650,7 +650,7 @@ def _multilabel_stat_scores_tensor_validation( def _multilabel_stat_scores_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range @@ -675,7 +675,7 @@ def _multilabel_stat_scores_format( def _multilabel_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global" -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics.""" sum_dim = [0, -1] if multidim_average == "global" else [-1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() @@ -828,7 +828,7 @@ def _del_column(data: Tensor, idx: int) -> Tensor: def _drop_negative_ignored_indices( preds: Tensor, target: Tensor, ignore_index: int, mode: DataType -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Remove negative ignored indices. Args: @@ -866,7 +866,7 @@ def _stat_scores( preds: Tensor, target: Tensor, reduce: Optional[str] = "micro", -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate the number of tp, fp, tn, fn. Args: @@ -892,7 +892,7 @@ def _stat_scores( - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors """ - dim: Union[int, List[int]] = 1 # for "samples" + dim: Union[int, list[int]] = 1 # for "samples" if reduce == "micro": dim = [0, 1] if preds.ndim == 2 else [1, 2] elif reduce == "macro": @@ -921,7 +921,7 @@ def _stat_scores_update( multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, mode: Optional[DataType] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate true positives, false positives, true negatives, false negatives. Raises: diff --git a/src/torchmetrics/functional/clustering/dunn_index.py b/src/torchmetrics/functional/clustering/dunn_index.py index 51120697002..ac073b7c273 100644 --- a/src/torchmetrics/functional/clustering/dunn_index.py +++ b/src/torchmetrics/functional/clustering/dunn_index.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import combinations -from typing import Tuple import torch from torch import Tensor -def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> Tuple[Tensor, Tensor]: +def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> tuple[Tensor, Tensor]: """Update and return variables required to compute the Dunn index. Args: diff --git a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py index e7faae5175e..88b9288f9ed 100644 --- a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor, tensor @@ -19,7 +18,7 @@ from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels -def _fowlkes_mallows_index_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _fowlkes_mallows_index_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Return contingency matrix required to compute the Fowlkes-Mallows index. Args: diff --git a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py index e98f1e26b5b..a85d9ab2a11 100644 --- a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -20,7 +19,7 @@ from torchmetrics.functional.clustering.utils import calculate_entropy, check_cluster_labels -def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the homogeneity score of a clustering given the predicted and target cluster labels.""" check_cluster_labels(preds, target) @@ -36,7 +35,7 @@ def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, T return homogeneity, mutual_info, entropy_preds, entropy_target -def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _completeness_score_compute(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Computes the completeness score of a clustering given the predicted and target cluster labels.""" homogeneity, mutual_info, entropy_preds, _ = _homogeneity_score_compute(preds, target) completeness = mutual_info / entropy_preds if entropy_preds else torch.ones_like(entropy_preds) diff --git a/src/torchmetrics/functional/detection/_deprecated.py b/src/torchmetrics/functional/detection/_deprecated.py index ce0e1ba6acf..46c8acdab56 100644 --- a/src/torchmetrics/functional/detection/_deprecated.py +++ b/src/torchmetrics/functional/detection/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Collection +from collections.abc import Collection from torch import Tensor diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index c94978e3435..16d0463ba45 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, Iterator, List, Optional, Set, Tuple, cast +from collections.abc import Collection, Iterator +from typing import Optional, cast import torch from torch import Tensor from torchmetrics.utilities import rank_zero_warn -_Color = Tuple[int, int] # A (category_id, instance_id) tuple that uniquely identifies a panoptic segment. +_Color = tuple[int, int] # A (category_id, instance_id) tuple that uniquely identifies a panoptic segment. -def _nested_tuple(nested_list: List) -> Tuple: +def _nested_tuple(nested_list: list) -> tuple: """Construct a nested tuple from a nested list. Args: @@ -34,7 +35,7 @@ def _nested_tuple(nested_list: List) -> Tuple: return tuple(map(_nested_tuple, nested_list)) if isinstance(nested_list, list) else nested_list -def _to_tuple(t: Tensor) -> Tuple: +def _to_tuple(t: Tensor) -> tuple: """Convert a tensor into a nested tuple. Args: @@ -47,7 +48,7 @@ def _to_tuple(t: Tensor) -> Tuple: return _nested_tuple(t.tolist()) -def _get_color_areas(inputs: Tensor) -> Dict[Tuple, Tensor]: +def _get_color_areas(inputs: Tensor) -> dict[tuple, Tensor]: """Measure the size of each instance. Args: @@ -62,7 +63,7 @@ def _get_color_areas(inputs: Tensor) -> Dict[Tuple, Tensor]: return dict(zip(_to_tuple(unique_keys), unique_keys_area)) -def _parse_categories(things: Collection[int], stuffs: Collection[int]) -> Tuple[Set[int], Set[int]]: +def _parse_categories(things: Collection[int], stuffs: Collection[int]) -> tuple[set[int], set[int]]: """Parse and validate metrics arguments for `things` and `stuff`. Args: @@ -121,7 +122,7 @@ def _validate_inputs(preds: Tensor, target: torch.Tensor) -> None: ) -def _get_void_color(things: Set[int], stuffs: Set[int]) -> Tuple[int, int]: +def _get_void_color(things: set[int], stuffs: set[int]) -> tuple[int, int]: """Get an unused color ID. Args: @@ -136,7 +137,7 @@ def _get_void_color(things: Set[int], stuffs: Set[int]) -> Tuple[int, int]: return unused_category_id, 0 -def _get_category_id_to_continuous_id(things: Set[int], stuffs: Set[int]) -> Dict[int, int]: +def _get_category_id_to_continuous_id(things: set[int], stuffs: set[int]) -> dict[int, int]: """Convert original IDs to continuous IDs. Args: @@ -157,7 +158,7 @@ def _get_category_id_to_continuous_id(things: Set[int], stuffs: Set[int]) -> Dic return cat_id_to_continuous_id -def _isin(arr: Tensor, values: List) -> Tensor: +def _isin(arr: Tensor, values: list) -> Tensor: """Check if all values of an arr are in another array. Implementation of torch.isin to support pre 0.10 version. Args: @@ -173,10 +174,10 @@ def _isin(arr: Tensor, values: List) -> Tensor: def _prepocess_inputs( - things: Set[int], - stuffs: Set[int], + things: set[int], + stuffs: set[int], inputs: Tensor, - void_color: Tuple[int, int], + void_color: tuple[int, int], allow_unknown_category: bool, ) -> Tensor: """Preprocesses an input tensor for metric calculation. @@ -214,9 +215,9 @@ def _prepocess_inputs( def _calculate_iou( pred_color: _Color, target_color: _Color, - pred_areas: Dict[_Color, Tensor], - target_areas: Dict[_Color, Tensor], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], + pred_areas: dict[_Color, Tensor], + target_areas: dict[_Color, Tensor], + intersection_areas: dict[tuple[_Color, _Color], Tensor], void_color: _Color, ) -> Tensor: """Helper function that calculates the IoU from precomputed areas of segments and their intersections. @@ -252,10 +253,10 @@ def _calculate_iou( def _filter_false_negatives( - target_areas: Dict[_Color, Tensor], - target_segment_matched: Set[_Color], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], - void_color: Tuple[int, int], + target_areas: dict[_Color, Tensor], + target_segment_matched: set[_Color], + intersection_areas: dict[tuple[_Color, _Color], Tensor], + void_color: tuple[int, int], ) -> Iterator[int]: """Filter false negative segments and yield their category IDs. @@ -281,10 +282,10 @@ def _filter_false_negatives( def _filter_false_positives( - pred_areas: Dict[_Color, Tensor], - pred_segment_matched: Set[_Color], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], - void_color: Tuple[int, int], + pred_areas: dict[_Color, Tensor], + pred_segment_matched: set[_Color], + intersection_areas: dict[tuple[_Color, _Color], Tensor], + void_color: tuple[int, int], ) -> Iterator[int]: """Filter false positive segments and yield their category IDs. @@ -312,10 +313,10 @@ def _filter_false_positives( def _panoptic_quality_update_sample( flatten_preds: Tensor, flatten_target: Tensor, - cat_id_to_continuous_id: Dict[int, int], - void_color: Tuple[int, int], - stuffs_modified_metric: Optional[Set[int]] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + cat_id_to_continuous_id: dict[int, int], + void_color: tuple[int, int], + stuffs_modified_metric: Optional[set[int]] = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate stat scores required to compute the metric **for a single sample**. Computed scores: iou sum, true positives, false positives, false negatives. @@ -351,11 +352,11 @@ def _panoptic_quality_update_sample( # calculate the area of each prediction, ground truth and pairwise intersection. # NOTE: mypy needs `cast()` because the annotation for `_get_color_areas` is too generic. - pred_areas = cast(Dict[_Color, Tensor], _get_color_areas(flatten_preds)) - target_areas = cast(Dict[_Color, Tensor], _get_color_areas(flatten_target)) + pred_areas = cast(dict[_Color, Tensor], _get_color_areas(flatten_preds)) + target_areas = cast(dict[_Color, Tensor], _get_color_areas(flatten_target)) # intersection matrix of shape [num_pixels, 2, 2] intersection_matrix = torch.transpose(torch.stack((flatten_preds, flatten_target), -1), -1, -2) - intersection_areas = cast(Dict[Tuple[_Color, _Color], Tensor], _get_color_areas(intersection_matrix)) + intersection_areas = cast(dict[tuple[_Color, _Color], Tensor], _get_color_areas(intersection_matrix)) # select intersection of things of same category with iou > 0.5 pred_segment_matched = set() @@ -397,10 +398,10 @@ def _panoptic_quality_update_sample( def _panoptic_quality_update( flatten_preds: Tensor, flatten_target: Tensor, - cat_id_to_continuous_id: Dict[int, int], - void_color: Tuple[int, int], - modified_metric_stuffs: Optional[Set[int]] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + cat_id_to_continuous_id: dict[int, int], + void_color: tuple[int, int], + modified_metric_stuffs: Optional[set[int]] = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate stat scores required to compute the metric for a full batch. Computed scores: iou sum, true positives, false positives, false negatives. @@ -449,7 +450,7 @@ def _panoptic_quality_compute( true_positives: Tensor, false_positives: Tensor, false_negatives: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Compute the final panoptic quality from interim values. Args: diff --git a/src/torchmetrics/functional/detection/panoptic_qualities.py b/src/torchmetrics/functional/detection/panoptic_qualities.py index 019d243f0ba..c1b3c7f27b6 100644 --- a/src/torchmetrics/functional/detection/panoptic_qualities.py +++ b/src/torchmetrics/functional/detection/panoptic_qualities.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection +from collections.abc import Collection import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 892d07afaa6..8ee36b77a43 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Union from torch import Tensor from typing_extensions import Literal @@ -57,7 +58,7 @@ def _error_relative_global_dimensionless_synthesis( return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction) -def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def _image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Wrapper for deprecated import. >>> import torch @@ -79,10 +80,10 @@ def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: def _peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, ) -> Tensor: """Wrapper for deprecated import. @@ -115,7 +116,7 @@ def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: def _root_mean_squared_error_using_sliding_window( preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False -) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: +) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]: """Wrapper for deprecated import. >>> from torch import rand @@ -156,10 +157,10 @@ def _multiscale_structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Optional[Literal["relu", "simple"]] = "relu", ) -> Tensor: """Wrapper for deprecated import. @@ -194,12 +195,12 @@ def _structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> import torch diff --git a/src/torchmetrics/functional/image/d_lambda.py b/src/torchmetrics/functional/image/d_lambda.py index 5921f51d32d..478455c0a68 100644 --- a/src/torchmetrics/functional/image/d_lambda.py +++ b/src/torchmetrics/functional/image/d_lambda.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -22,7 +21,7 @@ from torchmetrics.utilities.distributed import reduce -def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spectral Distortion Index. Args: diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 824cd3bfa8b..1c2797a2aab 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -28,7 +28,7 @@ def _spatial_distortion_index_update( preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: """Update and returns variables required to compute Spatial Distortion Index. Args: diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index 41500c552e0..ae940250dbe 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -21,7 +20,7 @@ from torchmetrics.utilities.distributed import reduce -def _ergas_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ergas_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Erreur Relative Globale Adimensionnelle de Synthèse. Args: diff --git a/src/torchmetrics/functional/image/gradients.py b/src/torchmetrics/functional/image/gradients.py index e87d4a75198..68045663fa9 100644 --- a/src/torchmetrics/functional/image/gradients.py +++ b/src/torchmetrics/functional/image/gradients.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -25,7 +24,7 @@ def _image_gradients_validate(img: Tensor) -> None: raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor") -def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def _compute_image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Compute image gradients (dy/dx) for a given image.""" batch_size, channels, height, width = img.shape @@ -43,7 +42,7 @@ def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: return dy, dx -def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Compute `Gradient Computation of Image`_ of a given image using finite difference. Args: diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index c557f61ead1..0e07cec3d28 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -24,7 +24,7 @@ # License under BSD 2-clause import inspect import os -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import List, NamedTuple, Optional, Union import torch from torch import Tensor, nn @@ -205,7 +205,7 @@ def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor: return in_tens.mean([2, 3], keepdim=keep_dim) -def _upsample(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor: +def _upsample(in_tens: Tensor, out_hw: tuple[int, ...] = (64, 64)) -> Tensor: """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) @@ -331,7 +331,7 @@ def __init__( def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False - ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: + ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 @@ -378,7 +378,7 @@ def _valid_img(img: Tensor, normalize: bool) -> bool: return img.ndim == 4 and img.shape[1] == 3 and value_check # type: ignore[return-value] -def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> Tuple[Tensor, Union[int, Tensor]]: +def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> tuple[Tensor, Union[int, Tensor]]: if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)): raise ValueError( "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." diff --git a/src/torchmetrics/functional/image/perceptual_path_length.py b/src/torchmetrics/functional/image/perceptual_path_length.py index 58b0a7bae05..035425539d2 100644 --- a/src/torchmetrics/functional/image/perceptual_path_length.py +++ b/src/torchmetrics/functional/image/perceptual_path_length.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import torch from torch import Tensor, nn @@ -162,7 +162,7 @@ def perceptual_path_length( upper_discard: Optional[float] = 0.99, sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg", device: Union[str, torch.device] = "cpu", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Computes the perceptual path length (`PPL`_) of a generator model. The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 7bd93ba94e1..e058d34e7b6 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor @@ -58,8 +58,8 @@ def _psnr_compute( def _psnr_update( preds: Tensor, target: Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None, -) -> Tuple[Tensor, Tensor]: + dim: Optional[Union[int, tuple[int, ...]]] = None, +) -> tuple[Tensor, Tensor]: """Update and return variables required to compute peak signal-to-noise ratio. Args: @@ -95,10 +95,10 @@ def _psnr_update( def peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, ) -> Tensor: """Compute the peak signal-to-noise ratio. diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index f725ea4bd80..4a8469e19fe 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple import torch from torch import Tensor, tensor @@ -86,7 +85,7 @@ def _psnrb_compute( return 10 * torch.log10(1.0 / sum_squared_error) -def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> Tuple[Tensor, Tensor, Tensor]: +def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[Tensor, Tensor, Tensor]: """Updates and returns variables required to compute peak signal-to-noise ratio. Args: diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index 388f2c237a3..832a759b562 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -23,7 +22,7 @@ def _rase_update( preds: Tensor, target: Tensor, window_size: int, rmse_map: Tensor, target_sum: Tensor, total_images: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the sum of RMSE map values for the batch of examples and update intermediate states. Args: diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index a27582bd11a..3b0eaf6221f 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor @@ -28,7 +28,7 @@ def _rmse_sw_update( rmse_val_sum: Optional[Tensor], rmse_map: Optional[Tensor], total_images: Optional[Tensor], -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the sum of RMSE values and RMSE map for the batch of examples and update intermediate states. Args: @@ -89,7 +89,7 @@ def _rmse_sw_update( def _rmse_sw_compute( rmse_val_sum: Optional[Tensor], rmse_map: Tensor, total_images: Tensor -) -> Tuple[Optional[Tensor], Tensor]: +) -> tuple[Optional[Tensor], Tensor]: """Compute RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map. Args: @@ -111,7 +111,7 @@ def _rmse_sw_compute( def root_mean_squared_error_using_sliding_window( preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False -) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: +) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]: """Compute Root Mean Squared Error (RMSE) using sliding window. Args: diff --git a/src/torchmetrics/functional/image/sam.py b/src/torchmetrics/functional/image/sam.py index 71927ff6b42..21efde6f9c4 100644 --- a/src/torchmetrics/functional/image/sam.py +++ b/src/torchmetrics/functional/image/sam.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -21,7 +20,7 @@ from torchmetrics.utilities.distributed import reduce -def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _sam_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spectral Angle Mapper. Args: diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index a4b7bd85725..d680a22b676 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor @@ -23,7 +23,7 @@ from torchmetrics.utilities.distributed import reduce -def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]: +def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> tuple[Tensor, Tensor, Tensor]: """Update and returns variables required to compute Spatial Correlation Coefficient. Args: @@ -73,7 +73,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i return preds, target, hp_filter -def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: +def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, tuple[int, ...]]) -> Tensor: """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) @@ -106,7 +106,7 @@ def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: return _signal_convolve_2d(input_img, kernel) * 2.0 -def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Computes local variance and covariance of the input tensors.""" # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index c61ef833fe3..ccaafe66065 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import List, Optional, Union import torch from torch import Tensor @@ -23,7 +24,7 @@ from torchmetrics.utilities.distributed import reduce -def _ssim_check_inputs(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ssim_check_inputs(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Structural Similarity Index Measure. Args: @@ -48,12 +49,12 @@ def _ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Structural Similarity Index Measure. Args: @@ -213,12 +214,12 @@ def structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Structural Similarity Index Measure. Args: @@ -297,11 +298,11 @@ def _get_normalized_sim_and_cs( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, normalize: Optional[Literal["relu", "simple"]] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: sim, contrast_sensitivity = _ssim_update( preds, target, @@ -325,10 +326,10 @@ def _multiscale_ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = ( + betas: Union[tuple[float, float, float, float, float], tuple[float, ...]] = ( 0.0448, 0.2856, 0.3001, @@ -452,10 +453,10 @@ def multiscale_structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Optional[Literal["relu", "simple"]] = "relu", ) -> Tensor: """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure. diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index be8e3366caa..a0f310d498d 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union from torch import Tensor from typing_extensions import Literal -def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: +def _total_variation_update(img: Tensor) -> tuple[Tensor, int]: """Compute total variation statistics on current batch.""" if img.ndim != 4: raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index ed8bf39742b..30b5e781e55 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional import torch from torch import Tensor, nn @@ -22,7 +23,7 @@ from torchmetrics.utilities.distributed import reduce -def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _uqi_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Universal Image Quality Index. Args: diff --git a/src/torchmetrics/functional/image/utils.py b/src/torchmetrics/functional/image/utils.py index bf09ff79249..24ed9cd0de8 100644 --- a/src/torchmetrics/functional/image/utils.py +++ b/src/torchmetrics/functional/image/utils.py @@ -1,4 +1,5 @@ -from typing import Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Union import torch from torch import Tensor @@ -56,7 +57,7 @@ def _gaussian_kernel_2d( return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) -def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> Tuple[Tensor, Tensor]: +def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> tuple[Tensor, Tensor]: """Construct uniform weight and bias for a 2d convolution. Args: diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 659c667fc52..99985c5b022 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Literal, Union import torch from torch import Tensor @@ -40,7 +40,7 @@ def _download_clip_for_iqa_metric() -> None: if not _PIQ_GREATER_EQUAL_0_8: __doctest_skip__ = ["clip_image_quality_assessment"] -_PROMPTS: Dict[str, Tuple[str, str]] = { +_PROMPTS: dict[str, tuple[str, str]] = { "quality": ("Good photo.", "Bad photo."), "brightness": ("Bright photo.", "Dark photo."), "noisiness": ("Clean photo.", "Noisy photo."), @@ -68,7 +68,7 @@ def _get_clip_iqa_model_and_processor( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ], -) -> Tuple["_CLIPModel", "_CLIPProcessor"]: +) -> tuple["_CLIPModel", "_CLIPProcessor"]: """Extract the CLIP model and processor from the model name or path.""" from transformers import CLIPProcessor as _CLIPProcessor @@ -89,7 +89,7 @@ def _get_clip_iqa_model_and_processor( return _get_clip_model_and_processor(model_name_or_path) -def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]: +def _clip_iqa_format_prompts(prompts: tuple[Union[str, tuple[str, str]]] = ("quality",)) -> tuple[list[str], list[str]]: """Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors. Args: @@ -119,8 +119,8 @@ def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("qua if not isinstance(prompts, tuple): raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") - prompts_names: List[str] = [] - prompts_list: List[str] = [] + prompts_names: list[str] = [] + prompts_list: list[str] = [] count = 0 for p in prompts: if not isinstance(p, (str, tuple)): @@ -146,7 +146,7 @@ def _clip_iqa_get_anchor_vectors( model_name_or_path: str, model: "_CLIPModel", processor: "_CLIPProcessor", - prompts_list: List[str], + prompts_list: list[str], device: Union[str, torch.device], ) -> Tensor: """Calculates the anchor vectors for the CLIP IQA metric. @@ -202,9 +202,9 @@ def _clip_iqa_update( def _clip_iqa_compute( img_features: Tensor, anchors: Tensor, - prompts_names: List[str], + prompts_names: list[str], format_as_dict: bool = True, -) -> Union[Tensor, Dict[str, Tensor]]: +) -> Union[Tensor, dict[str, Tensor]]: """Final computation of CLIP IQA.""" logits_per_image = 100 * img_features @ anchors.t() probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0] @@ -225,8 +225,8 @@ def clip_image_quality_assessment( "openai/clip-vit-large-patch14", ] = "clip_iqa", data_range: float = 1.0, - prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), -) -> Union[Tensor, Dict[str, Tensor]]: + prompts: tuple[Union[str, tuple[str, str]]] = ("quality",), +) -> Union[Tensor, dict[str, Tensor]]: """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index f70f37d534b..070d81bf54c 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, List, Union import torch from torch import Tensor @@ -43,10 +43,10 @@ def _download_clip_for_clip_score() -> None: def _clip_score_update( images: Union[Tensor, List[Tensor]], - text: Union[str, List[str]], + text: Union[str, list[str]], model: _CLIPModel, processor: _CLIPProcessor, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: if not isinstance(images, list): if images.ndim == 3: images = [images] @@ -97,7 +97,7 @@ def _get_clip_model_and_processor( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ] = "openai/clip-vit-large-patch14", -) -> Tuple[_CLIPModel, _CLIPProcessor]: +) -> tuple[_CLIPModel, _CLIPProcessor]: if _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -114,7 +114,7 @@ def _get_clip_model_and_processor( def clip_score( images: Union[Tensor, List[Tensor]], - text: Union[str, List[str]], + text: Union[str, list[str]], model_name_or_path: Literal[ "openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32", diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py index 8c8cc166778..9d8dd8dc4af 100644 --- a/src/torchmetrics/functional/nominal/utils.py +++ b/src/torchmetrics/functional/nominal/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -93,7 +93,7 @@ def _compute_phi_squared_corrected( ) -def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> Tuple[Tensor, Tensor]: +def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> tuple[Tensor, Tensor]: """Compute bias-corrected number of rows and columns.""" rows_corrected = num_rows - (num_rows - 1) ** 2 / (confmat_sum - 1) cols_corrected = num_cols - (num_cols - 1) ** 2 / (confmat_sum - 1) @@ -102,7 +102,7 @@ def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: def _compute_bias_corrected_values( phi_squared: Tensor, num_rows: int, num_cols: int, confmat_sum: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute bias-corrected Phi Squared and number of rows and columns.""" phi_squared_corrected = _compute_phi_squared_corrected(phi_squared, num_rows, num_cols, confmat_sum) rows_corrected, cols_corrected = _compute_rows_and_cols_corrected(num_rows, num_cols, confmat_sum) @@ -114,7 +114,7 @@ def _handle_nan_in_data( target: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[float] = 0.0, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Handle ``NaN`` values in input data. If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``. diff --git a/src/torchmetrics/functional/pairwise/helpers.py b/src/torchmetrics/functional/pairwise/helpers.py index bc6ba79a81a..703b5ddb083 100644 --- a/src/torchmetrics/functional/pairwise/helpers.py +++ b/src/torchmetrics/functional/pairwise/helpers.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional from torch import Tensor def _check_input( x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None -) -> Tuple[Tensor, Tensor, bool]: +) -> tuple[Tensor, Tensor, bool]: """Check that input has the right dimensionality and sets the ``zero_diagonal`` argument if user has not set it. Args: diff --git a/src/torchmetrics/functional/regression/cosine_similarity.py b/src/torchmetrics/functional/regression/cosine_similarity.py index 8f9460010e7..c57623931a4 100644 --- a/src/torchmetrics/functional/regression/cosine_similarity.py +++ b/src/torchmetrics/functional/regression/cosine_similarity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -22,7 +22,7 @@ def _cosine_similarity_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Cosine Similarity. Checks for same shape of input tensors. Args: diff --git a/src/torchmetrics/functional/regression/csi.py b/src/torchmetrics/functional/regression/csi.py index 8d43c6012fd..65d38e6f573 100644 --- a/src/torchmetrics/functional/regression/csi.py +++ b/src/torchmetrics/functional/regression/csi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -22,7 +22,7 @@ def _critical_success_index_update( preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute Critical Success Index. Checks for same shape of tensors. Args: diff --git a/src/torchmetrics/functional/regression/explained_variance.py b/src/torchmetrics/functional/regression/explained_variance.py index a6a6c4ff209..d401bb5a349 100644 --- a/src/torchmetrics/functional/regression/explained_variance.py +++ b/src/torchmetrics/functional/regression/explained_variance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Union import torch from torch import Tensor @@ -22,7 +23,7 @@ ALLOWED_MULTIOUTPUT = ("raw_values", "uniform_average", "variance_weighted") -def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]: +def _explained_variance_update(preds: Tensor, target: Tensor) -> tuple[int, Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Explained Variance. Checks for same shape of input tensors. Args: diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index c46173eae00..80afcc23b69 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -47,7 +47,7 @@ def _name() -> str: return "alternative" -def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: +def _sort_on_first_sequence(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" # We need to clone `y` tensor not to change an object in memory y = torch.clone(y) @@ -94,7 +94,7 @@ def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: return _cumsum(torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0), dim=0) -def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _get_ties(x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Get a total number of ties and staistics for p-value calculation for a given sequence.""" ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) @@ -111,7 +111,7 @@ def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def _get_metric_metadata( preds: Tensor, target: Tensor, variant: _MetricVariant -) -> Tuple[ +) -> tuple[ Tensor, Tensor, Optional[Tensor], @@ -228,7 +228,7 @@ def _kendall_corrcoef_update( concat_preds: Optional[List[Tensor]] = None, concat_target: Optional[List[Tensor]] = None, num_outputs: int = 1, -) -> Tuple[List[Tensor], List[Tensor]]: +) -> tuple[List[Tensor], List[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. Args: @@ -263,7 +263,7 @@ def _kendall_corrcoef_compute( target: Tensor, variant: _MetricVariant, alternative: Optional[_TestAlternative] = None, -) -> Tuple[Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Optional[Tensor]]: """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. Args: @@ -324,7 +324,7 @@ def kendall_rank_corrcoef( variant: Literal["a", "b", "c"] = "b", t_test: bool = False, alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: r"""Compute `Kendall Rank Correlation Coefficient`_. .. math:: diff --git a/src/torchmetrics/functional/regression/kl_divergence.py b/src/torchmetrics/functional/regression/kl_divergence.py index 6e6563aee71..c4ed9efb25a 100644 --- a/src/torchmetrics/functional/regression/kl_divergence.py +++ b/src/torchmetrics/functional/regression/kl_divergence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -22,7 +22,7 @@ from torchmetrics.utilities.compute import _safe_xlogy -def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]: +def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> tuple[Tensor, int]: """Update and returns KL divergence scores for each observation and the total number of observations. Args: diff --git a/src/torchmetrics/functional/regression/log_cosh.py b/src/torchmetrics/functional/regression/log_cosh.py index ef9402c740f..e13ee2a3e7a 100644 --- a/src/torchmetrics/functional/regression/log_cosh.py +++ b/src/torchmetrics/functional/regression/log_cosh.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -20,13 +19,13 @@ from torchmetrics.utilities.checks import _check_same_shape -def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: if preds.ndim == 2: return preds, target return preds.unsqueeze(1), target.unsqueeze(1) -def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: +def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute LogCosh error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/log_mse.py b/src/torchmetrics/functional/regression/log_mse.py index 3ba27ab9e99..7c3a3585127 100644 --- a/src/torchmetrics/functional/regression/log_mse.py +++ b/src/torchmetrics/functional/regression/log_mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Return variables required to compute Mean Squared Log Error. Checks for same shape of tensors. Args: diff --git a/src/torchmetrics/functional/regression/mae.py b/src/torchmetrics/functional/regression/mae.py index d40b30e9017..8774d67fbe1 100644 --- a/src/torchmetrics/functional/regression/mae.py +++ b/src/torchmetrics/functional/regression/mae.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_absolute_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: +def _mean_absolute_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Absolute Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/mape.py b/src/torchmetrics/functional/regression/mape.py index bdcb7c87264..c109e898eee 100644 --- a/src/torchmetrics/functional/regression/mape.py +++ b/src/torchmetrics/functional/regression/mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -23,7 +23,7 @@ def _mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, epsilon: float = 1.17e-06, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index f9649a87416..4ea9490c841 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Squared Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index 52cae36adb0..ccbe81c333d 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -22,7 +22,7 @@ def _normalized_root_mean_squared_error_update( preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean" -) -> Tuple[Tensor, int, Tensor]: +) -> tuple[Tensor, int, Tensor]: """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. Args: diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 47b26344163..e8fa9ef3408 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple import torch from torch import Tensor @@ -32,7 +31,7 @@ def _pearson_corrcoef_update( corr_xy: Tensor, num_prior: Tensor, num_outputs: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Pearson Correlation Coefficient. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/r2.py b/src/torchmetrics/functional/regression/r2.py index ec6aec10a12..f52d082fc74 100644 --- a/src/torchmetrics/functional/regression/r2.py +++ b/src/torchmetrics/functional/regression/r2.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -20,7 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, int]: +def _r2_score_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor, Tensor, int]: """Update and returns variables required to compute R2 score. Check for same shape and 1D/2D input tensors. diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 7d412e407b0..b72bd7f03f7 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -54,7 +53,7 @@ def _rank_data(data: Tensor) -> Tensor: return rank -def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: +def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spearman Correlation Coefficient. Check for same shape and type of input tensors. diff --git a/src/torchmetrics/functional/regression/symmetric_mape.py b/src/torchmetrics/functional/regression/symmetric_mape.py index 2ab3a55f7d5..3fca98258c7 100644 --- a/src/torchmetrics/functional/regression/symmetric_mape.py +++ b/src/torchmetrics/functional/regression/symmetric_mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor @@ -23,7 +23,7 @@ def _symmetric_mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, epsilon: float = 1.17e-06, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: """Update and returns variables required to compute Symmetric Mean Absolute Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/tweedie_deviance.py b/src/torchmetrics/functional/regression/tweedie_deviance.py index e3508dc04c7..328829dffe3 100644 --- a/src/torchmetrics/functional/regression/tweedie_deviance.py +++ b/src/torchmetrics/functional/regression/tweedie_deviance.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -20,7 +19,7 @@ from torchmetrics.utilities.compute import _safe_xlogy -def _tweedie_deviance_score_update(preds: Tensor, targets: Tensor, power: float = 0.0) -> Tuple[Tensor, Tensor]: +def _tweedie_deviance_score_update(preds: Tensor, targets: Tensor, power: float = 0.0) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Deviance Score for the given power. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/wmape.py b/src/torchmetrics/functional/regression/wmape.py index 443badba325..1781f306608 100644 --- a/src/torchmetrics/functional/regression/wmape.py +++ b/src/torchmetrics/functional/regression/wmape.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -22,7 +21,7 @@ def _weighted_mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Weighted Absolute Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/retrieval/_deprecated.py b/src/torchmetrics/functional/retrieval/_deprecated.py index 2be7c408340..6284470d1b2 100644 --- a/src/torchmetrics/functional/retrieval/_deprecated.py +++ b/src/torchmetrics/functional/retrieval/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional from torch import Tensor @@ -88,7 +88,7 @@ def _retrieval_precision( def _retrieval_precision_recall_curve( preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Wrapper for deprecated import. >>> from torch import tensor diff --git a/src/torchmetrics/functional/retrieval/precision_recall_curve.py b/src/torchmetrics/functional/retrieval/precision_recall_curve.py index 7c70c92ff4d..269dca04016 100644 --- a/src/torchmetrics/functional/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/functional/retrieval/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -23,7 +23,7 @@ def retrieval_precision_recall_curve( preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute precision-recall pairs for different k (from 1 to `max_k`). In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 87b3b699fc0..de9a8068bf9 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -46,7 +46,7 @@ def _dice_score_update( num_classes: int, include_background: bool, input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py index daadc90f6ba..ca58e5f036c 100644 --- a/src/torchmetrics/functional/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union import torch from torch import Tensor @@ -28,7 +28,7 @@ def _hausdorff_distance_validate_args( num_classes: int, include_background: bool, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", ) -> None: @@ -55,7 +55,7 @@ def hausdorff_distance( num_classes: int, include_background: bool = False, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 184e9578cab..9cfed0fa1bf 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor @@ -45,7 +44,7 @@ def _mean_iou_update( num_classes: int, include_background: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update the intersection and union counts for the mean IoU computation.""" _check_same_shape(preds, target) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 59d42e16171..b6dd123a4d6 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor @@ -24,7 +24,7 @@ from torchmetrics.utilities.imports import _SCIPY_AVAILABLE -def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ignore_background(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Ignore the background class in the computation assuming it is the first, index 0.""" preds = preds[:, 1:] if preds.shape[1] > 1 else preds target = target[:, 1:] if target.shape[1] > 1 else target @@ -44,7 +44,7 @@ def check_if_binarized(x: Tensor) -> None: raise ValueError("Input x should be binarized") -def _unfold(x: Tensor, kernel_size: Tuple[int, ...]) -> Tensor: +def _unfold(x: Tensor, kernel_size: tuple[int, ...]) -> Tensor: """Unfold the input tensor to a matrix. Function supports 3d images e.g. (B, C, D, H, W). Inspired by: @@ -112,7 +112,7 @@ def generate_binary_structure(rank: int, connectivity: int) -> Tensor: def binary_erosion( - image: Tensor, structure: Optional[Tensor] = None, origin: Optional[Tuple[int, ...]] = None, border_value: int = 0 + image: Tensor, structure: Optional[Tensor] = None, origin: Optional[tuple[int, ...]] = None, border_value: int = 0 ) -> Tensor: """Binary erosion of a tensor image. @@ -183,7 +183,7 @@ def binary_erosion( def distance_transform( x: Tensor, - sampling: Optional[Union[Tensor, List[float]]] = None, + sampling: Optional[Union[Tensor, list[float]]] = None, metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", engine: Literal["pytorch", "scipy"] = "pytorch", ) -> Tensor: @@ -285,8 +285,8 @@ def mask_edges( preds: Tensor, target: Tensor, crop: bool = True, - spacing: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, -) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]: + spacing: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, +) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, Tensor, Tensor]]: """Get the edges of binary segmentation masks. Args: @@ -343,7 +343,7 @@ def surface_distance( preds: Tensor, target: Tensor, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, ) -> Tensor: """Calculate the surface distance between two binary edge masks. @@ -393,9 +393,9 @@ def edge_surface_distance( preds: Tensor, target: Tensor, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, symmetric: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Extracts the edges from the input masks and calculates the surface distance between them. Args: @@ -423,8 +423,8 @@ def edge_surface_distance( @functools.lru_cache def get_neighbour_tables( - spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None -) -> Tuple[Tensor, Tensor]: + spacing: Union[tuple[int, int], tuple[int, int, int]], device: Optional[torch.device] = None +) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the contour length or surface area of the corresponding contour. Args: @@ -443,7 +443,7 @@ def get_neighbour_tables( raise ValueError("The spacing must be a tuple of length 2 or 3.") -def table_contour_length(spacing: Tuple[int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: +def table_contour_length(spacing: tuple[int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the contour length of the corresponding contour. Adopted from: @@ -487,7 +487,7 @@ def table_contour_length(spacing: Tuple[int, int], device: Optional[torch.device @functools.lru_cache -def table_surface_area(spacing: Tuple[int, int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: +def table_surface_area(spacing: tuple[int, int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the surface area of the corresponding surface. Adopted from: diff --git a/src/torchmetrics/functional/shape/procrustes.py b/src/torchmetrics/functional/shape/procrustes.py index 08068fd2454..e72ca339306 100644 --- a/src/torchmetrics/functional/shape/procrustes.py +++ b/src/torchmetrics/functional/shape/procrustes.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor, linalg @@ -22,7 +22,7 @@ def procrustes_disparity( point_cloud1: Tensor, point_cloud2: Tensor, return_all: bool = False -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor]]: """Runs procrustrus analysis on a batch of data points. Works similar ``scipy.spatial.procrustes`` but for batches of data points. diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 169c3d5357b..380c7048914 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Literal, Optional, Union import torch from torch import Tensor @@ -31,19 +32,19 @@ if not _TRANSFORMERS_GREATER_EQUAL_4_4: __doctest_skip__ = ["_bert_score", "_infolm"] -SQUAD_SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] -SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, List[SQUAD_SINGLE_TARGET_TYPE]] +SQUAD_SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] +SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, list[SQUAD_SINGLE_TARGET_TYPE]] def _bert_score( - preds: Union[List[str], Dict[str, Tensor]], - target: Union[List[str], Dict[str, Tensor]], + preds: Union[list[str], dict[str, Tensor]], + target: Union[list[str], dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Any = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -55,7 +56,7 @@ def _bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, -) -> Dict[str, Union[Tensor, List[float], str]]: +) -> dict[str, Union[Tensor, list[float], str]]: """Wrapper for deprecated import. >>> preds = ["hello there", "general kenobi"] @@ -111,7 +112,7 @@ def _bleu_score( return bleu_score(preds=preds, target=target, n_gram=n_gram, smooth=smooth, weights=weights) -def _char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -133,7 +134,7 @@ def _chrf_score( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] @@ -164,7 +165,7 @@ def _extended_edit_distance( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "here is an other sample"] @@ -201,7 +202,7 @@ def _infolm( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['he read the book because he was interested in world history'] @@ -229,7 +230,7 @@ def _infolm( ) -def _match_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -264,8 +265,8 @@ def _rouge_score( use_stemmer: bool = False, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), -) -> Dict[str, Tensor]: + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), +) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = "My name is John" @@ -327,7 +328,7 @@ def _sacre_bleu_score( ) -def _squad(preds: Union[Dict[str, str], List[Dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> Dict[str, Tensor]: +def _squad(preds: Union[dict[str, str], list[dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] @@ -348,7 +349,7 @@ def _translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] @@ -369,7 +370,7 @@ def _translation_edit_rate( ) -def _word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -382,7 +383,7 @@ def _word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]] return word_error_rate(preds=preds, target=target) -def _word_information_lost(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -395,7 +396,7 @@ def _word_information_lost(preds: Union[str, List[str]], target: Union[str, List return word_information_lost(preds=preds, target=target) -def _word_information_preserved(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 71bec857a72..9835723fae4 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -14,8 +14,9 @@ import csv import logging import urllib +from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -75,8 +76,8 @@ def _get_embeddings_and_idf_scale( all_layers: bool = False, idf: bool = False, verbose: bool = False, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, -) -> Tuple[Tensor, Tensor]: + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, +) -> tuple[Tensor, Tensor]: """Calculate sentence embeddings and the inverse-document-frequency scaling factor. Args: @@ -158,7 +159,7 @@ def _get_scaled_precision_or_recall(cos_sim: Tensor, metric: str, idf_scale: Ten def _get_precision_recall_f1( preds_embeddings: Tensor, target_embeddings: Tensor, preds_idf_scale: Tensor, target_idf_scale: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate precision, recall and F1 score over candidate and reference sentences. Args: @@ -245,7 +246,7 @@ def _rescale_metrics_with_baseline( baseline: Tensor, num_layers: Optional[int] = None, all_layers: bool = False, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Rescale the computed metrics with the pre-computed baseline.""" if num_layers is None and all_layers is False: num_layers = -1 @@ -257,14 +258,14 @@ def _rescale_metrics_with_baseline( def bert_score( - preds: Union[str, Sequence[str], Dict[str, Tensor]], - target: Union[str, Sequence[str], Dict[str, Tensor]], + preds: Union[str, Sequence[str], dict[str, Tensor]], + target: Union[str, Sequence[str], dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Any = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -277,7 +278,7 @@ def bert_score( baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, truncation: bool = False, -) -> Dict[str, Union[Tensor, List[float], str]]: +) -> dict[str, Union[Tensor, list[float], str]]: """`Bert_score Evaluating Text Generation`_ for text similirity matching. This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference @@ -405,7 +406,7 @@ def bert_score( ) if _are_empty_lists: rank_zero_warn("Predictions and references are empty.") - output_dict: Dict[str, Union[Tensor, List[float], str]] = { + output_dict: dict[str, Union[Tensor, list[float], str]] = { "precision": [0.0], "recall": [0.0], "f1": [0.0], diff --git a/src/torchmetrics/functional/text/bleu.py b/src/torchmetrics/functional/text/bleu.py index 032b677f182..52b8bb17432 100644 --- a/src/torchmetrics/functional/text/bleu.py +++ b/src/torchmetrics/functional/text/bleu.py @@ -17,7 +17,8 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from typing import Callable, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Callable, Optional, Union import torch from torch import Tensor, tensor @@ -66,7 +67,7 @@ def _bleu_score_update( target_len: Tensor, n_gram: int = 4, tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute the BLEU score. Args: diff --git a/src/torchmetrics/functional/text/cer.py b/src/torchmetrics/functional/text/cer.py index 19dc55afba9..e9ee191c449 100644 --- a/src/torchmetrics/functional/text/cer.py +++ b/src/torchmetrics/functional/text/cer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor @@ -21,9 +21,9 @@ def _cer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the cer score with the current set of references and predictions. Args: @@ -63,7 +63,7 @@ def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Compute Character Error Rate used for performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index ca98778fade..0e5e4978f1f 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -21,8 +21,9 @@ # Reference to the approval: https://github.com/Lightning-AI/torchmetrics/pull/2701#issuecomment-2316891785 from collections import defaultdict +from collections.abc import Sequence from itertools import chain -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -36,8 +37,8 @@ def _prepare_n_grams_dicts( n_char_order: int, n_word_order: int -) -> Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +) -> tuple[ + dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor] ]: """Prepare dictionaries with default zero values for total ref, hypothesis and matching character and word n-grams. @@ -50,12 +51,12 @@ def _prepare_n_grams_dicts( n-grams. """ - total_preds_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_preds_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_target_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_target_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_matching_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_preds_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_preds_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_target_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_target_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_matching_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_matching_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} return ( total_preds_char_n_grams, @@ -67,7 +68,7 @@ def _prepare_n_grams_dicts( ) -def _get_characters(sentence: str, whitespace: bool) -> List[str]: +def _get_characters(sentence: str, whitespace: bool) -> list[str]: """Split sentence into individual characters. Args: @@ -83,7 +84,7 @@ def _get_characters(sentence: str, whitespace: bool) -> List[str]: return list(sentence.strip().replace(" ", "")) -def _separate_word_and_punctuation(word: str) -> List[str]: +def _separate_word_and_punctuation(word: str) -> list[str]: """Separates out punctuations from beginning and end of words for chrF. Adapted from https://github.com/m-popovic/chrF and @@ -106,7 +107,7 @@ def _separate_word_and_punctuation(word: str) -> List[str]: return [word] -def _get_words_and_punctuation(sentence: str) -> List[str]: +def _get_words_and_punctuation(sentence: str) -> list[str]: """Separates out punctuations from beginning and end of words for chrF for all words in the sentence. Args: @@ -119,7 +120,7 @@ def _get_words_and_punctuation(sentence: str) -> List[str]: return list(chain.from_iterable(_separate_word_and_punctuation(word) for word in sentence.strip().split())) -def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]: +def _ngram_counts(char_or_word_list: list[str], n_gram_order: int) -> dict[int, dict[tuple[str, ...], Tensor]]: """Calculate n-gram counts. Args: @@ -130,7 +131,7 @@ def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, A dictionary of dictionaries with a counts of given n-grams. """ - ngrams: Dict[int, Dict[Tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) + ngrams: dict[int, dict[tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) for n in range(1, n_gram_order + 1): for ngram in (tuple(char_or_word_list[i : i + n]) for i in range(len(char_or_word_list) - n + 1)): ngrams[n][ngram] += tensor(1) @@ -139,11 +140,11 @@ def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, def _get_n_grams_counts_and_total_ngrams( sentence: str, n_char_order: int, n_word_order: int, lowercase: bool, whitespace: bool -) -> Tuple[ - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Tensor], - Dict[int, Tensor], +) -> tuple[ + dict[int, dict[tuple[str, ...], Tensor]], + dict[int, dict[tuple[str, ...], Tensor]], + dict[int, Tensor], + dict[int, Tensor], ]: """Get n-grams and total n-grams. @@ -164,7 +165,7 @@ def _get_n_grams_counts_and_total_ngrams( def _char_and_word_ngrams_counts( sentence: str, n_char_order: int, n_word_order: int, lowercase: bool - ) -> Tuple[Dict[int, Dict[Tuple[str, ...], Tensor]], Dict[int, Dict[Tuple[str, ...], Tensor]]]: + ) -> tuple[dict[int, dict[tuple[str, ...], Tensor]], dict[int, dict[tuple[str, ...], Tensor]]]: """Get a dictionary of dictionaries with a counts of given n-grams.""" if lowercase: sentence = sentence.lower() @@ -172,9 +173,9 @@ def _char_and_word_ngrams_counts( word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order) return char_n_grams_counts, word_n_grams_counts - def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]: + def _get_total_ngrams(n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]]) -> dict[int, Tensor]: """Get total sum of n-grams over n-grams w.r.t n.""" - total_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + total_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for n in n_grams_counts: total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore return total_n_grams @@ -189,9 +190,9 @@ def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) def _get_ngram_matches( - hyp_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], -) -> Dict[int, Tensor]: + hyp_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + ref_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], +) -> dict[int, Tensor]: """Get a number of n-gram matches between reference and hypothesis n-grams. Args: @@ -202,7 +203,7 @@ def _get_ngram_matches( matching_n_grams """ - matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + matching_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for n in hyp_n_grams_counts: min_n_grams = [ torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n] @@ -211,7 +212,7 @@ def _get_ngram_matches( return matching_n_grams -def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor]) -> Dict[int, Tensor]: +def _sum_over_dicts(total_n_grams: dict[int, Tensor], n_grams: dict[int, Tensor]) -> dict[int, Tensor]: """Aggregate total n-grams to keep corpus-level statistics. Args: @@ -228,12 +229,12 @@ def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor] def _calculate_fscore( - matching_char_n_grams: Dict[int, Tensor], - matching_word_n_grams: Dict[int, Tensor], - hyp_char_n_grams: Dict[int, Tensor], - hyp_word_n_grams: Dict[int, Tensor], - ref_char_n_grams: Dict[int, Tensor], - ref_word_n_grams: Dict[int, Tensor], + matching_char_n_grams: dict[int, Tensor], + matching_word_n_grams: dict[int, Tensor], + hyp_char_n_grams: dict[int, Tensor], + hyp_word_n_grams: dict[int, Tensor], + ref_char_n_grams: dict[int, Tensor], + ref_word_n_grams: dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: @@ -260,19 +261,19 @@ def _calculate_fscore( """ def _get_n_gram_fscore( - matching_n_grams: Dict[int, Tensor], ref_n_grams: Dict[int, Tensor], hyp_n_grams: Dict[int, Tensor], beta: float - ) -> Dict[int, Tensor]: + matching_n_grams: dict[int, Tensor], ref_n_grams: dict[int, Tensor], hyp_n_grams: dict[int, Tensor], beta: float + ) -> dict[int, Tensor]: """Get n-gram level f-score.""" - precision: Dict[int, Tensor] = { + precision: dict[int, Tensor] = { n: matching_n_grams[n] / hyp_n_grams[n] if hyp_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams } - recall: Dict[int, Tensor] = { + recall: dict[int, Tensor] = { n: matching_n_grams[n] / ref_n_grams[n] if ref_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams } - denominator: Dict[int, Tensor] = { + denominator: dict[int, Tensor] = { n: torch.max(beta**2 * precision[n] + recall[n], _EPS_SMOOTHING) for n in matching_n_grams } - f_score: Dict[int, Tensor] = { + f_score: dict[int, Tensor] = { n: (1 + beta**2) * precision[n] * recall[n] / denominator[n] for n in matching_n_grams } @@ -285,18 +286,18 @@ def _get_n_gram_fscore( def _calculate_sentence_level_chrf_score( - targets: List[str], - pred_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_char_n_grams: Dict[int, Tensor], - pred_word_n_grams: Dict[int, Tensor], + targets: list[str], + pred_char_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + pred_word_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + pred_char_n_grams: dict[int, Tensor], + pred_word_n_grams: dict[int, Tensor], n_char_order: int, n_word_order: int, n_order: float, beta: float, lowercase: bool, whitespace: bool, -) -> Tuple[Tensor, Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor]]: +) -> tuple[Tensor, dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor]]: """Calculate the best sentence-level chrF/chrF++ score. For a given pre-processed hypothesis, all references are evaluated and score and statistics @@ -328,10 +329,10 @@ def _calculate_sentence_level_chrf_score( """ best_f_score = tensor(0.0) - best_matching_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_matching_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_matching_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_matching_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for target in targets: ( @@ -373,12 +374,12 @@ def _calculate_sentence_level_chrf_score( def _chrf_score_update( preds: Union[str, Sequence[str]], target: Union[Sequence[str], Sequence[Sequence[str]]], - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], + total_preds_char_n_grams: dict[int, Tensor], + total_preds_word_n_grams: dict[int, Tensor], + total_target_char_n_grams: dict[int, Tensor], + total_target_word_n_grams: dict[int, Tensor], + total_matching_char_n_grams: dict[int, Tensor], + total_matching_word_n_grams: dict[int, Tensor], n_char_order: int, n_word_order: int, n_order: float, @@ -386,13 +387,13 @@ def _chrf_score_update( lowercase: bool, whitespace: bool, sentence_chrf_score: Optional[List[Tensor]] = None, -) -> Tuple[ - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], +) -> tuple[ + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], Optional[List[Tensor]], ]: """Update function for chrf score. @@ -482,12 +483,12 @@ def _chrf_score_update( def _chrf_score_compute( - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], + total_preds_char_n_grams: dict[int, Tensor], + total_preds_word_n_grams: dict[int, Tensor], + total_target_char_n_grams: dict[int, Tensor], + total_target_word_n_grams: dict[int, Tensor], + total_matching_char_n_grams: dict[int, Tensor], + total_matching_word_n_grams: dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: @@ -529,7 +530,7 @@ def chrf_score( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate `chrF score`_ of machine translated text with one or more references. This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in diff --git a/src/torchmetrics/functional/text/edit.py b/src/torchmetrics/functional/text/edit.py index 0e443bb4c04..5660f45e025 100644 --- a/src/torchmetrics/functional/text/edit.py +++ b/src/torchmetrics/functional/text/edit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index f6f562f5338..bde77680bea 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -88,8 +88,9 @@ import re import unicodedata +from collections.abc import Sequence from math import inf -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Union from torch import Tensor, stack, tensor from typing_extensions import Literal @@ -253,7 +254,7 @@ def _preprocess_sentences( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], language: Literal["en", "ja"], -) -> Tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: +) -> tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: """Preprocess strings according to language requirements. Args: @@ -370,7 +371,7 @@ def extended_edit_distance( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings. The metric utilises the Levenshtein distance and extends it by adding a jump operation. diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index d4c9ff7ae04..a61f06232ab 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -29,8 +29,9 @@ # limitations under the License. import math +from collections.abc import Sequence from enum import Enum, unique -from typing import Dict, List, Sequence, Tuple, Union +from typing import Union # Tercom-inspired limits _BEAM_WIDTH = 25 @@ -67,12 +68,12 @@ class _LevenshteinEditDistance: """ def __init__( - self, reference_tokens: List[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1 + self, reference_tokens: list[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1 ) -> None: self.reference_tokens = reference_tokens self.reference_len = len(reference_tokens) - self.cache: Dict[str, Tuple[int, str]] = {} + self.cache: dict[str, tuple[int, str]] = {} self.cache_size = 0 self.op_insert = op_insert @@ -81,7 +82,7 @@ def __init__( self.op_nothing = 0 self.op_undefined = _INT_INFINITY - def __call__(self, prediction_tokens: List[str]) -> Tuple[int, Tuple[_EditOperations, ...]]: + def __call__(self, prediction_tokens: list[str]) -> tuple[int, tuple[_EditOperations, ...]]: """Calculate edit distance between self._words_ref and the hypothesis. Uses cache to skip some computations. Args: @@ -104,10 +105,10 @@ def __call__(self, prediction_tokens: List[str]) -> Tuple[int, Tuple[_EditOperat def _levenshtein_edit_distance( self, - prediction_tokens: List[str], + prediction_tokens: list[str], prediction_start: int, - cache: List[List[Tuple[int, _EditOperations]]], - ) -> Tuple[int, List[List[Tuple[int, _EditOperations]]], Tuple[_EditOperations, ...]]: + cache: list[list[tuple[int, _EditOperations]]], + ) -> tuple[int, list[list[tuple[int, _EditOperations]]], tuple[_EditOperations, ...]]: """Dynamic programming algorithm to compute the Levenhstein edit distance. Args: @@ -121,10 +122,10 @@ def _levenshtein_edit_distance( """ prediction_len = len(prediction_tokens) - empty_rows: List[List[Tuple[int, _EditOperations]]] = [ + empty_rows: list[list[tuple[int, _EditOperations]]] = [ list(self._get_empty_row(self.reference_len)) for _ in range(prediction_len - prediction_start) ] - edit_distance: List[List[Tuple[int, _EditOperations]]] = cache + empty_rows + edit_distance: list[list[tuple[int, _EditOperations]]] = cache + empty_rows length_ratio = self.reference_len / prediction_len if prediction_tokens else 1.0 # Ensure to not end up with zero overlaip with previous role @@ -171,8 +172,8 @@ def _levenshtein_edit_distance( return edit_distance[-1][-1][0], edit_distance[len(cache) :], trace def _get_trace( - self, prediction_len: int, edit_distance: List[List[Tuple[int, _EditOperations]]] - ) -> Tuple[_EditOperations, ...]: + self, prediction_len: int, edit_distance: list[list[tuple[int, _EditOperations]]] + ) -> tuple[_EditOperations, ...]: """Get a trace of executed operations from the edit distance matrix. Args: @@ -189,7 +190,7 @@ def _get_trace( If an unknown operation has been applied. """ - trace: Tuple[_EditOperations, ...] = () + trace: tuple[_EditOperations, ...] = () i = prediction_len j = self.reference_len @@ -208,7 +209,7 @@ def _get_trace( return trace - def _add_cache(self, prediction_tokens: List[str], edit_distance: List[List[Tuple[int, _EditOperations]]]) -> None: + def _add_cache(self, prediction_tokens: list[str], edit_distance: list[list[tuple[int, _EditOperations]]]) -> None: """Add newly computed rows to cache. Since edit distance is only calculated on the hypothesis suffix that was not in cache, the number of rows in @@ -241,7 +242,7 @@ def _add_cache(self, prediction_tokens: List[str], edit_distance: List[List[Tupl value = node[word] node = value[0] # type: ignore - def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tuple[int, _EditOperations]]]]: + def _find_cache(self, prediction_tokens: list[str]) -> tuple[int, list[list[tuple[int, _EditOperations]]]]: """Find the already calculated rows of the Levenshtein edit distance metric. Args: @@ -258,7 +259,7 @@ def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tupl """ node = self.cache start_position = 0 - edit_distance: List[List[Tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)] + edit_distance: list[list[tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)] for word in prediction_tokens: if word in node: start_position += 1 @@ -269,7 +270,7 @@ def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tupl return start_position, edit_distance - def _get_empty_row(self, length: int) -> List[Tuple[int, _EditOperations]]: + def _get_empty_row(self, length: int) -> list[tuple[int, _EditOperations]]: """Precomputed empty matrix row for Levenhstein edit distance. Args: @@ -281,7 +282,7 @@ def _get_empty_row(self, length: int) -> List[Tuple[int, _EditOperations]]: """ return [(int(self.op_undefined), _EditOperations.OP_UNDEFINED)] * (length + 1) - def _get_initial_row(self, length: int) -> List[Tuple[int, _EditOperations]]: + def _get_initial_row(self, length: int) -> list[tuple[int, _EditOperations]]: """First row corresponds to insertion operations of the reference, so 1 edit operation per reference word. Args: @@ -297,7 +298,7 @@ def _get_initial_row(self, length: int) -> List[Tuple[int, _EditOperations]]: def _validate_inputs( ref_corpus: Union[Sequence[str], Sequence[Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], -) -> Tuple[Sequence[Sequence[str]], Sequence[str]]: +) -> tuple[Sequence[Sequence[str]], Sequence[str]]: """Check and update (if needed) the format of reference and hypothesis corpora for various text evaluation metrics. Args: @@ -326,7 +327,7 @@ def _validate_inputs( return ref_corpus, hypothesis_corpus -def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int: +def _edit_distance(prediction_tokens: list[str], reference_tokens: list[str]) -> int: """Dynamic programming algorithm to compute the edit distance. Args: @@ -350,7 +351,7 @@ def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> return dp[-1][-1] -def _flip_trace(trace: Tuple[_EditOperations, ...]) -> Tuple[_EditOperations, ...]: +def _flip_trace(trace: tuple[_EditOperations, ...]) -> tuple[_EditOperations, ...]: """Flip the trace of edit operations. Instead of rewriting a->b, get a recipe for rewriting b->a. Simply flips insertions and deletions. @@ -363,13 +364,13 @@ def _flip_trace(trace: Tuple[_EditOperations, ...]) -> Tuple[_EditOperations, .. A tuple of inverted edit operations. """ - _flip_operations: Dict[_EditOperations, _EditOperations] = { + _flip_operations: dict[_EditOperations, _EditOperations] = { _EditOperations.OP_INSERT: _EditOperations.OP_DELETE, _EditOperations.OP_DELETE: _EditOperations.OP_INSERT, } def _replace_operation_or_retain( - operation: _EditOperations, _flip_operations: Dict[_EditOperations, _EditOperations] + operation: _EditOperations, _flip_operations: dict[_EditOperations, _EditOperations] ) -> _EditOperations: if operation in _flip_operations: return _flip_operations.get(operation) # type: ignore @@ -378,7 +379,7 @@ def _replace_operation_or_retain( return tuple(_replace_operation_or_retain(operation, _flip_operations) for operation in trace) -def _trace_to_alignment(trace: Tuple[_EditOperations, ...]) -> Tuple[Dict[int, int], List[int], List[int]]: +def _trace_to_alignment(trace: tuple[_EditOperations, ...]) -> tuple[dict[int, int], list[int], list[int]]: """Transform trace of edit operations into an alignment of the sequences. Args: @@ -395,9 +396,9 @@ def _trace_to_alignment(trace: Tuple[_EditOperations, ...]) -> Tuple[Dict[int, i """ reference_position = hypothesis_position = -1 - reference_errors: List[int] = [] - hypothesis_errors: List[int] = [] - alignments: Dict[int, int] = {} + reference_errors: list[int] = [] + hypothesis_errors: list[int] = [] + alignments: dict[int, int] = {} # we are rewriting a into b for operation in trace: diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index f2b59126c7d..17c89558163 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -14,7 +14,7 @@ import math import os from collections import Counter, defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor @@ -49,8 +49,8 @@ def _process_attention_mask_for_special_tokens(attention_mask: Tensor) -> Tensor def _input_data_collator( - batch: Dict[str, Tensor], device: Optional[Union[str, torch.device]] = None -) -> Dict[str, Tensor]: + batch: dict[str, Tensor], device: Optional[Union[str, torch.device]] = None +) -> dict[str, Tensor]: """Trim model inputs. This function trims the model inputs to the longest sequence within the batch and put the input on the proper @@ -64,7 +64,7 @@ def _input_data_collator( return batch -def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> Tuple[Tensor, Tensor]: +def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> tuple[Tensor, Tensor]: """Pad the model output and attention mask to the target length.""" zeros_shape = list(model_output.shape) zeros_shape[2] = target_len - zeros_shape[2] @@ -76,7 +76,7 @@ def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_l return model_output, attention_mask -def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Sort tokenized sentence from the shortest to the longest one.""" sorted_indices = attention_mask.sum(1).argsort() input_ids = input_ids[sorted_indices] @@ -85,13 +85,13 @@ def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tu def _preprocess_text( - text: List[str], + text: list[str], tokenizer: Any, max_length: int = 512, truncation: bool = True, sort_according_length: bool = True, own_tokenizer: bool = False, -) -> Tuple[Dict[str, Tensor], Optional[Tensor]]: +) -> tuple[dict[str, Tensor], Optional[Tensor]]: """Text pre-processing function using `transformers` `AutoTokenizer` instance. Args: @@ -164,7 +164,7 @@ def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None: def _load_tokenizer_and_model( model_name_or_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None -) -> Tuple["PreTrainedTokenizerBase", "PreTrainedModel"]: +) -> tuple["PreTrainedTokenizerBase", "PreTrainedModel"]: """Load HuggingFace `transformers`' tokenizer and model. This function also handle a device placement. Args: @@ -191,14 +191,14 @@ class TextDataset(Dataset): def __init__( self, - text: List[str], + text: list[str], tokenizer: Any, max_length: int = 512, preprocess_text_fn: Callable[ - [List[str], Any, int, bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] + [list[str], Any, int, bool], Union[dict[str, Tensor], tuple[dict[str, Tensor], Optional[Tensor]]] ] = _preprocess_text, idf: bool = False, - tokens_idf: Optional[Dict[int, float]] = None, + tokens_idf: Optional[dict[int, float]] = None, truncation: bool = False, ) -> None: """Initialize text dataset class. @@ -225,7 +225,7 @@ def __init__( if idf: self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf() - def __getitem__(self, idx: int) -> Dict[str, Tensor]: + def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get the input ids and attention mask belonging to a specific datapoint.""" input_ids = self.text["input_ids"][idx, :] attention_mask = self.text["attention_mask"][idx, :] @@ -239,7 +239,7 @@ def __len__(self) -> int: """Return the number of sentences in the dataset.""" return self.num_sentences - def _get_tokens_idf(self) -> Dict[int, float]: + def _get_tokens_idf(self) -> dict[int, float]: """Calculate token inverse document frequencies. Return: @@ -250,7 +250,7 @@ def _get_tokens_idf(self) -> Dict[int, float]: for tokens in map(self._set_of_tokens, self.text["input_ids"]): token_counter.update(tokens) - tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value) + tokens_idf: dict[int, float] = defaultdict(self._get_tokens_idf_default_value) tokens_idf.update({ idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items() }) @@ -261,7 +261,7 @@ def _get_tokens_idf_default_value(self) -> float: return math.log((self.num_sentences + 1) / 1) @staticmethod - def _set_of_tokens(input_ids: Tensor) -> Set: + def _set_of_tokens(input_ids: Tensor) -> set: """Return set of tokens from the `input_ids` :class:`~torch.Tensor`.""" return set(input_ids.tolist()) @@ -274,7 +274,7 @@ def __init__( input_ids: Tensor, attention_mask: Tensor, idf: bool = False, - tokens_idf: Optional[Dict[int, float]] = None, + tokens_idf: Optional[dict[int, float]] = None, ) -> None: """Initialize the dataset class. diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 0365cdf7ae6..94452f4886e 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Sequence from enum import unique -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from torch import Tensor @@ -320,7 +321,7 @@ def _get_dataloader( return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) -def _get_special_tokens_map(tokenizer: "PreTrainedTokenizerBase") -> Dict[str, int]: +def _get_special_tokens_map(tokenizer: "PreTrainedTokenizerBase") -> dict[str, int]: """Build a dictionary of model/tokenizer special tokens. Args: @@ -366,10 +367,10 @@ def _get_token_mask(input_ids: Tensor, pad_token_id: int, sep_token_id: int, cls def _get_batch_distribution( model: "PreTrainedModel", - batch: Dict[str, Tensor], + batch: dict[str, Tensor], temperature: float, idf: bool, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], ) -> Tensor: """Calculate a discrete probability distribution for a batch of examples. See `InfoLM`_ for details. @@ -427,7 +428,7 @@ def _get_data_distribution( dataloader: DataLoader, temperature: float, idf: bool, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], verbose: bool, ) -> Tensor: """Calculate a discrete probability distribution according to the methodology described in `InfoLM`_. @@ -467,7 +468,7 @@ def _infolm_update( target: Union[str, Sequence[str]], tokenizer: "PreTrainedTokenizerBase", max_length: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Update the metric state by a tokenization of ``preds`` and ``target`` sentencens. Args: @@ -503,7 +504,7 @@ def _infolm_compute( temperature: float, idf: bool, information_measure_cls: _InformationMeasure, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], verbose: bool = True, ) -> Tensor: """Calculate selected information measure using the pre-trained language model. @@ -557,7 +558,7 @@ def infolm( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate `InfoLM`_ [1]. InfoML corresponds to distance/divergence between predicted and reference sentence discrete distribution using diff --git a/src/torchmetrics/functional/text/mer.py b/src/torchmetrics/functional/text/mer.py index 89d3764331d..46f30331b85 100644 --- a/src/torchmetrics/functional/text/mer.py +++ b/src/torchmetrics/functional/text/mer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor @@ -21,9 +21,9 @@ def _mer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the mer score with the current set of references and predictions. Args: @@ -64,7 +64,7 @@ def _mer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def match_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Match error rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index 1561ffa412a..5931b7b6944 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -62,7 +62,7 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.") -def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tuple[Tensor, Tensor]: +def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> tuple[Tensor, Tensor]: """Compute intermediate statistics for Perplexity. Args: diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 58c9a05fecf..23a7aeb0856 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -13,7 +13,8 @@ # limitations under the License. import re from collections import Counter -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, tensor @@ -23,7 +24,7 @@ __doctest_requires__ = {("rouge_score", "_rouge_score_update"): ["nltk"]} -ALLOWED_ROUGE_KEYS: Dict[str, Union[int, str]] = { +ALLOWED_ROUGE_KEYS: dict[str, Union[int, str]] = { "rouge1": 1, "rouge2": 2, "rouge3": 3, @@ -71,7 +72,7 @@ def _split_sentence(x: str) -> Sequence[str]: return nltk.sent_tokenize(x) -def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, Tensor]: +def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> dict[str, Tensor]: """Compute overall metrics. This function computes precision, recall and F1 score based on hits/lcs, the length of lists of tokenizer @@ -128,7 +129,7 @@ def _backtracked_lcs( """ i = len(pred_tokens) j = len(target_tokens) - backtracked_lcs: List[int] = [] + backtracked_lcs: list[int] = [] while i > 0 and j > 0: if pred_tokens[i - 1] == target_tokens[j - 1]: backtracked_lcs.insert(0, j - 1) @@ -199,7 +200,7 @@ def _normalize_and_tokenize_text( return [x for x in tokens if (isinstance(x, str) and len(x) > 0)] -def _rouge_n_score(pred: Sequence[str], target: Sequence[str], n_gram: int) -> Dict[str, Tensor]: +def _rouge_n_score(pred: Sequence[str], target: Sequence[str], n_gram: int) -> dict[str, Tensor]: """Compute precision, recall and F1 score for the Rouge-N metric. Args: @@ -225,7 +226,7 @@ def _create_ngrams(tokens: Sequence[str], n: int) -> Counter: return _compute_metrics(hits, max(pred_len, 1), max(target_len, 1)) -def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]: +def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> dict[str, Tensor]: """Compute precision, recall and F1 score for the Rouge-L metric. Args: @@ -241,7 +242,7 @@ def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tens return _compute_metrics(lcs, pred_len, target_len) -def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> Dict[str, Tensor]: +def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> dict[str, Tensor]: r"""Compute precision, recall and F1 score for the Rouge-LSum metric. More information can be found in Section 3.2 of the referenced paper [1]. This implementation follow the official @@ -287,12 +288,12 @@ def _get_token_counts(sentences: Sequence[Sequence[str]]) -> Counter: def _rouge_score_update( preds: Sequence[str], target: Sequence[Sequence[str]], - rouge_keys_values: List[Union[int, str]], + rouge_keys_values: list[Union[int, str]], accumulate: str, stemmer: Optional[Any] = None, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, -) -> Dict[Union[int, str], List[Dict[str, Tensor]]]: +) -> dict[Union[int, str], list[dict[str, Tensor]]]: """Update the rouge score with the current set of predicted and target sentences. Args: @@ -327,11 +328,11 @@ def _rouge_score_update( 'recall': tensor(0.5000)}]} """ - results: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} + results: dict[Union[int, str], list[dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} for pred_raw, target_raw in zip(preds, target): - result_inner: Dict[Union[int, str], Dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values} - result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} + result_inner: dict[Union[int, str], dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values} + result_avg: dict[Union[int, str], list[dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) if "Lsum" in rouge_keys_values: @@ -369,11 +370,11 @@ def _rouge_score_update( results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo elif accumulate == "avg": - new_result_avg: Dict[Union[int, str], Dict[str, Tensor]] = { + new_result_avg: dict[Union[int, str], dict[str, Tensor]] = { rouge_key: {} for rouge_key in rouge_keys_values } for rouge_key, metrics in result_avg.items(): - _dict_metric_score_batch: Dict[str, List[Tensor]] = {} + _dict_metric_score_batch: dict[str, List[Tensor]] = {} for metric in metrics: for _type, value in metric.items(): if _type not in _dict_metric_score_batch: @@ -390,14 +391,14 @@ def _rouge_score_update( return results -def _rouge_score_compute(sentence_results: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: +def _rouge_score_compute(sentence_results: dict[str, List[Tensor]]) -> dict[str, Tensor]: """Compute the combined ROUGE metric for all the input set of predicted and target sentences. Args: sentence_results: Rouge-N/Rouge-L/Rouge-LSum metrics calculated for single sentence. """ - results: Dict[str, Tensor] = {} + results: dict[str, Tensor] = {} # Obtain mean scores for individual rouge metrics if sentence_results == {}: return results @@ -415,8 +416,8 @@ def rouge_score( use_stemmer: bool = False, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), -) -> Dict[str, Tensor]: + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), +) -> dict[str, Tensor]: """Calculate `Calculate Rouge Score`_ , used for automatic summarization. Args: @@ -494,7 +495,7 @@ def rouge_score( if isinstance(target, str): target = [[target]] - sentence_results: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update( + sentence_results: dict[Union[int, str], list[dict[str, Tensor]]] = _rouge_score_update( preds, target, rouge_keys_values, @@ -504,7 +505,7 @@ def rouge_score( accumulate=accumulate, ) - output: Dict[str, List[Tensor]] = { + output: dict[str, List[Tensor]] = { f"rouge{rouge_key}_{tp}": [] for rouge_key in rouge_keys_values for tp in ["fmeasure", "precision", "recall"] } for rouge_key, metrics in sentence_results.items(): diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index 33c3afb0beb..a398acaf676 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -40,8 +40,9 @@ import os import re import tempfile +from collections.abc import Sequence from functools import partial -from typing import Any, ClassVar, Dict, Optional, Sequence, Type +from typing import Any, ClassVar, Optional import torch from torch import Tensor, tensor @@ -141,7 +142,7 @@ class _SacreBLEUTokenizer: } # Keep it as class variable to avoid initializing over and over again - sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} + sentencepiece_processors: ClassVar[dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} def __init__(self, tokenize: _TokenizersLiteral, lowercase: bool = False) -> None: self._check_tokenizers_validity(tokenize) @@ -155,7 +156,7 @@ def __call__(self, line: str) -> Sequence[str]: @classmethod def tokenize( - cls: Type["_SacreBLEUTokenizer"], + cls: type["_SacreBLEUTokenizer"], line: str, tokenize: _TokenizersLiteral, lowercase: bool = False, @@ -167,7 +168,7 @@ def tokenize( return cls._lower(tokenized_line, lowercase).split() @classmethod - def _tokenize_regex(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_regex(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Post-processing tokenizer for `13a` and `zh` tokenizers. Args: @@ -196,7 +197,7 @@ def _is_chinese_char(uchar: str) -> bool: return any(start <= uchar <= end for start, end in _UCODE_RANGES) @classmethod - def _tokenize_base(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_base(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes an input line with the tokenizer. Args: @@ -209,7 +210,7 @@ def _tokenize_base(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return line @classmethod - def _tokenize_13a(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_13a(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a line using a relatively minimal tokenization that is equivalent to mteval-v13a, used by WMT. Args: @@ -233,7 +234,7 @@ def _tokenize_13a(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_regex(f" {line} ") @classmethod - def _tokenize_zh(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_zh(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenization of Chinese text. This is done in two steps: separate each Chinese characters (by utf-8 encoding) and afterwards tokenize the @@ -261,7 +262,7 @@ def _tokenize_zh(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_regex(line_in_chars) @classmethod - def _tokenize_international(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_international(cls: type["_SacreBLEUTokenizer"], line: str) -> str: r"""Tokenizes a string following the official BLEU implementation. See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 @@ -294,7 +295,7 @@ def _tokenize_international(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return " ".join(line.split()) @classmethod - def _tokenize_char(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_char(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes all the characters in the input line. Args: @@ -307,7 +308,7 @@ def _tokenize_char(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return " ".join(char for char in line) @classmethod - def _tokenize_ja_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_ja_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Japanese string line using MeCab morphological analyzer. Args: @@ -326,7 +327,7 @@ def _tokenize_ja_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return tagger.parse(line).strip() @classmethod - def _tokenize_ko_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_ko_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Korean string line using MeCab-korean morphological analyzer. Args: @@ -346,7 +347,7 @@ def _tokenize_ko_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: @classmethod def _tokenize_flores( - cls: Type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"] + cls: type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"] ) -> str: """Tokenizes a string line using sentencepiece tokenizer. @@ -372,7 +373,7 @@ def _tokenize_flores( return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr] @classmethod - def _tokenize_flores_101(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_flores_101(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset. Args: @@ -385,7 +386,7 @@ def _tokenize_flores_101(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_flores(line, "flores101") @classmethod - def _tokenize_flores_200(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_flores_200(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset. Args: @@ -404,7 +405,7 @@ def _lower(line: str, lowercase: bool) -> str: return line @classmethod - def _check_tokenizers_validity(cls: Type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None: + def _check_tokenizers_validity(cls: type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None: """Check if a supported tokenizer is chosen. Also check all dependencies of a given tokenizers are installed. diff --git a/src/torchmetrics/functional/text/squad.py b/src/torchmetrics/functional/text/squad.py index 01dfb4ec0e6..c52f0860e14 100644 --- a/src/torchmetrics/functional/text/squad.py +++ b/src/torchmetrics/functional/text/squad.py @@ -17,17 +17,17 @@ import re import string from collections import Counter -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Union from torch import Tensor, tensor from torchmetrics.utilities import rank_zero_warn -SINGLE_PRED_TYPE = Dict[str, str] -PREDS_TYPE = Union[SINGLE_PRED_TYPE, List[SINGLE_PRED_TYPE]] -SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] -TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, List[SINGLE_TARGET_TYPE]] -UPDATE_METHOD_SINGLE_PRED_TYPE = Union[List[Dict[str, Union[str, int]]], str, Dict[str, Union[List[str], List[int]]]] +SINGLE_PRED_TYPE = dict[str, str] +PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]] +SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] +TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]] +UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]] SQuAD_FORMAT = { "answers": {"answer_start": [1], "text": ["This is a test text"]}, @@ -57,7 +57,7 @@ def lower(text: str) -> str: return white_space_fix(remove_articles(remove_punc(lower(s)))) -def _get_tokens(s: str) -> List[str]: +def _get_tokens(s: str) -> list[str]: """Split a sentence into separate tokens.""" return [] if not s else _normalize_text(s).split() @@ -84,7 +84,7 @@ def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: def _metric_max_over_ground_truths( - metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: List[str] + metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str] ) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" return max(metric_fn(prediction, truth) for truth in ground_truths) # type: ignore[type-var] @@ -92,12 +92,12 @@ def _metric_max_over_ground_truths( def _squad_input_check( preds: PREDS_TYPE, targets: TARGETS_TYPE -) -> Tuple[Dict[str, str], List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]]]: +) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]: """Check for types and convert the input to necessary format to compute the input.""" - if isinstance(preds, Dict): + if isinstance(preds, dict): preds = [preds] - if isinstance(targets, Dict): + if isinstance(targets, dict): targets = [targets] for pred in preds: @@ -118,7 +118,7 @@ def _squad_input_check( f"{SQuAD_FORMAT}" ) - answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore[assignment] + answers: dict[str, Union[list[str], list[int]]] = target["answers"] # type: ignore[assignment] if "text" not in answers: raise KeyError( "Expected keys in a 'answers' are 'text'." @@ -134,9 +134,9 @@ def _squad_input_check( def _squad_update( - preds: Dict[str, str], - target: List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]], -) -> Tuple[Tensor, Tensor, Tensor]: + preds: dict[str, str], + target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]], +) -> tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. Args: @@ -180,7 +180,7 @@ def _squad_update( return f1, exact_match, total -def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> Dict[str, Tensor]: +def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. Return: @@ -192,7 +192,7 @@ def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> Dict[str, return {"exact_match": exact_match, "f1": f1} -def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> Dict[str, Tensor]: +def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]: """Calculate `SQuAD Metric`_ . Args: diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 400a4c283b6..2d7a6211e0d 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -34,8 +34,9 @@ # limitations under the License. import re +from collections.abc import Iterator, Sequence from functools import lru_cache -from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union +from typing import List, Optional, Union from torch import Tensor, tensor @@ -150,7 +151,7 @@ def _normalize_general_and_western(sentence: str) -> str: return sentence @classmethod - def _normalize_asian(cls: Type["_TercomTokenizer"], sentence: str) -> str: + def _normalize_asian(cls: type["_TercomTokenizer"], sentence: str) -> str: """Split Chinese chars and Japanese kanji down to character level.""" # 4E00—9FFF CJK Unified Ideographs # 3400—4DBF CJK Unified Ideographs Extension A @@ -182,7 +183,7 @@ def _remove_punct(sentence: str) -> str: return re.sub(r"[\.,\?:;!\"\(\)]", "", sentence) @classmethod - def _remove_asian_punct(cls: Type["_TercomTokenizer"], sentence: str) -> str: + def _remove_asian_punct(cls: type["_TercomTokenizer"], sentence: str) -> str: """Remove asian punctuation from an input sentence string.""" sentence = re.sub(cls._ASIAN_PUNCTUATION, r"", sentence) return re.sub(cls._FULL_WIDTH_PUNCTUATION, r"", sentence) @@ -202,7 +203,7 @@ def _preprocess_sentence(sentence: str, tokenizer: _TercomTokenizer) -> str: return tokenizer(sentence.rstrip()) -def _find_shifted_pairs(pred_words: List[str], target_words: List[str]) -> Iterator[Tuple[int, int, int]]: +def _find_shifted_pairs(pred_words: list[str], target_words: list[str]) -> Iterator[tuple[int, int, int]]: """Find matching word sub-sequences in two lists of words. Ignores sub- sequences starting at the same position. Args: @@ -242,9 +243,9 @@ def _find_shifted_pairs(pred_words: List[str], target_words: List[str]) -> Itera def _handle_corner_cases_during_shifting( - alignments: Dict[int, int], - pred_errors: List[int], - target_errors: List[int], + alignments: dict[int, int], + pred_errors: list[int], + target_errors: list[int], pred_start: int, target_start: int, length: int, @@ -275,7 +276,7 @@ def _handle_corner_cases_during_shifting( return pred_start <= alignments[target_start] < pred_start + length -def _perform_shift(words: List[str], start: int, length: int, target: int) -> List[str]: +def _perform_shift(words: list[str], start: int, length: int, target: int) -> list[str]: """Perform a shift in ``words`` from ``start`` to ``target``. Args: @@ -289,13 +290,13 @@ def _perform_shift(words: List[str], start: int, length: int, target: int) -> Li """ - def _shift_word_before_previous_position(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_before_previous_position(words: list[str], start: int, target: int, length: int) -> list[str]: return words[:target] + words[start : start + length] + words[target:start] + words[start + length :] - def _shift_word_after_previous_position(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_after_previous_position(words: list[str], start: int, target: int, length: int) -> list[str]: return words[:start] + words[start + length : target] + words[start : start + length] + words[target:] - def _shift_word_within_shifted_string(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_within_shifted_string(words: list[str], start: int, target: int, length: int) -> list[str]: shifted_words = words[:start] shifted_words += words[start + length : length + target] shifted_words += words[start : start + length] @@ -310,11 +311,11 @@ def _shift_word_within_shifted_string(words: List[str], start: int, target: int, def _shift_words( - pred_words: List[str], - target_words: List[str], + pred_words: list[str], + target_words: list[str], cached_edit_distance: _LevenshteinEditDistance, checked_candidates: int, -) -> Tuple[int, List[str], int]: +) -> tuple[int, list[str], int]: """Attempt to shift words to match a hypothesis with a reference. It returns the lowest number of required edits between a hypothesis and a provided reference, a list of shifted @@ -342,7 +343,7 @@ def _shift_words( trace = _flip_trace(inverted_trace) alignments, target_errors, pred_errors = _trace_to_alignment(trace) - best: Optional[Tuple[int, int, int, int, List[str]]] = None + best: Optional[tuple[int, int, int, int, list[str]]] = None for pred_start, target_start, length in _find_shifted_pairs(pred_words, target_words): if _handle_corner_cases_during_shifting( @@ -390,7 +391,7 @@ def _shift_words( return best_score, shifted_words, checked_candidates -def _translation_edit_rate(pred_words: List[str], target_words: List[str]) -> Tensor: +def _translation_edit_rate(pred_words: list[str], target_words: list[str]) -> Tensor: """Compute translation edit rate between hypothesis and reference sentences. Args: @@ -425,7 +426,7 @@ def _translation_edit_rate(pred_words: List[str], target_words: List[str]) -> Te return tensor(total_edits) -def _compute_sentence_statistics(pred_words: List[str], target_words: List[List[str]]) -> Tuple[Tensor, Tensor]: +def _compute_sentence_statistics(pred_words: list[str], target_words: list[list[str]]) -> tuple[Tensor, Tensor]: """Compute sentence TER statistics between hypothesis and provided references. Args: @@ -477,7 +478,7 @@ def _ter_update( total_num_edits: Tensor, total_tgt_length: Tensor, sentence_ter: Optional[List[Tensor]] = None, -) -> Tuple[Tensor, Tensor, Optional[List[Tensor]]]: +) -> tuple[Tensor, Tensor, Optional[List[Tensor]]]: """Update TER statistics. Args: @@ -504,8 +505,8 @@ def _ter_update( target, preds = _validate_inputs(target, preds) for pred, tgt in zip(preds, target): - tgt_words_: List[List[str]] = [_preprocess_sentence(_tgt, tokenizer).split() for _tgt in tgt] - pred_words_: List[str] = _preprocess_sentence(pred, tokenizer).split() + tgt_words_: list[list[str]] = [_preprocess_sentence(_tgt, tokenizer).split() for _tgt in tgt] + pred_words_: list[str] = _preprocess_sentence(pred, tokenizer).split() num_edits, tgt_length = _compute_sentence_statistics(pred_words_, tgt_words_) total_num_edits += num_edits total_tgt_length += tgt_length @@ -536,7 +537,7 @@ def translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: """Calculate Translation edit rate (`TER`_) of machine translated text with one or more references. This implementation follows the implementations from diff --git a/src/torchmetrics/functional/text/wer.py b/src/torchmetrics/functional/text/wer.py index af50d4cb289..b61bdb4c105 100644 --- a/src/torchmetrics/functional/text/wer.py +++ b/src/torchmetrics/functional/text/wer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor @@ -21,9 +21,9 @@ def _wer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the wer score with the current set of references and predictions. Args: @@ -63,7 +63,7 @@ def _wer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word error rate (WordErrorRate_) is a common metric of performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the diff --git a/src/torchmetrics/functional/text/wil.py b/src/torchmetrics/functional/text/wil.py index bb0d5f7e0d8..3d8c370facb 100644 --- a/src/torchmetrics/functional/text/wil.py +++ b/src/torchmetrics/functional/text/wil.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union from torch import Tensor, tensor @@ -20,9 +20,9 @@ def _word_info_lost_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor, Tensor]: """Update the WIL score with the current set of references and predictions. Args: @@ -69,7 +69,7 @@ def _word_info_lost_compute(errors: Tensor, target_total: Tensor, preds_total: T return 1 - ((errors / target_total) * (errors / preds_total)) -def word_information_lost(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better diff --git a/src/torchmetrics/functional/text/wip.py b/src/torchmetrics/functional/text/wip.py index 2c77139009b..77dae42e5ed 100644 --- a/src/torchmetrics/functional/text/wip.py +++ b/src/torchmetrics/functional/text/wip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union from torch import Tensor, tensor @@ -19,9 +19,9 @@ def _wip_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor, Tensor]: """Update the wip score with the current set of references and predictions. Args: @@ -68,7 +68,7 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return (errors / target_total) * (errors / preds_total) -def word_information_preserved(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 8b382b89cf7..cab7692f5e7 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from typing_extensions import Literal @@ -54,10 +55,10 @@ def __init__( kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = "relu", **kwargs: Any, ) -> None: @@ -90,10 +91,10 @@ class _PeakSignalNoiseRatio(PeakSignalNoiseRatio): def __init__( self, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("PeakSignalNoiseRatio", "image") @@ -115,7 +116,7 @@ class _RelativeAverageSpectralError(RelativeAverageSpectralError): def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RelativeAverageSpectralError", "image") super().__init__(window_size=window_size, **kwargs) @@ -136,7 +137,7 @@ class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWi def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image") super().__init__(window_size=window_size, **kwargs) @@ -200,7 +201,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 5b1e58de0dd..97d95ccd926 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index a8dacbaf447..9143810f545 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -127,7 +128,7 @@ def __init__( self.add_state("pan", default=[], dist_reduce_fx="cat") self.add_state("pan_lr", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: + def update(self, preds: Tensor, target: dict[str, Tensor]) -> None: """Update state with preds and target. Args: diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index bf6b1c99a10..22c24b164f1 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index ac559ed0c68..c15b3302af9 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor @@ -47,7 +48,7 @@ class NoTrainInceptionV3(_FeatureExtractorInceptionV3): def __init__( self, name: str, - features_list: List[str], + features_list: list[str], feature_extractor_weights_path: Optional[str] = None, ) -> None: if not _TORCH_FIDELITY_AVAILABLE: @@ -64,7 +65,7 @@ def train(self, mode: bool) -> "NoTrainInceptionV3": """Force network to always be in evaluation mode.""" return super().train(False) - def _torch_fidelity_forward(self, x: Tensor) -> Tuple[Tensor, ...]: + def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]: """Forward method of inception net. Copy of the forward method from this file: @@ -298,7 +299,7 @@ def __init__( feature: Union[int, Module] = 2048, reset_real_features: bool = True, normalize: bool = False, - input_img_size: Tuple[int, int, int] = (3, 299, 299), + input_img_size: tuple[int, int, int] = (3, 299, 299), **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 20d53d10f2b..fd11a6afe03 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -101,7 +102,7 @@ class InceptionScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - features: List + features: list inception: Module feature_network: str = "inception" @@ -151,7 +152,7 @@ def update(self, imgs: Tensor) -> None: features = self.inception(imgs) self.features.append(features) - def compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor]: """Compute metric.""" features = dim_zero_cat(self.features) # random permute the features diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 018fc7a7511..99c2b04bf7b 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -263,7 +264,7 @@ def update(self, imgs: Tensor, real: bool) -> None: else: self.fake_features.append(features) - def compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor]: """Calculate KID score based on accumulated extracted features from the two distributions. Implementation inspired by `Fid Score`_ diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 1893fb734ba..811476dd75d 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, ClassVar, Optional, Union import torch from torch import Tensor @@ -98,7 +99,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): feature_network: str = "net" # due to the use of named tuple in the backbone the net variable cannot be scripted - __jit_ignored_attributes__: ClassVar[List[str]] = ["net"] + __jit_ignored_attributes__: ClassVar[list[str]] = ["net"] def __init__( self, diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index a1e2d2f4c0a..5d344b57dd1 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 117dca6f8cb..5c48b57c3fe 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union from torch import Tensor, nn @@ -166,7 +166,7 @@ def update(self, generator: GeneratorType) -> None: _validate_generator_model(generator, self.conditional) self.generator = generator - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute the perceptual path length.""" return perceptual_path_length( generator=self.generator, diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index fe774d2588b..5f00d21c7cb 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor @@ -86,10 +87,10 @@ class PeakSignalNoiseRatio(Metric): def __init__( self, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index c9fb157d6ff..0e12018c324 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index b226a6b19fd..f28e61b2fbf 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -130,7 +131,7 @@ def __init__( self.add_state("pan", default=[], dist_reduce_fx="cat") self.add_state("pan_lr", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: + def update(self, preds: Tensor, target: dict[str, Tensor]) -> None: """Update state with preds and target. Args: diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index b1eb32141a6..bca9504c1aa 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor @@ -69,7 +70,7 @@ class RelativeAverageSpectralError(Metric): def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index c1f7c652879..6312174b1be 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -68,7 +69,7 @@ class RootMeanSquaredErrorUsingSlidingWindow(Metric): def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 7aad1782852..b313158f80b 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 648f9c26029..fd9d12d770d 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -92,7 +93,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, @@ -155,7 +156,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: else: self.similarity.append(similarity) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute SSIM over state.""" if self.reduction == "elementwise_mean": similarity = self.similarity / self.total @@ -293,10 +294,10 @@ def __init__( kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = "relu", **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index ca6276b7e86..287e58a3a43 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index 3fea5e8986f..c503cc1f394 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 279537aac51..b270903eafd 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -18,9 +18,10 @@ import functools import inspect from abc import ABC, abstractmethod +from collections.abc import Generator, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, ClassVar, List, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -83,8 +84,8 @@ class Metric(Module, ABC): """ - __jit_ignored_attributes__: ClassVar[List[str]] = ["device"] - __jit_unused_properties__: ClassVar[List[str]] = [ + __jit_ignored_attributes__: ClassVar[list[str]] = ["device"] + __jit_unused_properties__: ClassVar[list[str]] = [ "is_differentiable", "higher_is_better", "plot_lower_bound", @@ -165,13 +166,13 @@ def __init__( self._dtype_convert = False # initialize state - self._defaults: Dict[str, Union[List, Tensor]] = {} - self._persistent: Dict[str, bool] = {} - self._reductions: Dict[str, Union[str, Callable[..., Any], None]] = {} + self._defaults: dict[str, Union[list, Tensor]] = {} + self._persistent: dict[str, bool] = {} + self._reductions: dict[str, Union[str, Callable[..., Any], None]] = {} # state management self._is_synced = False - self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None + self._cache: Optional[dict[str, Union[List[Tensor], Tensor]]] = None @property def _update_called(self) -> bool: @@ -193,7 +194,7 @@ def update_count(self) -> int: return self._update_count @property - def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]: + def metric_state(self) -> dict[str, Union[List[Tensor], Tensor]]: """Get the current state of the metric.""" return {attr: getattr(self, attr) for attr in self._defaults} @@ -401,7 +402,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: return batch_val - def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: + def merge_state(self, incoming_state: Union[dict[str, Any], "Metric"]) -> None: """Merge incoming metric state to the current state of the metric. Args: @@ -462,7 +463,7 @@ def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: self._reduce_states(incoming_state) - def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: + def _reduce_states(self, incoming_state: dict[str, Any]) -> None: """Add an incoming metric state to the current state of the metric. Args: @@ -725,7 +726,7 @@ def plot(self, *_: Any, **__: Any) -> Any: def _plot( self, - val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, + val: Optional[Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -776,7 +777,7 @@ def clone(self) -> "Metric": """Make a copy of the metric.""" return deepcopy(self) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """Get the current state, including all metric states, for the metric. Used for loading and saving a metric. @@ -785,7 +786,7 @@ def __getstate__(self) -> Dict[str, Any]: # ignore update and compute functions for pickling return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]} - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """Set the state of the metric, based on a input state. Used for loading and saving a metric. @@ -923,10 +924,10 @@ def persistent(self, mode: bool = False) -> None: def state_dict( # type: ignore[override] # todo self, - destination: Optional[Dict[str, Any]] = None, + destination: Optional[dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get the current state of metric as an dictionary. Args: @@ -937,7 +938,7 @@ def state_dict( # type: ignore[override] # todo If set to ``True``, detaching will not be performed. """ - destination: Dict[str, Union[torch.Tensor, List, Any]] = super().state_dict( + destination: dict[str, Union[torch.Tensor, list, Any]] = super().state_dict( destination=destination, # type: ignore[arg-type] prefix=prefix, keep_vars=keep_vars, @@ -955,9 +956,9 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + def _copy_state_dict(self) -> dict[str, Union[Tensor, list[Any]]]: """Copy the current state values.""" - cache: Dict[str, Union[Tensor, List[Any]]] = {} + cache: dict[str, Union[Tensor, list[Any]]] = {} for attr in self._defaults: current_value = getattr(self, attr) @@ -976,9 +977,9 @@ def _load_from_state_dict( prefix: str, local_metadata: dict, strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: """Load metric states from state_dict.""" for key in self._defaults: @@ -989,7 +990,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) - def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Filter kwargs such that they match the update signature of the metric.""" # filter all parameters based on update signature except those of # types `VAR_POSITIONAL` for `* args` and `VAR_KEYWORD` for `** kwargs` @@ -1172,7 +1173,7 @@ def __getitem__(self, idx: int) -> "CompositionalMetric": """Construct compositional metric using the get item operator.""" return CompositionalMetric(lambda x: x[idx], self, None) - def __getnewargs__(self) -> Tuple: + def __getnewargs__(self) -> tuple: """Needed method for construction of new metrics __new__ method.""" return tuple( Metric.__str__(self), diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index f49113e297f..cc9c0715be6 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor @@ -178,7 +179,7 @@ def __init__( "openai/clip-vit-large-patch14", ] = "clip_iqa", data_range: float = 1.0, - prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), + prompts: tuple[Union[str, tuple[str, str]]] = ("quality",), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -212,7 +213,7 @@ def update(self, images: Tensor) -> None: raise ValueError("Output probs should be a tensor") self.probs_list.append(probs) - def compute(self) -> Union[Tensor, Dict[str, Tensor]]: + def compute(self) -> Union[Tensor, dict[str, Tensor]]: """Compute metric.""" probs = dim_zero_cat(self.probs_list) if len(self.prompts_name) == 1: diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 92ca7ad6b4f..c89384fbb35 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -117,7 +118,7 @@ def __init__( self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None: + def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str]]) -> None: """Update CLIP score on a batch of images and text. Args: diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index df47cc079d1..5f361bfb477 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index 619e8196d4c..254796e96c9 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index a43cd548792..15be1bd43b4 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index 7f7f22ecb1b..19695a61c2c 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index 9986fa2ec6f..9744103f304 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index 6697ccadec9..c6d051b6ebe 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union from torch import Tensor diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index b1d8923290e..5c86ac00cab 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index b75762a3b0c..b5c7356aaab 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -14,6 +14,7 @@ from typing import Any, List, Optional import torch +from torch import Tensor from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update from torchmetrics.metric import Metric @@ -60,12 +61,12 @@ class CriticalSuccessIndex(Metric): is_differentiable: bool = False higher_is_better: bool = True - hits: torch.Tensor - misses: torch.Tensor - false_alarms: torch.Tensor - hits_list: List[torch.Tensor] - misses_list: List[torch.Tensor] - false_alarms_list: List[torch.Tensor] + hits: Tensor + misses: Tensor + false_alarms: Tensor + hits_list: List[Tensor] + misses_list: List[Tensor] + false_alarms_list: List[Tensor] def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -84,7 +85,7 @@ def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, ** self.add_state("misses_list", default=[], dist_reduce_fx="cat") self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat") - def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" hits, misses, false_alarms = _critical_success_index_update( preds, target, self.threshold, self.keep_sequence_dim @@ -98,7 +99,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: self.misses_list.append(misses) self.false_alarms_list.append(false_alarms) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Compute critical success index over state.""" if self.keep_sequence_dim is None: hits = self.hits diff --git a/src/torchmetrics/regression/explained_variance.py b/src/torchmetrics/regression/explained_variance.py index bf0ba0dc55f..833c1609e55 100644 --- a/src/torchmetrics/regression/explained_variance.py +++ b/src/torchmetrics/regression/explained_variance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index f9677f4c296..8a102dee08f 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -153,7 +154,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: num_outputs=self.num_outputs, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 253b7a05002..bd799c31d7f 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/log_cosh.py b/src/torchmetrics/regression/log_cosh.py index ca9395f9a50..750e5a3497e 100644 --- a/src/torchmetrics/regression/log_cosh.py +++ b/src/torchmetrics/regression/log_cosh.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/log_mse.py b/src/torchmetrics/regression/log_mse.py index da190016844..31dbba6accb 100644 --- a/src/torchmetrics/regression/log_mse.py +++ b/src/torchmetrics/regression/log_mse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mae.py b/src/torchmetrics/regression/mae.py index 46de2d70d21..da95b2b2d90 100644 --- a/src/torchmetrics/regression/mae.py +++ b/src/torchmetrics/regression/mae.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mape.py b/src/torchmetrics/regression/mape.py index cbc490e686e..7ee1eaf61f7 100644 --- a/src/torchmetrics/regression/mape.py +++ b/src/torchmetrics/regression/mape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/minkowski.py b/src/torchmetrics/regression/minkowski.py index 1c5d9e430f1..50785f5425c 100644 --- a/src/torchmetrics/regression/minkowski.py +++ b/src/torchmetrics/regression/minkowski.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 363c56d6a29..b82738ace4a 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 62562803542..bb44ae9c905 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 26d8ffdb5c0..8cf165471fe 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -32,7 +33,7 @@ def _final_aggregation( vars_y: Tensor, corrs_xy: Tensor, nbs: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Aggregate the statistics from multiple devices. Formula taken from here: `Aggregate the statistics from multiple devices`_ diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index a54502e087d..613dc69a141 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/rse.py b/src/torchmetrics/regression/rse.py index 2c9d524f03c..2776f8c9fda 100644 --- a/src/torchmetrics/regression/rse.py +++ b/src/torchmetrics/regression/rse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 62be592919d..de94903c8c0 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index 82b5702f476..a1c601a2a16 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index 70c8ae37855..3cd10070e7b 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 42fadb906ef..a5c532f145a 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index 8d5ac12a929..6d4382894a4 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/average_precision.py b/src/torchmetrics/retrieval/average_precision.py index b5527bd6afd..93ecc38cb51 100644 --- a/src/torchmetrics/retrieval/average_precision.py +++ b/src/torchmetrics/retrieval/average_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/fall_out.py b/src/torchmetrics/retrieval/fall_out.py index 08cda4ace08..9660f061370 100644 --- a/src/torchmetrics/retrieval/fall_out.py +++ b/src/torchmetrics/retrieval/fall_out.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/retrieval/hit_rate.py b/src/torchmetrics/retrieval/hit_rate.py index e981a126fa5..c04ff3ad14d 100644 --- a/src/torchmetrics/retrieval/hit_rate.py +++ b/src/torchmetrics/retrieval/hit_rate.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/ndcg.py b/src/torchmetrics/retrieval/ndcg.py index 2914726fde9..afd211cb142 100644 --- a/src/torchmetrics/retrieval/ndcg.py +++ b/src/torchmetrics/retrieval/ndcg.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/precision.py b/src/torchmetrics/retrieval/precision.py index e59cac4e382..70bf4d9794a 100644 --- a/src/torchmetrics/retrieval/precision.py +++ b/src/torchmetrics/retrieval/precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 9b4f4d38d0e..9eef71c154e 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -34,7 +35,7 @@ def _retrieval_recall_at_fixed_precision( recall: Tensor, top_k: Tensor, min_precision: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute maximum recall with condition that corresponding precision >= `min_precision`. Args: @@ -201,7 +202,7 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: self.preds.append(preds) self.target.append(target) - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" # concat all data indexes = dim_zero_cat(self.indexes) @@ -256,7 +257,7 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -380,7 +381,7 @@ def __init__( self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" precisions, recalls, top_k = super().compute() diff --git a/src/torchmetrics/retrieval/r_precision.py b/src/torchmetrics/retrieval/r_precision.py index c3892b76b8b..2c2466af430 100644 --- a/src/torchmetrics/retrieval/r_precision.py +++ b/src/torchmetrics/retrieval/r_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union from torch import Tensor diff --git a/src/torchmetrics/retrieval/recall.py b/src/torchmetrics/retrieval/recall.py index de77abc9a85..04fe881bc99 100644 --- a/src/torchmetrics/retrieval/recall.py +++ b/src/torchmetrics/retrieval/recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/reciprocal_rank.py b/src/torchmetrics/retrieval/reciprocal_rank.py index 17a9a2bd7a8..01ea27ae12b 100644 --- a/src/torchmetrics/retrieval/reciprocal_rank.py +++ b/src/torchmetrics/retrieval/reciprocal_rank.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index fc8cadd8c3a..05a6e29b387 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 95da9ab26d9..aef04be1e43 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index f1e8812ed30..4e646bb6a93 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -10,7 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Literal, Optional, Union import torch from torch import Tensor @@ -88,7 +89,7 @@ def __init__( num_classes: int, include_background: bool = False, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", **kwargs: Any, diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index 0fe831f5231..ae8dd3d2aea 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/shape/procrustes.py b/src/torchmetrics/shape/procrustes.py index a924fb48a4a..1d8396c7afb 100644 --- a/src/torchmetrics/shape/procrustes.py +++ b/src/torchmetrics/shape/procrustes.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index 0c7ffef29af..50c5091de28 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Literal, Optional from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 6e1bab1b9bd..103c5a47bc7 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -46,7 +47,7 @@ def _download_model_for_bert_score() -> None: __doctest_skip__ = ["BERTScore", "BERTScore.plot"] -def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> Dict[str, Tensor]: +def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> dict[str, Tensor]: """Create an input dictionary of ``input_ids`` and ``attention_mask`` for BERTScore calculation.""" return {"input_ids": torch.cat(input_ids), "attention_mask": torch.cat(attention_mask)} @@ -139,7 +140,7 @@ def __init__( all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Optional[Any] = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -231,7 +232,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s self.target_input_ids.append(target_dict["input_ids"]) self.target_attention_mask.append(target_dict["attention_mask"]) - def compute(self) -> Dict[str, Union[Tensor, List[float], str]]: + def compute(self) -> dict[str, Union[Tensor, list[float], str]]: """Calculate BERT scores.""" preds = { "input_ids": dim_zero_cat(self.preds_input_ids), diff --git a/src/torchmetrics/text/bleu.py b/src/torchmetrics/text/bleu.py index d05da1e33c4..a40525bbe0c 100644 --- a/src/torchmetrics/text/bleu.py +++ b/src/torchmetrics/text/bleu.py @@ -16,7 +16,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 0bdc5670f59..e0337801916 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor @@ -84,7 +85,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, total = _cer_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 1ff412ab1a4..64eb7c6b1d4 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -18,7 +18,8 @@ # Link: import itertools -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from collections.abc import Iterator, Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor @@ -44,8 +45,8 @@ "total_matching_word_n_grams", ) -_DICT_STATES_TYPES = Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +_DICT_STATES_TYPES = tuple[ + dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor] ] @@ -156,7 +157,7 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: if self.sentence_chrf_score is not None: self.sentence_chrf_score = n_grams_dicts_tuple[-1] - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate chrF/chrF++ score.""" if self.sentence_chrf_score is not None: return ( @@ -167,7 +168,7 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" - n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( + n_grams_dicts: dict[str, dict[int, Tensor]] = dict( zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) ) @@ -200,7 +201,7 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: """Return a metric state name w.r.t input args.""" return f"total_{text}_{n_gram_level}_{n}_grams" - def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: + def _get_text_n_gram_iterator(self) -> Iterator[tuple[tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) diff --git a/src/torchmetrics/text/edit.py b/src/torchmetrics/text/edit.py index de14f49ae3b..947fc79cd6c 100644 --- a/src/torchmetrics/text/edit.py +++ b/src/torchmetrics/text/edit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index 659181f8d39..c776eba2331 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, stack from typing_extensions import Literal @@ -112,7 +113,7 @@ def update( self.sentence_eed, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate extended edit distance score.""" average = _eed_compute(self.sentence_eed) diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index 31fea4adc23..74e931c7cba 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, ClassVar, List, Optional, Union import torch from torch import Tensor @@ -144,7 +145,7 @@ def __init__( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) self.model_name_or_path = model_name_or_path @@ -188,7 +189,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s self.target_input_ids.append(target_input_ids) self.target_attention_mask.append(target_attention_mask) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate selected information measure using the pre-trained language model.""" preds_dataloader = _get_dataloader( input_ids=dim_zero_cat(self.preds_input_ids), diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index b519445c05e..a898c9c4758 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor @@ -83,8 +84,8 @@ def __init__( def update( self, - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[str, list[str]], + target: Union[str, list[str]], ) -> None: """Update state with predictions and targets.""" errors, total = _mer_update(preds, target) diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index d13eac2f402..af8d9e70795 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor @@ -68,7 +69,7 @@ class Perplexity(Metric): def __init__( self, ignore_index: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) if ignore_index is not None and not isinstance(ignore_index, int): diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index 7bce72ed1c3..ec1ef711afc 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -108,7 +109,7 @@ def __init__( normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, accumulate: Literal["avg", "best"] = "best", - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -155,7 +156,7 @@ def update( if isinstance(target, str): target = [[target]] - output: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update( + output: dict[Union[int, str], list[dict[str, Tensor]]] = _rouge_score_update( preds, target, self.rouge_keys_values, @@ -169,7 +170,7 @@ def update( for tp, value in metric.items(): getattr(self, f"rouge{rouge_key}_{tp}").append(value.to(self.device)) # todo - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Calculate (Aggregate and provide confidence intervals) ROUGE score.""" update_output = {} for rouge_key in self.rouge_keys_values: diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 8e2e63ad82b..3a393a399df 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index e4be6670d19..a545da95803 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -118,7 +119,7 @@ def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: self.exact_match += exact_match self.total += total - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch.""" return _squad_compute(self.f1_score, self.exact_match, self.total) diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 98ec0a90235..6cdd1d02118 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor @@ -108,7 +109,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Sequence[Union[str, S self.sentence_ter, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate the translate error rate (TER).""" ter = _ter_compute(self.total_num_edits, self.total_tgt_len) if self.sentence_ter is not None: diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index 0950bd4de42..93bc7b8da13 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor @@ -83,7 +84,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, total = _wer_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index 16b71720c3b..edd00b7f657 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor @@ -81,7 +82,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, target_total, preds_total = _word_info_lost_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index bbdd2b7a235..b9b4351f548 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor @@ -82,7 +83,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 449efcade2d..a1f7e47632f 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -14,9 +14,10 @@ import multiprocessing import os import sys +from collections.abc import Mapping, Sequence from functools import partial from time import perf_counter -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, no_type_check +from typing import Any, Callable, Optional, no_type_check from unittest.mock import Mock import torch @@ -70,7 +71,7 @@ def _basic_input_validation( raise ValueError("If you set `multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") -def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[DataType, int]: +def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> tuple[DataType, int]: """Check that the shape and type of inputs are consistent with each other. The input types needs to be one of allowed input types (see the documentation of docstring of @@ -301,7 +302,7 @@ def _check_classification_inputs( def _input_squeeze( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Remove excess dimensions.""" if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) @@ -318,7 +319,7 @@ def _input_format_classification( num_classes: Optional[int] = None, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, DataType]: +) -> tuple[Tensor, Tensor, DataType]: """Convert preds and target tensors into common format. Preds and targets are supposed to fall into one of these categories (and are @@ -460,7 +461,7 @@ def _input_format_classification_one_hot( target: Tensor, threshold: float = 0.5, multilabel: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert preds and target tensors into one hot spare label tensors. Args: @@ -508,7 +509,7 @@ def _check_retrieval_functional_inputs( preds: Tensor, target: Tensor, allow_non_binary_target: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -541,7 +542,7 @@ def _check_retrieval_inputs( target: Tensor, allow_non_binary_target: bool = False, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -588,7 +589,7 @@ def _check_retrieval_target_and_prediction_types( preds: Tensor, target: Tensor, allow_non_binary_target: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -633,8 +634,8 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-6) -> bool: @no_type_check def check_forward_full_state_property( metric_class: Metric, - init_args: Optional[Dict[str, Any]] = None, - input_args: Optional[Dict[str, Any]] = None, + init_args: Optional[dict[str, Any]] = None, + input_args: Optional[dict[str, Any]] = None, num_update_to_compare: Sequence[int] = [10, 100, 1000], reps: int = 5, ) -> None: diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index e526ecc8456..cbb648a8844 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -81,7 +81,7 @@ def _adjust_weights_safe_divide( return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) -def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: +def _auc_format_inputs(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: """Check that auc input is correct.""" x = x.squeeze() if x.ndim > 1 else x y = y.squeeze() if y.ndim > 1 else y diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 1a68e655c33..e5bb148a9ce 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -60,7 +61,7 @@ def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist] -def _flatten_dict(x: Dict) -> Tuple[Dict, bool]: +def _flatten_dict(x: dict) -> tuple[dict, bool]: """Flatten dict of dicts into single dict and checking for duplicates in keys along the way.""" new_dict = {} duplicates = False diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 90239b46af0..a68cffec3f1 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -98,7 +98,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: - """Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. + """Gather all tensors from several ddp processes onto a list that is broadcast to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case tensors are padded, gathered and then trimmed to secure equal workload for all processes. diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index bfc2fd20190..155f1bb8f60 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type from lightning_utilities.core.enums import StrEnum from typing_extensions import Literal @@ -25,7 +24,7 @@ def _name() -> str: return "Task" @classmethod - def from_str(cls: Type["EnumStr"], value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": + def from_str(cls: type["EnumStr"], value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": """Load from string. Raises: diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 4c88c078050..4d14349b7f9 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator, Sequence from itertools import product from math import ceil, floor, sqrt -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union, no_type_check +from typing import Any, List, Optional, Union, no_type_check import numpy as np import torch @@ -26,13 +27,13 @@ import matplotlib.axes import matplotlib.pyplot as plt - _PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] + _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] _AX_TYPE = matplotlib.axes.Axes _CMAP_TYPE = Union[matplotlib.colors.Colormap, str] style_change = plt.style.context else: - _PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc] + _PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc] _AX_TYPE = object _CMAP_TYPE = object # type: ignore[misc] @@ -62,7 +63,7 @@ def _error_on_missing_matplotlib() -> None: @style_change(_style) def plot_single_or_multi_val( - val: Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]], + val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]], ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] higher_is_better: Optional[bool] = None, lower_bound: Optional[float] = None, @@ -171,7 +172,7 @@ def plot_single_or_multi_val( return fig, ax -def _get_col_row_split(n: int) -> Tuple[int, int]: +def _get_col_row_split(n: int) -> tuple[int, int]: """Split `n` figures into `rows` x `cols` figures.""" nsq = sqrt(n) if int(nsq) == nsq: # square number @@ -181,7 +182,7 @@ def _get_col_row_split(n: int) -> Tuple[int, int]: return ceil(nsq), ceil(nsq) -def _get_text_color(patch_color: Tuple[float, float, float, float]) -> str: +def _get_text_color(patch_color: tuple[float, float, float, float]) -> str: """Get the text color for a given value and colormap. Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance. @@ -221,7 +222,7 @@ def plot_confusion_matrix( confmat: Tensor, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[Union[int, str]]] = None, + labels: Optional[list[Union[int, str]]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot an confusion matrix. @@ -294,13 +295,13 @@ def plot_confusion_matrix( @style_change(_style) def plot_curve( - curve: Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]], + curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]], score: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] - label_names: Optional[Tuple[str, str]] = None, + label_names: Optional[tuple[str, str]] = None, legend_name: Optional[str] = None, name: Optional[str] = None, - labels: Optional[List[Union[int, str]]] = None, + labels: Optional[list[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py. diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index d59f7724c2a..566c4d4ab66 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -145,7 +146,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute the bootstrapped metric values. Always returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 217c94d6bc0..c37e3c7fa2b 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import typing -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor @@ -115,7 +116,7 @@ class ClasswiseWrapper(WrapperMetric): def __init__( self, metric: Metric, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, prefix: Optional[str] = None, postfix: Optional[str] = None, ) -> None: @@ -138,11 +139,11 @@ def __init__( self._update_count = 1 - def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Filter kwargs for the metric.""" return self.metric._filter_kwargs(**kwargs) - def _convert_output(self, x: Tensor) -> Dict[str, Any]: + def _convert_output(self, x: Tensor) -> dict[str, Any]: """Convert output to dictionary with labels as keys.""" # Will set the class name as prefix if neither prefix nor postfix is given if not self._prefix and not self._postfix: @@ -163,7 +164,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update state.""" self.metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute metric.""" return self._convert_output(self.metric.compute()) diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index 1bd1b81783b..62302fff140 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import lru_cache -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union from torch.nn import Module @@ -83,7 +84,7 @@ class FeatureShare(MetricCollection): def __init__( self, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], max_cache_size: Optional[int] = None, ) -> None: # disable compute groups because the feature sharing is more custom diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index 09684c55919..25300e8fbe0 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor @@ -82,7 +83,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update the underlying metric.""" self._base_metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute the underlying metric as well as max and min values for this metric. Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 7853e6257e6..12a135099a2 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -103,7 +104,7 @@ def __init__( self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> List[Tuple[Tensor, Tensor]]: + def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index 04ddd87ad71..66e9516591e 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # this is just a bypass for this module name collision with built-in one +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union from torch import Tensor, nn @@ -133,7 +134,7 @@ class MultitaskWrapper(WrapperMetric): def __init__( self, - task_metrics: Dict[str, Union[Metric, MetricCollection]], + task_metrics: dict[str, Union[Metric, MetricCollection]], prefix: Optional[str] = None, postfix: Optional[str] = None, ) -> None: @@ -159,7 +160,7 @@ def __init__( raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") self._postfix = postfix or "" - def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: + def items(self, flatten: bool = True) -> Iterable[tuple[str, nn.Module]]: """Iterate over task and task metrics. Args: @@ -203,7 +204,7 @@ def values(self, flatten: bool = True) -> Iterable[nn.Module]: else: yield metric - def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> None: + def update(self, task_preds: dict[str, Any], task_targets: dict[str, Any]) -> None: """Update each task's metric with its corresponding pred and target. Args: @@ -223,15 +224,15 @@ def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> No target = task_targets[task_name] metric.update(pred, target) - def _convert_output(self, output: Dict[str, Any]) -> Dict[str, Any]: + def _convert_output(self, output: dict[str, Any]) -> dict[str, Any]: """Convert the output of the underlying metrics to a dictionary with the task names as keys.""" return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()} - def compute(self) -> Dict[str, Any]: + def compute(self) -> dict[str, Any]: """Compute metrics for all tasks.""" return self._convert_output({task_name: metric.compute() for task_name, metric in self.task_metrics.items()}) - def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]: + def forward(self, task_preds: dict[str, Tensor], task_targets: dict[str, Tensor]) -> dict[str, Any]: """Call underlying forward methods for all tasks and return the result as a dictionary.""" # This method is overridden because we do not need the complex version defined in Metric, that relies on the # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the @@ -268,7 +269,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> return multitask_copy def plot( - self, val: Optional[Union[Dict, Sequence[Dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None + self, val: Optional[Union[dict, Sequence[dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None ) -> Sequence[_PLOT_OUT_TYPE]: """Plot a single or multiple values from the metric. @@ -352,7 +353,7 @@ def plot( fig_axs = [] for i, (task_name, task_metric) in enumerate(self.task_metrics.items()): ax = axes[i] if axes is not None else None - if isinstance(val, Dict): + if isinstance(val, dict): f, a = task_metric.plot(val[task_name], ax=ax) elif isinstance(val, Sequence): f, a = task_metric.plot([v[task_name] for v in val], ax=ax) diff --git a/src/torchmetrics/wrappers/running.py b/src/torchmetrics/wrappers/running.py index f877c65ad95..a57cb8333f0 100644 --- a/src/torchmetrics/wrappers/running.py +++ b/src/torchmetrics/wrappers/running.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index c1fe9957b1b..0a8ca7eac1d 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor @@ -103,10 +104,10 @@ class MetricTracker(ModuleList): """ - maximize: Union[bool, List[bool]] + maximize: Union[bool, list[bool]] def __init__( - self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, List[bool]]] = True + self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, list[bool]]] = True ) -> None: super().__init__() if not isinstance(metric, (Metric, MetricCollection)): @@ -220,10 +221,10 @@ def best_metric( None, float, Tensor, - Tuple[Union[int, float, Tensor], Union[int, float, Tensor]], - Tuple[None, None], - Dict[str, Union[float, None]], - Tuple[Dict[str, Union[float, None]], Dict[str, Union[int, None]]], + tuple[Union[int, float, Tensor], Union[int, float, Tensor]], + tuple[None, None], + dict[str, Union[float, None]], + tuple[dict[str, Union[float, None]], dict[str, Union[int, None]]], ]: """Return the highest metric out of all tracked. diff --git a/src/torchmetrics/wrappers/transformations.py b/src/torchmetrics/wrappers/transformations.py index f2ac106fe86..8a86e3224af 100644 --- a/src/torchmetrics/wrappers/transformations.py +++ b/src/torchmetrics/wrappers/transformations.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch @@ -28,7 +28,7 @@ class MetricInputTransformer(WrapperMetric): """ - def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Dict[str, Any]) -> None: + def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) if not isinstance(wrapped_metric, (Metric, MetricCollection)): raise TypeError( @@ -53,7 +53,7 @@ def transform_target(self, target: torch.Tensor) -> torch.Tensor: """ return target - def _wrap_transform(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: + def _wrap_transform(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: """Wrap transformation functions to dispatch args to their individual transform functions.""" if len(args) == 1: return (self.transform_pred(args[0]),) @@ -61,7 +61,7 @@ def _wrap_transform(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: return self.transform_pred(args[0]), self.transform_target(args[1]) return self.transform_pred(args[0]), self.transform_target(args[1]), *args[2:] - def update(self, *args: torch.Tensor, **kwargs: Dict[str, Any]) -> None: + def update(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> None: """Wrap the update call of the underlying metric.""" args = self._wrap_transform(*args) self.wrapped_metric.update(*args, **kwargs) @@ -70,7 +70,7 @@ def compute(self) -> Any: """Wrap the compute call of the underlying metric.""" return self.wrapped_metric.compute() - def forward(self, *args: torch.Tensor, **kwargs: Dict[str, Any]) -> Any: + def forward(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> Any: """Wrap the forward call of the underlying metric.""" args = self._wrap_transform(*args) return self.wrapped_metric.forward(*args, **kwargs) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 98cc110a3ff..1622e4ad8a3 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -13,9 +13,10 @@ # limitations under the License. import pickle import sys +from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest @@ -42,7 +43,7 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): _assert_allclose(pl_res, ref_res, atol=atol) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert np.allclose( @@ -60,7 +61,7 @@ def _assert_tensor(tm_result: Any, key: Optional[str] = None) -> None: if isinstance(tm_result, Sequence): for plr in tm_result: _assert_tensor(plr) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert isinstance(tm_result[key], Tensor) @@ -73,7 +74,7 @@ def _assert_requires_grad(metric: Metric, tm_result: Any, key: Optional[str] = N if isinstance(tm_result, Sequence): for plr in tm_result: _assert_requires_grad(metric, plr, key=key) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert metric.is_differentiable == tm_result[key].requires_grad @@ -84,8 +85,8 @@ def _assert_requires_grad(metric: Metric, tm_result: Any, key: Optional[str] = N def _class_test( rank: int, world_size: int, - preds: Union[Tensor, list, List[Dict[str, Tensor]]], - target: Union[Tensor, list, List[Dict[str, Tensor]]], + preds: Union[Tensor, list, list[dict[str, Tensor]]], + target: Union[Tensor, list, list[dict[str, Tensor]]], metric_class: Metric, reference_metric: Callable, dist_sync_on_step: bool, @@ -251,7 +252,7 @@ def _class_test( def _functional_test( preds: Union[Tensor, list], - target: Union[Tensor, list, List[Dict[str, Tensor]]], + target: Union[Tensor, list, list[dict[str, Tensor]]], metric_functional: Callable, reference_metric: Callable, metric_args: Optional[dict] = None, @@ -316,7 +317,7 @@ def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, - target: Union[Tensor, List[Dict[str, Tensor]]], + target: Union[Tensor, list[dict[str, Tensor]]], device: str = "cpu", dtype: torch.dtype = torch.half, **kwargs_update: Any, @@ -419,8 +420,8 @@ def run_functional_metric_test( def run_class_metric_test( self, ddp: bool, - preds: Union[Tensor, List[Dict]], - target: Union[Tensor, List[Dict]], + preds: Union[Tensor, list[dict]], + target: Union[Tensor, list[dict]], metric_class: Metric, reference_metric: Callable, dist_sync_on_step: bool = False, @@ -684,7 +685,7 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: return x -def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]: +def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> tuple[Tensor, Tensor]: """Remove samples that are equal to the ignore_index in comparison functions. Example: @@ -703,7 +704,7 @@ def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[in def remove_ignore_index_groups( target: Tensor, preds: Tensor, groups: Tensor, ignore_index: Optional[int] -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Version of the remove_ignore_index which includes groups.""" if ignore_index is not None: idx = target == ignore_index diff --git a/tests/unittests/audio/test_dnsmos.py b/tests/unittests/audio/test_dnsmos.py index c3c1df7a03b..80607057467 100644 --- a/tests/unittests/audio/test_dnsmos.py +++ b/tests/unittests/audio/test_dnsmos.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Optional import numpy as np import pytest @@ -43,7 +43,7 @@ class InferenceSession: # type:ignore """Dummy InferenceSession.""" - def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + def __init__(self, **kwargs: dict[str, Any]) -> None: ... SAMPLING_RATE = 16000 @@ -82,7 +82,7 @@ def _get_polyfit_val(self, sig, bak, ovr, is_personalized): return sig_poly, bak_poly, ovr_poly - def __call__(self, aud, input_fs, is_personalized) -> Dict[str, Any]: + def __call__(self, aud, input_fs, is_personalized) -> dict[str, Any]: fs = SAMPLING_RATE audio = librosa.resample(aud, orig_sr=input_fs, target_sr=fs) if input_fs != fs else aud actual_audio_len = len(audio) @@ -143,7 +143,7 @@ def _reference_metric_batch( personalized: bool, device: Optional[str] = None, # for tester reduce_mean: bool = False, - **kwargs: Dict[str, Any], # for tester + **kwargs: dict[str, Any], # for tester ): # download onnx first _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", torch.device("cpu")) @@ -170,7 +170,7 @@ def _reference_metric_batch( return score.reshape(*shape[:-1], 4).reshape(shape[:-1] + (4,)).numpy() -def _dnsmos_cheat(preds, target, **kwargs: Dict[str, Any]): +def _dnsmos_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as the deep_noise_suppression_mean_opinion_score doesn't need target return deep_noise_suppression_mean_opinion_score(preds, **kwargs) diff --git a/tests/unittests/audio/test_nisqa.py b/tests/unittests/audio/test_nisqa.py index a36ad0db1c6..06eac64710c 100644 --- a/tests/unittests/audio/test_nisqa.py +++ b/tests/unittests/audio/test_nisqa.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Any, Dict, Tuple +from typing import Any import pytest import torch @@ -115,7 +115,7 @@ def _reference_metric(preds): return out.mean(dim=0) if mean else out.reshape(*preds.shape[:-1], 5) -def _nisqa_cheat(preds, target, **kwargs: Dict[str, Any]): +def _nisqa_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as non_intrusive_speech_quality_assessment does not need a target return non_intrusive_speech_quality_assessment(preds, **kwargs) @@ -156,7 +156,7 @@ def test_nisqa_functional(self, preds: Tensor, reference: Tensor, fs: int, devic @pytest.mark.parametrize("shape", [(3000,), (2, 3000), (1, 2, 3000), (2, 3, 1, 3000)]) -def test_shape(shape: Tuple[int]): +def test_shape(shape: tuple[int]): """Test output shape.""" preds = torch.rand(*shape) out = non_intrusive_speech_quality_assessment(preds, 16000) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 85baab5e045..70c5b44c6ab 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Tuple +from typing import Callable import numpy as np import pytest @@ -57,7 +57,7 @@ def _reference_scipy_pit( target: Tensor, metric_func: Callable, eval_func: str, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Naive implementation of `Permutation Invariant Training` based on Scipy. Args: @@ -87,7 +87,7 @@ def _reference_scipy_pit( return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) -def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: return _reference_scipy_pit( preds=preds, target=target, @@ -96,7 +96,7 @@ def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Ten ) -def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: return _reference_scipy_pit( preds=preds, target=target, diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index e7370546478..d3a18cca357 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Dict +from typing import Any import pytest import torch @@ -30,7 +30,7 @@ def _reference_srmr_batch( - preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, reduce_mean: bool = False, **kwargs: Dict[str, Any] + preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, reduce_mean: bool = False, **kwargs: dict[str, Any] ): # shape: preds [BATCH_SIZE, Time] shape = preds.shape @@ -51,7 +51,7 @@ def _reference_srmr_batch( return srmr -def _speech_reverberation_modulation_energy_ratio_cheat(preds, target, **kwargs: Dict[str, Any]): +def _speech_reverberation_modulation_energy_ratio_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as the speech_reverberation_modulation_energy_ratio doesn't need target return speech_reverberation_modulation_energy_ratio(preds, **kwargs) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 9e89d041fd7..5f627676f09 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from functools import partial -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from unittest import mock import numpy as np @@ -72,7 +72,7 @@ def _reference_fairlearn_binary(preds, target, groups, ignore_index): } -def _assert_tensor(pl_result: Dict[str, Tensor], key: Optional[str] = None) -> None: +def _assert_tensor(pl_result: dict[str, Tensor], key: Optional[str] = None) -> None: if isinstance(pl_result, dict) and key is None: for key, val in pl_result.items(): assert isinstance(val, Tensor), f"{key!r} is not a Tensor!" @@ -81,7 +81,7 @@ def _assert_tensor(pl_result: Dict[str, Tensor], key: Optional[str] = None) -> N def _assert_allclose( # todo: unify with the general assert_allclose - pl_result: Dict[str, Tensor], sk_result: Dict[str, Tensor], atol: float = 1e-8, key: Optional[str] = None + pl_result: dict[str, Tensor], sk_result: dict[str, Tensor], atol: float = 1e-8, key: Optional[str] = None ) -> None: if isinstance(pl_result, dict) and key is None: for (pl_key, pl_val), (sk_key, sk_val) in zip(pl_result.items(), sk_result.items()): diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 4c864d0e9af..f4fe1d1ee06 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import numpy as np import pytest @@ -180,7 +180,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 245fb4097fc..58287aa7fcd 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import numpy as np import pytest @@ -214,7 +214,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index e14cb2c96a0..09a9675d380 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Dict, List, NamedTuple +from typing import NamedTuple import numpy as np import pytest @@ -34,7 +34,7 @@ class _Input(NamedTuple): preds: Tensor - target: List[Dict[str, Tensor]] + target: list[dict[str, Tensor]] ms: Tensor pan: Tensor pan_lr: Tensor diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index 4cb42cf36be..52d77896235 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Dict, List, NamedTuple +from typing import NamedTuple import pytest import torch @@ -32,7 +32,7 @@ class _Input(NamedTuple): preds: Tensor - target: List[Dict[str, Tensor]] + target: list[dict[str, Tensor]] ms: Tensor pan: Tensor pan_lr: Tensor diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index e2804ecebb9..9e71a30ca0a 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import List, NamedTuple +from typing import NamedTuple import matplotlib import matplotlib.pyplot as plt @@ -33,7 +33,7 @@ class _InputImagesCaptions(NamedTuple): images: Tensor - captions: List[List[str]] + captions: list[list[str]] captions = [ diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 03d5429ba47..6747be475ed 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest @@ -59,7 +59,7 @@ def _retrieval_aggregate( return aggregation(values, dim=dim) -def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: +def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> list[Union[Tensor, np.ndarray]]: """Extract group indexes. Given an integer :class:`~torch.Tensor` or `np.ndarray` `indexes`, return a :class:`~torch.Tensor` or @@ -151,7 +151,7 @@ def _compute_sklearn_metric( return np.array(0.0) -def _concat_tests(*tests: Tuple[Dict]) -> Dict: +def _concat_tests(*tests: tuple[dict]) -> dict: """Concat tests composed by a string and a list of arguments.""" assert len(tests), "`_concat_tests` expects at least an argument" assert all(tests[0]["argnames"] == x["argnames"] for x in tests[1:]), "the header must be the same for all tests" @@ -408,7 +408,7 @@ def _errors_test_class_metric( metric_class: Metric, message: str = "", metric_args: Optional[dict] = None, - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ): """Check types, parameters and errors. @@ -437,7 +437,7 @@ def _errors_test_functional_metric( target: Tensor, metric_functional: Metric, message: str = "", - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ): """Check types, parameters and errors. @@ -564,7 +564,7 @@ def run_metric_class_arguments_test( metric_class: Metric, message: str = "", metric_args: Optional[dict] = None, - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ) -> None: """Test that specific errors are raised for incorrect input.""" @@ -585,7 +585,7 @@ def run_functional_metric_arguments_test( target: Tensor, metric_functional: Callable, message: str = "", - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ) -> None: """Test that specific errors are raised for incorrect input.""" diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index 691334d135e..d8e9817bd51 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np import pytest @@ -42,7 +42,7 @@ def _compute_precision_recall_curve( empty_target_action: str = "skip", reverse: bool = False, aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute metric with multiple iterations over every query predictions set. Didn't find a reliable implementation of precision-recall curve in Information Retrieval, diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index e55944a72af..2b9f8381d42 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -13,8 +13,9 @@ # limitations under the License. import pickle import sys +from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest @@ -47,7 +48,7 @@ def _assert_all_close_regardless_of_order( elif isinstance(pl_result, Sequence): for pl_res, ref_res in zip(pl_result, ref_result): _assert_allclose(pl_res, ref_res, atol=atol) - elif isinstance(pl_result, Dict): + elif isinstance(pl_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert np.allclose( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 1d74f0c858d..d7e1fb22609 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 6ef3f7390be..99de422ba26 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.cer import char_error_rate @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_cer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_cer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import cer except ImportError: diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 233c9451381..ca3995e2519 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor, tensor diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 69e595465a7..56592c34762 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.mer import match_error_rate @@ -24,7 +24,7 @@ seed_all(42) -def _reference_jiwer_mer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_mer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import compute_measures except ImportError: diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index a40885587e8..3a8acef5b16 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -13,8 +13,9 @@ # limitations under the License. import re +from collections.abc import Sequence from functools import partial -from typing import Callable, Sequence, Union +from typing import Callable, Union import pytest import torch diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index bf2d45fbd5a..54da47a6a8a 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from lightning_utilities.core.imports import RequirementCache diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index 6c047730067..1f896bdbbbe 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor, tensor diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index 16b03849f84..e57c51af5fd 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wer import word_error_rate @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import compute_measures except ImportError: diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 37278b829f1..9b1615e5ee6 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wil import word_information_lost @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wil(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import wil except ImportError: diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index a6523babd67..b0393d12b29 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wip import word_information_preserved @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wip(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import wip except ImportError: