From c79dcb849a054bd9ba81233a2c54b384815bc349 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 5 Nov 2024 13:16:58 +0000 Subject: [PATCH 01/15] bump: drop support for python 3.8 --- .github/workflows/_focus-diff.yml | 2 -- .github/workflows/ci-checks.yml | 2 +- .github/workflows/ci-tests.yml | 2 +- .github/workflows/publish-pkg.yml | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_focus-diff.yml b/.github/workflows/_focus-diff.yml index faa4812f375..0e68eac2406 100644 --- a/.github/workflows/_focus-diff.yml +++ b/.github/workflows/_focus-diff.yml @@ -18,8 +18,6 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - #with: - # python-version: 3.8 - name: Get PR diff id: diff-domains diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index d7abe9bb569..fa3eb8ccdee 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -31,7 +31,7 @@ jobs: testing-matrix: | { "os": ["ubuntu-22.04", "macos-13", "windows-2022"], - "python-version": ["3.8", "3.11"] + "python-version": ["3.9", "3.11"] } check-md-links: diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 7d44adf3aab..4e79c632692 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -42,7 +42,7 @@ jobs: - "2.5.0" include: # cover additional python and PT combinations - - { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" } + - { os: "ubuntu-20.04", python-version: "3.9", pytorch-version: "2.0.1", requires: "oldest" } - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" } - { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" } # standard mac machine, not the M1 diff --git a/.github/workflows/publish-pkg.yml b/.github/workflows/publish-pkg.yml index f7b3f46997f..490eeac9c15 100644 --- a/.github/workflows/publish-pkg.yml +++ b/.github/workflows/publish-pkg.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.10" - name: Install dependencies run: >- From ea058dedd9062c425f560ed15ba6d3b93106a582 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 5 Nov 2024 13:17:38 +0000 Subject: [PATCH 02/15] setup + lint --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a765978081..ca69b1c1de3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ ] [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 120 #[tool.ruff.pycodestyle] diff --git a/setup.py b/setup.py index 2324b660cc0..c51d4915dff 100755 --- a/setup.py +++ b/setup.py @@ -215,7 +215,7 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx include_package_data=True, zip_safe=False, keywords=["deep learning", "machine learning", "pytorch", "metrics", "AI"], - python_requires=">=3.8", + python_requires=">=3.9", setup_requires=[], install_requires=BASE_REQUIREMENTS, extras_require=_prepare_extras(), From e9d901bc489066d94d232f67aae8ba1f9f8f532c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:19:21 +0000 Subject: [PATCH 03/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- _samples/rouge_score-own_normalizer_and_tokenizer.py | 2 +- setup.py | 3 ++- src/torchmetrics/aggregation.py | 3 ++- src/torchmetrics/audio/dnsmos.py | 3 ++- src/torchmetrics/audio/nisqa.py | 3 ++- src/torchmetrics/audio/pesq.py | 3 ++- src/torchmetrics/audio/pit.py | 3 ++- src/torchmetrics/audio/sdr.py | 3 ++- src/torchmetrics/audio/snr.py | 3 ++- src/torchmetrics/audio/srmr.py | 3 ++- src/torchmetrics/audio/stoi.py | 3 ++- src/torchmetrics/classification/accuracy.py | 3 ++- src/torchmetrics/classification/auroc.py | 3 ++- src/torchmetrics/classification/average_precision.py | 3 ++- src/torchmetrics/classification/calibration_error.py | 3 ++- src/torchmetrics/classification/cohen_kappa.py | 3 ++- src/torchmetrics/classification/dice.py | 3 ++- src/torchmetrics/classification/exact_match.py | 3 ++- src/torchmetrics/classification/f_beta.py | 3 ++- src/torchmetrics/classification/group_fairness.py | 3 ++- src/torchmetrics/classification/hamming.py | 3 ++- src/torchmetrics/classification/hinge.py | 3 ++- src/torchmetrics/classification/jaccard.py | 3 ++- src/torchmetrics/classification/matthews_corrcoef.py | 3 ++- src/torchmetrics/classification/negative_predictive_value.py | 3 ++- src/torchmetrics/classification/precision_fixed_recall.py | 3 ++- src/torchmetrics/classification/precision_recall.py | 3 ++- src/torchmetrics/classification/ranking.py | 3 ++- src/torchmetrics/classification/recall_fixed_precision.py | 3 ++- src/torchmetrics/classification/specificity.py | 3 ++- src/torchmetrics/clustering/adjusted_mutual_info_score.py | 3 ++- src/torchmetrics/clustering/adjusted_rand_score.py | 3 ++- src/torchmetrics/clustering/calinski_harabasz_score.py | 3 ++- src/torchmetrics/clustering/davies_bouldin_score.py | 3 ++- src/torchmetrics/clustering/dunn_index.py | 3 ++- src/torchmetrics/clustering/fowlkes_mallows_index.py | 3 ++- .../clustering/homogeneity_completeness_v_measure.py | 3 ++- src/torchmetrics/clustering/mutual_info_score.py | 3 ++- src/torchmetrics/clustering/normalized_mutual_info_score.py | 3 ++- src/torchmetrics/clustering/rand_score.py | 3 ++- src/torchmetrics/collections.py | 3 ++- src/torchmetrics/detection/_deprecated.py | 3 ++- src/torchmetrics/detection/_mean_ap.py | 3 ++- src/torchmetrics/detection/ciou.py | 3 ++- src/torchmetrics/detection/diou.py | 3 ++- src/torchmetrics/detection/giou.py | 3 ++- src/torchmetrics/detection/helpers.py | 3 ++- src/torchmetrics/detection/iou.py | 3 ++- src/torchmetrics/detection/mean_ap.py | 3 ++- src/torchmetrics/detection/panoptic_qualities.py | 3 ++- .../functional/classification/precision_recall_curve.py | 3 ++- src/torchmetrics/functional/detection/_deprecated.py | 2 +- .../functional/detection/_panoptic_quality_common.py | 3 ++- src/torchmetrics/functional/detection/panoptic_qualities.py | 2 +- src/torchmetrics/functional/image/_deprecated.py | 3 ++- src/torchmetrics/functional/image/ssim.py | 3 ++- src/torchmetrics/functional/image/uqi.py | 3 ++- src/torchmetrics/functional/image/utils.py | 3 ++- src/torchmetrics/functional/regression/explained_variance.py | 3 ++- src/torchmetrics/functional/text/_deprecated.py | 3 ++- src/torchmetrics/functional/text/bert.py | 3 ++- src/torchmetrics/functional/text/bleu.py | 3 ++- src/torchmetrics/functional/text/chrf.py | 3 ++- src/torchmetrics/functional/text/edit.py | 3 ++- src/torchmetrics/functional/text/eed.py | 3 ++- src/torchmetrics/functional/text/helper.py | 3 ++- src/torchmetrics/functional/text/infolm.py | 3 ++- src/torchmetrics/functional/text/rouge.py | 3 ++- src/torchmetrics/functional/text/sacre_bleu.py | 3 ++- src/torchmetrics/functional/text/ter.py | 3 ++- src/torchmetrics/image/_deprecated.py | 3 ++- src/torchmetrics/image/d_lambda.py | 3 ++- src/torchmetrics/image/d_s.py | 3 ++- src/torchmetrics/image/ergas.py | 3 ++- src/torchmetrics/image/fid.py | 3 ++- src/torchmetrics/image/inception.py | 3 ++- src/torchmetrics/image/kid.py | 3 ++- src/torchmetrics/image/lpip.py | 3 ++- src/torchmetrics/image/mifid.py | 3 ++- src/torchmetrics/image/psnr.py | 3 ++- src/torchmetrics/image/psnrb.py | 3 ++- src/torchmetrics/image/qnr.py | 3 ++- src/torchmetrics/image/rase.py | 3 ++- src/torchmetrics/image/rmse_sw.py | 3 ++- src/torchmetrics/image/sam.py | 3 ++- src/torchmetrics/image/ssim.py | 3 ++- src/torchmetrics/image/tv.py | 3 ++- src/torchmetrics/image/uqi.py | 3 ++- src/torchmetrics/metric.py | 3 ++- src/torchmetrics/multimodal/clip_iqa.py | 3 ++- src/torchmetrics/multimodal/clip_score.py | 3 ++- src/torchmetrics/nominal/cramers.py | 3 ++- src/torchmetrics/nominal/fleiss_kappa.py | 3 ++- src/torchmetrics/nominal/pearson.py | 3 ++- src/torchmetrics/nominal/theils_u.py | 3 ++- src/torchmetrics/nominal/tschuprows.py | 3 ++- src/torchmetrics/regression/concordance.py | 3 ++- src/torchmetrics/regression/cosine_similarity.py | 3 ++- src/torchmetrics/regression/explained_variance.py | 3 ++- src/torchmetrics/regression/kendall.py | 3 ++- src/torchmetrics/regression/kl_divergence.py | 3 ++- src/torchmetrics/regression/log_cosh.py | 3 ++- src/torchmetrics/regression/log_mse.py | 3 ++- src/torchmetrics/regression/mae.py | 3 ++- src/torchmetrics/regression/mape.py | 3 ++- src/torchmetrics/regression/minkowski.py | 3 ++- src/torchmetrics/regression/mse.py | 3 ++- src/torchmetrics/regression/nrmse.py | 3 ++- src/torchmetrics/regression/pearson.py | 3 ++- src/torchmetrics/regression/r2.py | 3 ++- src/torchmetrics/regression/rse.py | 3 ++- src/torchmetrics/regression/spearman.py | 3 ++- src/torchmetrics/regression/symmetric_mape.py | 3 ++- src/torchmetrics/regression/tweedie_deviance.py | 3 ++- src/torchmetrics/regression/wmape.py | 3 ++- src/torchmetrics/retrieval/auroc.py | 3 ++- src/torchmetrics/retrieval/average_precision.py | 3 ++- src/torchmetrics/retrieval/fall_out.py | 3 ++- src/torchmetrics/retrieval/hit_rate.py | 3 ++- src/torchmetrics/retrieval/ndcg.py | 3 ++- src/torchmetrics/retrieval/precision.py | 3 ++- src/torchmetrics/retrieval/precision_recall_curve.py | 3 ++- src/torchmetrics/retrieval/r_precision.py | 3 ++- src/torchmetrics/retrieval/recall.py | 3 ++- src/torchmetrics/retrieval/reciprocal_rank.py | 3 ++- src/torchmetrics/segmentation/dice.py | 3 ++- src/torchmetrics/segmentation/generalized_dice.py | 3 ++- src/torchmetrics/segmentation/hausdorff_distance.py | 3 ++- src/torchmetrics/segmentation/mean_iou.py | 3 ++- src/torchmetrics/shape/procrustes.py | 3 ++- src/torchmetrics/text/_deprecated.py | 3 ++- src/torchmetrics/text/bert.py | 3 ++- src/torchmetrics/text/bleu.py | 3 ++- src/torchmetrics/text/cer.py | 3 ++- src/torchmetrics/text/chrf.py | 3 ++- src/torchmetrics/text/edit.py | 3 ++- src/torchmetrics/text/eed.py | 3 ++- src/torchmetrics/text/infolm.py | 3 ++- src/torchmetrics/text/mer.py | 3 ++- src/torchmetrics/text/perplexity.py | 3 ++- src/torchmetrics/text/rouge.py | 3 ++- src/torchmetrics/text/sacre_bleu.py | 3 ++- src/torchmetrics/text/squad.py | 3 ++- src/torchmetrics/text/ter.py | 3 ++- src/torchmetrics/text/wer.py | 3 ++- src/torchmetrics/text/wil.py | 3 ++- src/torchmetrics/text/wip.py | 3 ++- src/torchmetrics/utilities/checks.py | 3 ++- src/torchmetrics/utilities/data.py | 3 ++- src/torchmetrics/utilities/plot.py | 3 ++- src/torchmetrics/wrappers/bootstrapping.py | 3 ++- src/torchmetrics/wrappers/classwise.py | 3 ++- src/torchmetrics/wrappers/feature_share.py | 3 ++- src/torchmetrics/wrappers/minmax.py | 3 ++- src/torchmetrics/wrappers/multioutput.py | 3 ++- src/torchmetrics/wrappers/multitask.py | 3 ++- src/torchmetrics/wrappers/running.py | 3 ++- src/torchmetrics/wrappers/tracker.py | 3 ++- tests/unittests/_helpers/testers.py | 3 ++- tests/unittests/text/_helpers.py | 3 ++- tests/unittests/text/test_bertscore.py | 2 +- tests/unittests/text/test_chrf.py | 2 +- tests/unittests/text/test_rouge.py | 3 ++- tests/unittests/text/test_sacre_bleu.py | 2 +- tests/unittests/text/test_ter.py | 2 +- 165 files changed, 323 insertions(+), 165 deletions(-) diff --git a/_samples/rouge_score-own_normalizer_and_tokenizer.py b/_samples/rouge_score-own_normalizer_and_tokenizer.py index 14e2252d438..28619effcba 100644 --- a/_samples/rouge_score-own_normalizer_and_tokenizer.py +++ b/_samples/rouge_score-own_normalizer_and_tokenizer.py @@ -18,8 +18,8 @@ """ import re +from collections.abc import Sequence from pprint import pprint -from typing import Sequence from torchmetrics.text.rouge import ROUGEScore diff --git a/setup.py b/setup.py index c51d4915dff..37994261b49 100755 --- a/setup.py +++ b/setup.py @@ -2,11 +2,12 @@ import glob import os import re +from collections.abc import Iterable, Iterator from functools import partial from importlib.util import module_from_spec, spec_from_file_location from itertools import chain from pathlib import Path -from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from pkg_resources import Requirement, yield_lines from setuptools import find_packages, setup diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index ee4f86ffdc3..14c6831e62a 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/audio/dnsmos.py b/src/torchmetrics/audio/dnsmos.py index 406b817eb05..6b6f8fc60de 100644 --- a/src/torchmetrics/audio/dnsmos.py +++ b/src/torchmetrics/audio/dnsmos.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/nisqa.py b/src/torchmetrics/audio/nisqa.py index c079d903101..d1be81a01da 100644 --- a/src/torchmetrics/audio/nisqa.py +++ b/src/torchmetrics/audio/nisqa.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index ee6a1751359..38146f6c943 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 2def91d4d01..ecd3aa6f1ee 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 9b8646aaa1f..e932e06199b 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index d8b9fd4c173..4947d5141c3 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/srmr.py b/src/torchmetrics/audio/srmr.py index 6c910738880..453f1bb7eab 100644 --- a/src/torchmetrics/audio/srmr.py +++ b/src/torchmetrics/audio/srmr.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index cea3df3514e..60d1fb9d670 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index efd337d496a..614b8d035dd 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 65e9493b14c..a757d71eb00 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 9d36774938c..9f7c1adf6fb 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 404d9089bbc..952eed47fd4 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 3531eb6b106..093919f2cd0 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index df9807cea7f..080d482d2fe 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Tuple, Union, no_type_check +from collections.abc import Sequence +from typing import Any, Callable, Optional, Tuple, Union, no_type_check import torch from torch import Tensor diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 10b9aedc2fc..c71d35df116 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 526ad1ae0da..9a042907cfb 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 8e38b24faeb..06dc3a07ae0 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index bd0bfa733c6..d29e217efe9 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 5514f98cccc..f03e24cd7fd 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 385009d5a6a..cb3e20b0d89 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 49de1f03795..b7a7ee59237 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py index d0b19dc1247..5f3d505872b 100644 --- a/src/torchmetrics/classification/negative_predictive_value.py +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index c761f9aa8a9..73466f37f94 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 0380545b5ac..19d2117863e 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index 9dda737030d..fca0d64ff86 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 58a460b7b2e..4e7f540f49a 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index caca10dfa2b..274709b546a 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Type, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py index c797a5ac23b..ebcf4749d08 100644 --- a/src/torchmetrics/clustering/adjusted_mutual_info_score.py +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 5c1f5f49276..20278f74bc3 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index 483e4332148..c331fba7866 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index 40827b568cb..ddd079793cd 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 5a85074443d..65d1c0c9a94 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py index 32fcffe37a7..1317a0cee1c 100644 --- a/src/torchmetrics/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index 01039a8fdcb..260ab522245 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index ede0c1393fc..a2be02f834e 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/normalized_mutual_info_score.py b/src/torchmetrics/clustering/normalized_mutual_info_score.py index 927a4451009..2583b0b2a9e 100644 --- a/src/torchmetrics/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/clustering/normalized_mutual_info_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 8ded8b27d0d..724a38b227c 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index b4ec0c4e4cc..ef6e2087e7e 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -13,8 +13,9 @@ # limitations under the License. # this is just a bypass for this module name collision with built-in one from collections import OrderedDict +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy -from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/detection/_deprecated.py b/src/torchmetrics/detection/_deprecated.py index c162c751554..f8acd23adb6 100644 --- a/src/torchmetrics/detection/_deprecated.py +++ b/src/torchmetrics/detection/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Collection +from collections.abc import Collection +from typing import Any from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality from torchmetrics.utilities.prints import _deprecated_root_import_class diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 604daef82eb..73ffea36eb2 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index b6174c3b60c..1545ab62a83 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 7eb3780a112..edfebb38184 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index d024adad817..cf03a73c65d 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index dc31a7c7497..97eb91f0ac9 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Literal, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Dict, Literal, Tuple, Union from torch import Tensor diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index d4930d905ab..4cf7fbfbd43 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index c1c63d1a9b4..3cd67cd6e2c 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -14,8 +14,9 @@ import contextlib import io import json +from collections.abc import Sequence from types import ModuleType -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/torchmetrics/detection/panoptic_qualities.py b/src/torchmetrics/detection/panoptic_qualities.py index b4629be8e69..ce3948f3234 100644 --- a/src/torchmetrics/detection/panoptic_qualities.py +++ b/src/torchmetrics/detection/panoptic_qualities.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Optional, Sequence, Union +from collections.abc import Collection, Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index c4607fd9489..3c01ceeed36 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/detection/_deprecated.py b/src/torchmetrics/functional/detection/_deprecated.py index ce0e1ba6acf..46c8acdab56 100644 --- a/src/torchmetrics/functional/detection/_deprecated.py +++ b/src/torchmetrics/functional/detection/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Collection +from collections.abc import Collection from torch import Tensor diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index c94978e3435..b3b0fc33295 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, Iterator, List, Optional, Set, Tuple, cast +from collections.abc import Collection, Iterator +from typing import Dict, List, Optional, Set, Tuple, cast import torch from torch import Tensor diff --git a/src/torchmetrics/functional/detection/panoptic_qualities.py b/src/torchmetrics/functional/detection/panoptic_qualities.py index 019d243f0ba..c1b3c7f27b6 100644 --- a/src/torchmetrics/functional/detection/panoptic_qualities.py +++ b/src/torchmetrics/functional/detection/panoptic_qualities.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection +from collections.abc import Collection import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 892d07afaa6..efaab73cded 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Tuple, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index c61ef833fe3..bfa27e7df0a 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index ed8bf39742b..366e06a3bb1 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional, Tuple import torch from torch import Tensor, nn diff --git a/src/torchmetrics/functional/image/utils.py b/src/torchmetrics/functional/image/utils.py index bf09ff79249..a6869d1dc88 100644 --- a/src/torchmetrics/functional/image/utils.py +++ b/src/torchmetrics/functional/image/utils.py @@ -1,4 +1,5 @@ -from typing import Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/explained_variance.py b/src/torchmetrics/functional/regression/explained_variance.py index a6a6c4ff209..ab3158bc594 100644 --- a/src/torchmetrics/functional/regression/explained_variance.py +++ b/src/torchmetrics/functional/regression/explained_variance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 169c3d5357b..8fe94195c58 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 71bec857a72..83cc23950d4 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -14,8 +14,9 @@ import csv import logging import urllib +from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/bleu.py b/src/torchmetrics/functional/text/bleu.py index 032b677f182..724cec16794 100644 --- a/src/torchmetrics/functional/text/bleu.py +++ b/src/torchmetrics/functional/text/bleu.py @@ -17,7 +17,8 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from typing import Callable, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Callable, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index ca98778fade..7d7c552e3fb 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -21,8 +21,9 @@ # Reference to the approval: https://github.com/Lightning-AI/torchmetrics/pull/2701#issuecomment-2316891785 from collections import defaultdict +from collections.abc import Sequence from itertools import chain -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/edit.py b/src/torchmetrics/functional/text/edit.py index 0e443bb4c04..5660f45e025 100644 --- a/src/torchmetrics/functional/text/edit.py +++ b/src/torchmetrics/functional/text/edit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index f6f562f5338..45e9c254412 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -88,8 +88,9 @@ import re import unicodedata +from collections.abc import Sequence from math import inf -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Tuple, Union from torch import Tensor, stack, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index d4c9ff7ae04..bf02d40f3d8 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -29,8 +29,9 @@ # limitations under the License. import math +from collections.abc import Sequence from enum import Enum, unique -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Tuple, Union # Tercom-inspired limits _BEAM_WIDTH = 25 diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 0365cdf7ae6..a3efadffe33 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Sequence from enum import unique -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 58c9a05fecf..a83b41a1a20 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -13,7 +13,8 @@ # limitations under the License. import re from collections import Counter -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index 33c3afb0beb..6b18f4cab8e 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -40,8 +40,9 @@ import os import re import tempfile +from collections.abc import Sequence from functools import partial -from typing import Any, ClassVar, Dict, Optional, Sequence, Type +from typing import Any, ClassVar, Dict, Optional, Type import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 400a4c283b6..08bf6b7562f 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -34,8 +34,9 @@ # limitations under the License. import re +from collections.abc import Iterator, Sequence from functools import lru_cache -from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 8b382b89cf7..50b54f479ff 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Dict, Optional, Tuple, Union from typing_extensions import Literal diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 5b1e58de0dd..97d95ccd926 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index a8dacbaf447..02530f7cd90 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index bf6b1c99a10..22c24b164f1 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index ac559ed0c68..7ad1bb0b892 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 20d53d10f2b..8bf42584af2 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 018fc7a7511..6deb4efa22d 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 1893fb734ba..8c4948b18a8 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, ClassVar, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index a1e2d2f4c0a..5d344b57dd1 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index fe774d2588b..40c6fdab89f 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index c9fb157d6ff..0e12018c324 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index b226a6b19fd..75b88379669 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index b1eb32141a6..6f7b0b10346 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index c1f7c652879..75feb612802 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 7aad1782852..b313158f80b 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 648f9c26029..576ac4f0879 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index ca6276b7e86..287e58a3a43 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index 3fea5e8986f..c503cc1f394 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 279537aac51..6344d683edb 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -18,9 +18,10 @@ import functools import inspect from abc import ABC, abstractmethod +from collections.abc import Generator, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index f49113e297f..e4abfaa89b1 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 92ca7ad6b4f..aca26ac09e5 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index df47cc079d1..5f361bfb477 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index 619e8196d4c..254796e96c9 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index a43cd548792..15be1bd43b4 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index 7f7f22ecb1b..19695a61c2c 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index 9986fa2ec6f..9744103f304 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index 6697ccadec9..c6d051b6ebe 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union from torch import Tensor diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index b1d8923290e..5c86ac00cab 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/explained_variance.py b/src/torchmetrics/regression/explained_variance.py index bf0ba0dc55f..833c1609e55 100644 --- a/src/torchmetrics/regression/explained_variance.py +++ b/src/torchmetrics/regression/explained_variance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index f9677f4c296..63c2ec150b6 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 253b7a05002..bd799c31d7f 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/log_cosh.py b/src/torchmetrics/regression/log_cosh.py index ca9395f9a50..750e5a3497e 100644 --- a/src/torchmetrics/regression/log_cosh.py +++ b/src/torchmetrics/regression/log_cosh.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/log_mse.py b/src/torchmetrics/regression/log_mse.py index da190016844..31dbba6accb 100644 --- a/src/torchmetrics/regression/log_mse.py +++ b/src/torchmetrics/regression/log_mse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mae.py b/src/torchmetrics/regression/mae.py index 46de2d70d21..da95b2b2d90 100644 --- a/src/torchmetrics/regression/mae.py +++ b/src/torchmetrics/regression/mae.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mape.py b/src/torchmetrics/regression/mape.py index cbc490e686e..7ee1eaf61f7 100644 --- a/src/torchmetrics/regression/mape.py +++ b/src/torchmetrics/regression/mape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/minkowski.py b/src/torchmetrics/regression/minkowski.py index 1c5d9e430f1..50785f5425c 100644 --- a/src/torchmetrics/regression/minkowski.py +++ b/src/torchmetrics/regression/minkowski.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 363c56d6a29..b82738ace4a 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 62562803542..bb44ae9c905 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 26d8ffdb5c0..75d323a60f6 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index a54502e087d..613dc69a141 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/rse.py b/src/torchmetrics/regression/rse.py index 2c9d524f03c..2776f8c9fda 100644 --- a/src/torchmetrics/regression/rse.py +++ b/src/torchmetrics/regression/rse.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 62be592919d..de94903c8c0 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index 82b5702f476..a1c601a2a16 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index 70c8ae37855..3cd10070e7b 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 42fadb906ef..a5c532f145a 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index 8d5ac12a929..6d4382894a4 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/average_precision.py b/src/torchmetrics/retrieval/average_precision.py index b5527bd6afd..93ecc38cb51 100644 --- a/src/torchmetrics/retrieval/average_precision.py +++ b/src/torchmetrics/retrieval/average_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/fall_out.py b/src/torchmetrics/retrieval/fall_out.py index 08cda4ace08..9660f061370 100644 --- a/src/torchmetrics/retrieval/fall_out.py +++ b/src/torchmetrics/retrieval/fall_out.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/retrieval/hit_rate.py b/src/torchmetrics/retrieval/hit_rate.py index e981a126fa5..c04ff3ad14d 100644 --- a/src/torchmetrics/retrieval/hit_rate.py +++ b/src/torchmetrics/retrieval/hit_rate.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/ndcg.py b/src/torchmetrics/retrieval/ndcg.py index 2914726fde9..afd211cb142 100644 --- a/src/torchmetrics/retrieval/ndcg.py +++ b/src/torchmetrics/retrieval/ndcg.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/precision.py b/src/torchmetrics/retrieval/precision.py index e59cac4e382..70bf4d9794a 100644 --- a/src/torchmetrics/retrieval/precision.py +++ b/src/torchmetrics/retrieval/precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 9b4f4d38d0e..74fe1dd4c50 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/retrieval/r_precision.py b/src/torchmetrics/retrieval/r_precision.py index c3892b76b8b..2c2466af430 100644 --- a/src/torchmetrics/retrieval/r_precision.py +++ b/src/torchmetrics/retrieval/r_precision.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union from torch import Tensor diff --git a/src/torchmetrics/retrieval/recall.py b/src/torchmetrics/retrieval/recall.py index de77abc9a85..04fe881bc99 100644 --- a/src/torchmetrics/retrieval/recall.py +++ b/src/torchmetrics/retrieval/recall.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/retrieval/reciprocal_rank.py b/src/torchmetrics/retrieval/reciprocal_rank.py index 17a9a2bd7a8..01ea27ae12b 100644 --- a/src/torchmetrics/retrieval/reciprocal_rank.py +++ b/src/torchmetrics/retrieval/reciprocal_rank.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index fc8cadd8c3a..05a6e29b387 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 95da9ab26d9..aef04be1e43 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index f1e8812ed30..727f9e7bb08 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -10,7 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index 0fe831f5231..ae8dd3d2aea 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/shape/procrustes.py b/src/torchmetrics/shape/procrustes.py index a924fb48a4a..1d8396c7afb 100644 --- a/src/torchmetrics/shape/procrustes.py +++ b/src/torchmetrics/shape/procrustes.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index 0c7ffef29af..50c5091de28 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Literal, Optional from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 6e1bab1b9bd..5960e16a00f 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/bleu.py b/src/torchmetrics/text/bleu.py index d05da1e33c4..a40525bbe0c 100644 --- a/src/torchmetrics/text/bleu.py +++ b/src/torchmetrics/text/bleu.py @@ -16,7 +16,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 0bdc5670f59..9862c450ed3 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 1ff412ab1a4..962b0813b3b 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -18,7 +18,8 @@ # Link: import itertools -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from collections.abc import Iterator, Sequence +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/edit.py b/src/torchmetrics/text/edit.py index de14f49ae3b..947fc79cd6c 100644 --- a/src/torchmetrics/text/edit.py +++ b/src/torchmetrics/text/edit.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index 659181f8d39..c0629b9ba44 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union from torch import Tensor, stack from typing_extensions import Literal diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index 31fea4adc23..3488e2074b4 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index b519445c05e..37dae4cc4f6 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index d13eac2f402..a090d522db8 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index 7bce72ed1c3..d0cac0df18d 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/text/sacre_bleu.py b/src/torchmetrics/text/sacre_bleu.py index 8e2e63ad82b..3a393a399df 100644 --- a/src/torchmetrics/text/sacre_bleu.py +++ b/src/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index e4be6670d19..e4d98c2a8b6 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 98ec0a90235..8ded3c9b606 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index 0950bd4de42..fc947ef2772 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index 16b71720c3b..a0d42fbfcf2 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index bbdd2b7a235..6d5db6b3e2c 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, List, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 449efcade2d..79878d058b7 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -14,9 +14,10 @@ import multiprocessing import os import sys +from collections.abc import Mapping, Sequence from functools import partial from time import perf_counter -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, no_type_check +from typing import Any, Callable, Dict, Optional, Tuple, no_type_check from unittest.mock import Mock import torch diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 1a68e655c33..4428c8cc7e9 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Tuple, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 4c88c078050..dae78a873e9 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator, Sequence from itertools import product from math import ceil, floor, sqrt -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union, no_type_check +from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check import numpy as np import torch diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index d59f7724c2a..083cafd76c6 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 217c94d6bc0..682cfdad4b3 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import typing -from typing import Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index 1bd1b81783b..5cbc0106beb 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import lru_cache -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, Optional, Union from torch.nn import Module diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index 09684c55919..f91c3992529 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Dict, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 7853e6257e6..85cdb2c573f 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index 04ddd87ad71..a955d938583 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # this is just a bypass for this module name collision with built-in one +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from torch import Tensor, nn diff --git a/src/torchmetrics/wrappers/running.py b/src/torchmetrics/wrappers/running.py index f877c65ad95..a57cb8333f0 100644 --- a/src/torchmetrics/wrappers/running.py +++ b/src/torchmetrics/wrappers/running.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index c1fe9957b1b..148e16c412c 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 98cc110a3ff..42510b08747 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -13,9 +13,10 @@ # limitations under the License. import pickle import sys +from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pytest diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index e55944a72af..d580a62fa39 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -13,8 +13,9 @@ # limitations under the License. import pickle import sys +from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np import pytest diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 1d74f0c858d..d7e1fb22609 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 233c9451381..ca3995e2519 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor, tensor diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index a40885587e8..3a8acef5b16 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -13,8 +13,9 @@ # limitations under the License. import re +from collections.abc import Sequence from functools import partial -from typing import Callable, Sequence, Union +from typing import Callable, Union import pytest import torch diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index bf2d45fbd5a..54da47a6a8a 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from lightning_utilities.core.imports import RequirementCache diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index 6c047730067..1f896bdbbbe 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import partial -from typing import Sequence import pytest from torch import Tensor, tensor From ed9f2bb402e08569c8bc95e4831262b13869f254 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 5 Nov 2024 15:33:00 +0000 Subject: [PATCH 04/15] try a few --- src/torchmetrics/functional/text/squad.py | 6 +++--- tests/unittests/text/test_wer.py | 2 +- tests/unittests/text/test_wip.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/text/squad.py b/src/torchmetrics/functional/text/squad.py index 01dfb4ec0e6..6ea9c70c9a1 100644 --- a/src/torchmetrics/functional/text/squad.py +++ b/src/torchmetrics/functional/text/squad.py @@ -92,7 +92,7 @@ def _metric_max_over_ground_truths( def _squad_input_check( preds: PREDS_TYPE, targets: TARGETS_TYPE -) -> Tuple[Dict[str, str], List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]]]: +) -> Tuple[Dict[str, str], list[Dict[str, List[Dict[str, list[Dict[str, Any]]]]]]]: """Check for types and convert the input to necessary format to compute the input.""" if isinstance(preds, Dict): preds = [preds] @@ -118,7 +118,7 @@ def _squad_input_check( f"{SQuAD_FORMAT}" ) - answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore[assignment] + answers: dict[str, Union[List[str], list[int]]] = target["answers"] # type: ignore[assignment] if "text" not in answers: raise KeyError( "Expected keys in a 'answers' are 'text'." @@ -135,7 +135,7 @@ def _squad_input_check( def _squad_update( preds: Dict[str, str], - target: List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]], + target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]], ) -> Tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index 16b03849f84..bae4c91c11d 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import compute_measures except ImportError: diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index a6523babd67..fe4bedb482c 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wip(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import wip except ImportError: From e1361082d0d03c1ce80870e5c79d59bb95b05c68 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 7 Nov 2024 15:18:57 +0000 Subject: [PATCH 05/15] unsafe-fixes --- .github/assistant.py | 10 +- _samples/bert_score-own_model.py | 6 +- examples/audio/signal_to_noise_ratio.py | 2 +- setup.py | 4 +- src/torchmetrics/aggregation.py | 4 +- src/torchmetrics/audio/pit.py | 2 +- src/torchmetrics/classification/accuracy.py | 2 +- src/torchmetrics/classification/auroc.py | 10 +- .../classification/average_precision.py | 8 +- .../classification/calibration_error.py | 10 +- .../classification/cohen_kappa.py | 2 +- .../classification/confusion_matrix.py | 8 +- src/torchmetrics/classification/dice.py | 2 +- .../classification/exact_match.py | 2 +- src/torchmetrics/classification/f_beta.py | 4 +- .../classification/group_fairness.py | 6 +- src/torchmetrics/classification/hamming.py | 2 +- src/torchmetrics/classification/hinge.py | 2 +- src/torchmetrics/classification/jaccard.py | 2 +- .../classification/matthews_corrcoef.py | 2 +- .../negative_predictive_value.py | 2 +- .../classification/precision_fixed_recall.py | 16 +- .../classification/precision_recall.py | 4 +- .../classification/precision_recall_curve.py | 34 ++--- .../classification/recall_fixed_precision.py | 16 +- src/torchmetrics/classification/roc.py | 20 +-- .../classification/sensitivity_specificity.py | 16 +- .../classification/specificity.py | 2 +- .../classification/specificity_sensitivity.py | 16 +- .../classification/stat_scores.py | 12 +- .../clustering/adjusted_mutual_info_score.py | 4 +- .../clustering/adjusted_rand_score.py | 4 +- .../clustering/calinski_harabasz_score.py | 4 +- .../clustering/davies_bouldin_score.py | 4 +- src/torchmetrics/clustering/dunn_index.py | 4 +- .../clustering/fowlkes_mallows_index.py | 4 +- .../homogeneity_completeness_v_measure.py | 12 +- .../clustering/mutual_info_score.py | 4 +- .../normalized_mutual_info_score.py | 4 +- src/torchmetrics/clustering/rand_score.py | 4 +- src/torchmetrics/collections.py | 26 ++-- src/torchmetrics/detection/_mean_ap.py | 54 +++---- src/torchmetrics/detection/helpers.py | 10 +- src/torchmetrics/detection/iou.py | 10 +- src/torchmetrics/detection/mean_ap.py | 64 ++++---- .../functional/audio/_deprecated.py | 2 +- src/torchmetrics/functional/audio/dnsmos.py | 2 +- src/torchmetrics/functional/audio/nisqa.py | 26 ++-- src/torchmetrics/functional/audio/pit.py | 6 +- src/torchmetrics/functional/audio/sdr.py | 2 +- src/torchmetrics/functional/audio/srmr.py | 4 +- .../functional/classification/auroc.py | 24 +-- .../classification/average_precision.py | 22 +-- .../classification/calibration_error.py | 6 +- .../classification/confusion_matrix.py | 6 +- .../functional/classification/exact_match.py | 4 +- .../classification/group_fairness.py | 22 +-- .../functional/classification/hinge.py | 4 +- .../classification/precision_fixed_recall.py | 18 +-- .../classification/precision_recall_curve.py | 64 ++++---- .../functional/classification/ranking.py | 6 +- .../classification/recall_fixed_precision.py | 36 ++--- .../functional/classification/roc.py | 28 ++-- .../classification/sensitivity_specificity.py | 36 ++--- .../classification/specificity_sensitivity.py | 40 ++--- .../functional/classification/stat_scores.py | 20 +-- .../functional/clustering/dunn_index.py | 2 +- .../clustering/fowlkes_mallows_index.py | 2 +- .../homogeneity_completeness_v_measure.py | 4 +- .../detection/_panoptic_quality_common.py | 68 ++++----- .../functional/image/_deprecated.py | 16 +- src/torchmetrics/functional/image/d_lambda.py | 2 +- src/torchmetrics/functional/image/d_s.py | 2 +- src/torchmetrics/functional/image/ergas.py | 2 +- .../functional/image/gradients.py | 4 +- src/torchmetrics/functional/image/lpips.py | 6 +- .../image/perceptual_path_length.py | 2 +- src/torchmetrics/functional/image/psnr.py | 8 +- src/torchmetrics/functional/image/psnrb.py | 2 +- src/torchmetrics/functional/image/rase.py | 2 +- src/torchmetrics/functional/image/rmse_sw.py | 6 +- src/torchmetrics/functional/image/sam.py | 2 +- src/torchmetrics/functional/image/scc.py | 6 +- src/torchmetrics/functional/image/ssim.py | 24 +-- src/torchmetrics/functional/image/tv.py | 2 +- src/torchmetrics/functional/image/uqi.py | 2 +- src/torchmetrics/functional/image/utils.py | 2 +- .../functional/multimodal/clip_iqa.py | 20 +-- .../functional/multimodal/clip_score.py | 12 +- src/torchmetrics/functional/nominal/utils.py | 6 +- .../functional/pairwise/helpers.py | 2 +- .../regression/cosine_similarity.py | 2 +- src/torchmetrics/functional/regression/csi.py | 2 +- .../regression/explained_variance.py | 2 +- .../functional/regression/kendall.py | 16 +- .../functional/regression/kl_divergence.py | 2 +- .../functional/regression/log_cosh.py | 4 +- .../functional/regression/log_mse.py | 2 +- src/torchmetrics/functional/regression/mae.py | 2 +- .../functional/regression/mape.py | 2 +- src/torchmetrics/functional/regression/mse.py | 2 +- .../functional/regression/nrmse.py | 2 +- .../functional/regression/pearson.py | 2 +- src/torchmetrics/functional/regression/r2.py | 2 +- .../functional/regression/spearman.py | 2 +- .../functional/regression/symmetric_mape.py | 2 +- .../functional/regression/tweedie_deviance.py | 2 +- .../functional/regression/wmape.py | 2 +- .../functional/retrieval/_deprecated.py | 2 +- .../retrieval/precision_recall_curve.py | 2 +- .../functional/segmentation/dice.py | 2 +- .../segmentation/hausdorff_distance.py | 4 +- .../functional/segmentation/mean_iou.py | 2 +- .../functional/segmentation/utils.py | 26 ++-- .../functional/shape/procrustes.py | 2 +- .../functional/text/_deprecated.py | 36 ++--- src/torchmetrics/functional/text/bert.py | 22 +-- src/torchmetrics/functional/text/bleu.py | 2 +- src/torchmetrics/functional/text/cer.py | 8 +- src/torchmetrics/functional/text/chrf.py | 142 +++++++++--------- src/torchmetrics/functional/text/eed.py | 10 +- src/torchmetrics/functional/text/helper.py | 50 +++--- .../text/helper_embedding_metric.py | 30 ++-- src/torchmetrics/functional/text/infolm.py | 18 +-- src/torchmetrics/functional/text/mer.py | 8 +- .../functional/text/perplexity.py | 2 +- src/torchmetrics/functional/text/rouge.py | 38 ++--- .../functional/text/sacre_bleu.py | 28 ++-- src/torchmetrics/functional/text/squad.py | 30 ++-- src/torchmetrics/functional/text/ter.py | 44 +++--- src/torchmetrics/functional/text/wer.py | 8 +- src/torchmetrics/functional/text/wil.py | 8 +- src/torchmetrics/functional/text/wip.py | 8 +- src/torchmetrics/image/_deprecated.py | 14 +- src/torchmetrics/image/d_lambda.py | 4 +- src/torchmetrics/image/d_s.py | 10 +- src/torchmetrics/image/ergas.py | 4 +- src/torchmetrics/image/fid.py | 6 +- src/torchmetrics/image/inception.py | 4 +- src/torchmetrics/image/kid.py | 6 +- src/torchmetrics/image/lpip.py | 2 +- src/torchmetrics/image/mifid.py | 4 +- .../image/perceptual_path_length.py | 2 +- src/torchmetrics/image/psnr.py | 4 +- src/torchmetrics/image/qnr.py | 10 +- src/torchmetrics/image/rase.py | 6 +- src/torchmetrics/image/rmse_sw.py | 2 +- src/torchmetrics/image/sam.py | 4 +- src/torchmetrics/image/ssim.py | 16 +- src/torchmetrics/image/tv.py | 2 +- src/torchmetrics/image/uqi.py | 4 +- src/torchmetrics/metric.py | 44 +++--- src/torchmetrics/multimodal/clip_iqa.py | 6 +- src/torchmetrics/multimodal/clip_score.py | 2 +- src/torchmetrics/nominal/fleiss_kappa.py | 2 +- .../regression/cosine_similarity.py | 4 +- src/torchmetrics/regression/csi.py | 6 +- src/torchmetrics/regression/kendall.py | 6 +- src/torchmetrics/regression/pearson.py | 6 +- src/torchmetrics/regression/spearman.py | 4 +- src/torchmetrics/retrieval/base.py | 6 +- .../retrieval/precision_recall_curve.py | 14 +- src/torchmetrics/segmentation/dice.py | 6 +- .../segmentation/hausdorff_distance.py | 2 +- src/torchmetrics/text/bert.py | 14 +- src/torchmetrics/text/cer.py | 2 +- src/torchmetrics/text/chrf.py | 12 +- src/torchmetrics/text/edit.py | 2 +- src/torchmetrics/text/eed.py | 4 +- src/torchmetrics/text/infolm.py | 12 +- src/torchmetrics/text/mer.py | 4 +- src/torchmetrics/text/perplexity.py | 2 +- src/torchmetrics/text/rouge.py | 6 +- src/torchmetrics/text/squad.py | 2 +- src/torchmetrics/text/ter.py | 4 +- src/torchmetrics/text/wer.py | 2 +- src/torchmetrics/text/wil.py | 2 +- src/torchmetrics/text/wip.py | 2 +- src/torchmetrics/utilities/checks.py | 18 +-- src/torchmetrics/utilities/compute.py | 2 +- src/torchmetrics/utilities/data.py | 4 +- src/torchmetrics/utilities/distributed.py | 4 +- src/torchmetrics/utilities/enums.py | 2 +- src/torchmetrics/utilities/plot.py | 18 +-- src/torchmetrics/wrappers/bootstrapping.py | 2 +- src/torchmetrics/wrappers/classwise.py | 8 +- src/torchmetrics/wrappers/feature_share.py | 2 +- src/torchmetrics/wrappers/minmax.py | 2 +- src/torchmetrics/wrappers/multioutput.py | 2 +- src/torchmetrics/wrappers/multitask.py | 16 +- src/torchmetrics/wrappers/tracker.py | 12 +- src/torchmetrics/wrappers/transformations.py | 8 +- tests/unittests/_helpers/testers.py | 22 +-- tests/unittests/audio/test_dnsmos.py | 8 +- tests/unittests/audio/test_nisqa.py | 4 +- tests/unittests/audio/test_pit.py | 6 +- tests/unittests/audio/test_srmr.py | 4 +- .../classification/test_group_fairness.py | 4 +- .../test_modified_panoptic_quality.py | 2 +- .../detection/test_panoptic_quality.py | 2 +- tests/unittests/image/test_d_s.py | 2 +- tests/unittests/image/test_qnr.py | 2 +- tests/unittests/multimodal/test_clip_score.py | 2 +- tests/unittests/retrieval/helpers.py | 12 +- .../retrieval/test_precision_recall_curve.py | 2 +- tests/unittests/text/_helpers.py | 2 +- tests/unittests/text/test_cer.py | 2 +- tests/unittests/text/test_mer.py | 2 +- tests/unittests/text/test_wil.py | 2 +- 209 files changed, 1034 insertions(+), 1034 deletions(-) diff --git a/.github/assistant.py b/.github/assistant.py index ad054d96e2f..1718727c452 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -83,7 +83,7 @@ def _replace_requirement(fpath: str, old_str: str = "", new_str: str = "") -> No fp.write(req) @staticmethod - def replace_str_requirements(old_str: str, new_str: str, req_files: List[str] = REQUIREMENTS_FILES) -> None: + def replace_str_requirements(old_str: str, new_str: str, req_files: list[str] = REQUIREMENTS_FILES) -> None: """Replace a particular string in all requirements files.""" if isinstance(req_files, str): req_files = [req_files] @@ -96,7 +96,7 @@ def replace_min_requirements(fpath: str) -> None: AssistantCLI._replace_requirement(fpath, old_str=">=", new_str="==") @staticmethod - def set_oldest_versions(req_files: List[str] = REQUIREMENTS_FILES) -> None: + def set_oldest_versions(req_files: list[str] = REQUIREMENTS_FILES) -> None: """Set the oldest version for requirements.""" AssistantCLI.set_min_torch_by_python() if isinstance(req_files, str): @@ -109,8 +109,8 @@ def changed_domains( pr: Optional[int] = None, auth_token: Optional[str] = None, as_list: bool = False, - general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES, - ) -> Union[str, List[str]]: + general_sub_pkgs: tuple[str] = _PKG_WIDE_SUBPACKAGES, + ) -> Union[str, list[str]]: """Determine what domains were changed in particular PR.""" import github @@ -139,7 +139,7 @@ def changed_domains( return "unittests" # parse domains - def _crop_path(fname: str, paths: List[str]) -> str: + def _crop_path(fname: str, paths: list[str]) -> str: for p in paths: fname = fname.replace(p, "") return fname diff --git a/_samples/bert_score-own_model.py b/_samples/bert_score-own_model.py index d5e74078c65..74799c41acc 100644 --- a/_samples/bert_score-own_model.py +++ b/_samples/bert_score-own_model.py @@ -53,7 +53,7 @@ def __init__(self) -> None: self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM), } - def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, Tensor]: + def __call__(self, sentences: Union[str, list[str]], max_len: int = _MAX_LEN) -> dict[str, Tensor]: """Call method to tokenize user input. The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method @@ -69,7 +69,7 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values. """ - output_dict: Dict[str, Tensor] = {} + output_dict: dict[str, Tensor] = {} if isinstance(sentences, str): sentences = [sentences] # Add special tokens @@ -96,7 +96,7 @@ def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_ return nn.TransformerEncoder(encoder_layer, num_layers=num_layers) -def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor: +def user_forward_fn(model: Module, batch: dict[str, Tensor]) -> Tensor: """User forward function used for the computation of model embeddings. This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index c7130a895e4..b203efb87d5 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -20,7 +20,7 @@ # Generate a clean signal (simulating a high-quality recording) -def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]: +def generate_clean_signal(length: int = 1000) -> tuple[np.ndarray, np.ndarray]: """Generate a clean signal (sine wave)""" t = np.linspace(0, 1, length) signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave, representing the clean recording diff --git a/setup.py b/setup.py index 37994261b49..dc66936b98a 100755 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen def _load_requirements( path_dir: str, file_name: str = "base.txt", unfreeze: bool = not _FREEZE_REQUIREMENTS -) -> List[str]: +) -> list[str]: """Load requirements from a file. >>> _load_requirements(_PATH_REQUIRE) @@ -162,7 +162,7 @@ def _load_py_module(fname: str, pkg: str = "torchmetrics"): BASE_REQUIREMENTS = _load_requirements(path_dir=_PATH_REQUIRE, file_name="base.txt") -def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.txt",)) -> dict: +def _prepare_extras(skip_pattern: str = "^_", skip_files: tuple[str] = ("base.txt",)) -> dict: """Preparing extras for the package listing requirements. Args: diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index 14c6831e62a..ae30429bc20 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -56,7 +56,7 @@ class BaseAggregator(Metric): def __init__( self, fn: Union[Callable, str], - default_value: Union[Tensor, List], + default_value: Union[Tensor, list], nan_strategy: Union[str, float] = "error", state_name: str = "value", **kwargs: Any, @@ -75,7 +75,7 @@ def __init__( def _cast_and_nan_check_input( self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Convert input ``x`` to a tensor and check for Nans.""" if not isinstance(x, Tensor): x = torch.as_tensor(x, dtype=self.dtype, device=self.device) diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index ecd3aa6f1ee..56cd28b5ae0 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -88,7 +88,7 @@ def __init__( eval_func: Literal["max", "min"] = "max", **kwargs: Any, ) -> None: - base_kwargs: Dict[str, Any] = { + base_kwargs: dict[str, Any] = { "dist_sync_on_step": kwargs.pop("dist_sync_on_step", False), "process_group": kwargs.pop("process_group", None), "dist_sync_fn": kwargs.pop("dist_sync_fn", None), diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 614b8d035dd..1baec43a9e5 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -490,7 +490,7 @@ class Accuracy(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Accuracy"], + cls: type["Accuracy"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index a757d71eb00..e82498b221b 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -109,7 +109,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): def __init__( self, max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -258,7 +258,7 @@ def __init__( self, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -413,7 +413,7 @@ def __init__( self, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -508,9 +508,9 @@ class AUROC(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["AUROC"], + cls: type["AUROC"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 9f7c1adf6fb..e37f6989fa3 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -256,7 +256,7 @@ def __init__( self, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -416,7 +416,7 @@ def __init__( self, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -518,9 +518,9 @@ class AveragePrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["AveragePrecision"], + cls: type["AveragePrecision"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 952eed47fd4..7de354e54cd 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -106,8 +106,8 @@ class BinaryCalibrationError(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - confidences: List[Tensor] - accuracies: List[Tensor] + confidences: list[Tensor] + accuracies: list[Tensor] def __init__( self, @@ -259,8 +259,8 @@ class MulticlassCalibrationError(Metric): plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" - confidences: List[Tensor] - accuracies: List[Tensor] + confidences: list[Tensor] + accuracies: list[Tensor] def __init__( self, @@ -371,7 +371,7 @@ class CalibrationError(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["CalibrationError"], + cls: type["CalibrationError"], task: Literal["binary", "multiclass"], n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 093919f2cd0..154cae505c1 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -315,7 +315,7 @@ class labels. """ def __new__( # type: ignore[misc] - cls: Type["CohenKappa"], + cls: type["CohenKappa"], task: Literal["binary", "multiclass"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index f1f870bcbfd..f713f6c1c40 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -150,7 +150,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -294,7 +294,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -441,7 +441,7 @@ def plot( val: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -517,7 +517,7 @@ class ConfusionMatrix(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ConfusionMatrix"], + cls: type["ConfusionMatrix"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 080d482d2fe..281324767af 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -248,7 +248,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.fn.append(fn) @no_type_check - def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _get_final_stats(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Perform concatenation on the stat scores if necessary, before passing them to a compute function.""" tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index c71d35df116..189b5b44822 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -395,7 +395,7 @@ class ExactMatch(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ExactMatch"], + cls: type["ExactMatch"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 9a042907cfb..46b1dd6d297 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -1117,7 +1117,7 @@ class FBetaScore(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["FBetaScore"], + cls: type["FBetaScore"], task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, threshold: float = 0.5, @@ -1184,7 +1184,7 @@ class F1Score(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["F1Score"], + cls: type["F1Score"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 06dc3a07ae0..0235ef123af 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -48,7 +48,7 @@ def _create_states(self, num_groups: int) -> None: self.add_state("tn", default(), dist_reduce_fx="sum") self.add_state("fn", default(), dist_reduce_fx="sum") - def _update_states(self, group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]) -> None: + def _update_states(self, group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]) -> None: for group, stats in enumerate(group_stats): tp, fp, tn, fn = stats self.tp[group] += tp @@ -148,7 +148,7 @@ def update(self, preds: Tensor, target: Tensor, groups: Tensor) -> None: def compute( self, - ) -> Dict[str, Tensor]: + ) -> dict[str, Tensor]: """Compute tp, fp, tn and fn rates based on inputs passed in to ``update`` previously.""" results = torch.stack((self.tp, self.fp, self.tn, self.fn), dim=1) @@ -268,7 +268,7 @@ def update(self, preds: Tensor, target: Tensor, groups: Tensor) -> None: def compute( self, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """Compute fairness criteria based on inputs passed in to ``update`` previously.""" if self.task == "demographic_parity": return _compute_binary_demographic_parity(self.tp, self.fp, self.tn, self.fn) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index d29e217efe9..c0dc94d3c21 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -494,7 +494,7 @@ class HammingDistance(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["HammingDistance"], + cls: type["HammingDistance"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index f03e24cd7fd..4fed17aa0f1 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -354,7 +354,7 @@ class HingeLoss(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["HingeLoss"], + cls: type["HingeLoss"], task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, squared: bool = False, diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index cb3e20b0d89..7098db7896c 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -459,7 +459,7 @@ class JaccardIndex(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["JaccardIndex"], + cls: type["JaccardIndex"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index b7a7ee59237..c84ebed69d3 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -391,7 +391,7 @@ class MatthewsCorrCoef(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["MatthewsCorrCoef"], + cls: type["MatthewsCorrCoef"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py index 5f3d505872b..a4bfb9bc4c4 100644 --- a/src/torchmetrics/classification/negative_predictive_value.py +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -487,7 +487,7 @@ class NegativePredictiveValue(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["NegativePredictiveValue"], + cls: type["NegativePredictiveValue"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index 73466f37f94..5ce97e7effc 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -116,7 +116,7 @@ class BinaryPrecisionAtFixedRecall(BinaryPrecisionRecallCurve): def __init__( self, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -127,7 +127,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute( @@ -259,7 +259,7 @@ def __init__( self, num_classes: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -272,7 +272,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( @@ -405,7 +405,7 @@ def __init__( self, num_labels: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -418,7 +418,7 @@ def __init__( self.validate_args = validate_args self.min_recall = min_recall - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( @@ -486,10 +486,10 @@ class PrecisionAtFixedRecall(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["PrecisionAtFixedRecall"], + cls: type["PrecisionAtFixedRecall"], task: Literal["binary", "multiclass", "multilabel"], min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 19d2117863e..b9b790008cd 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -986,7 +986,7 @@ class Precision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Precision"], + cls: type["Precision"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, @@ -1051,7 +1051,7 @@ class Recall(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Recall"], + cls: type["Recall"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index d0f9c632c02..86f2c66967b 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -130,13 +130,13 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] confmat: Tensor def __init__( self, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -171,14 +171,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_precision_recall_curve_compute(state, self.thresholds) def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -323,14 +323,14 @@ class MulticlassPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] confmat: Tensor def __init__( self, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, @@ -374,14 +374,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -523,14 +523,14 @@ class MultilabelPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] confmat: Tensor def __init__( self, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -570,14 +570,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -667,9 +667,9 @@ class PrecisionRecallCurve(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["PrecisionRecallCurve"], + cls: type["PrecisionRecallCurve"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 4e7f540f49a..88a4fbdd74f 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -115,7 +115,7 @@ class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): def __init__( self, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -126,7 +126,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision) @@ -258,7 +258,7 @@ def __init__( self, num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -271,7 +271,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_recall_at_fixed_precision_arg_compute( @@ -404,7 +404,7 @@ def __init__( self, num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -417,7 +417,7 @@ def __init__( self.validate_args = validate_args self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_recall_at_fixed_precision_arg_compute( @@ -485,10 +485,10 @@ class RecallAtFixedPrecision(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["RecallAtFixedPrecision"], + cls: type["RecallAtFixedPrecision"], task: Literal["binary", "multiclass", "multilabel"], min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index fc378e53ee0..73be641c66d 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -117,14 +117,14 @@ class BinaryROC(BinaryPrecisionRecallCurve): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -287,17 +287,17 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -449,17 +449,17 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Label" - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -571,9 +571,9 @@ class ROC(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["ROC"], + cls: type["ROC"], task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index c9bcd0bad6d..bce575a1014 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -111,7 +111,7 @@ class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): def __init__( self, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -122,7 +122,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _binary_sensitivity_at_specificity_compute(state, self.thresholds, self.min_specificity) @@ -208,7 +208,7 @@ def __init__( self, num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -223,7 +223,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_sensitivity_at_specificity_compute( @@ -309,7 +309,7 @@ def __init__( self, num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -322,7 +322,7 @@ def __init__( self.validate_args = validate_args self.min_specificity = min_specificity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_sensitivity_at_specificity_compute( @@ -346,10 +346,10 @@ class SensitivityAtSpecificity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["SensitivityAtSpecificity"], + cls: type["SensitivityAtSpecificity"], task: Literal["binary", "multiclass", "multilabel"], min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 274709b546a..59b7baac794 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -478,7 +478,7 @@ class Specificity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["Specificity"], + cls: type["Specificity"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index 54f38f20b06..2fd7b4c3f70 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -111,7 +111,7 @@ class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve): def __init__( self, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -122,7 +122,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity) @@ -208,7 +208,7 @@ def __init__( self, num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -223,7 +223,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_specificity_at_sensitivity_compute( @@ -309,7 +309,7 @@ def __init__( self, num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -322,7 +322,7 @@ def __init__( self.validate_args = validate_args self.min_sensitivity = min_sensitivity - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_specificity_at_sensitivity_compute( @@ -346,10 +346,10 @@ class SpecificityAtSensitivity(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["SpecificityAtSensitivity"], + cls: type["SpecificityAtSensitivity"], task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 96d797fd5d6..3a18a8b481e 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -41,10 +41,10 @@ class _AbstractStatScores(Metric): - tp: Union[List[Tensor], Tensor] - fp: Union[List[Tensor], Tensor] - tn: Union[List[Tensor], Tensor] - fn: Union[List[Tensor], Tensor] + tp: Union[list[Tensor], Tensor] + fp: Union[list[Tensor], Tensor] + tn: Union[list[Tensor], Tensor] + fn: Union[list[Tensor], Tensor] # define common functions def _create_state( @@ -79,7 +79,7 @@ def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: self.tn += tn self.fn += fn - def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _final_state(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Aggregate states that are lists and return final states.""" tp = dim_zero_cat(self.tp) fp = dim_zero_cat(self.fp) @@ -526,7 +526,7 @@ class StatScores(_ClassificationTaskWrapper): """ def __new__( # type: ignore[misc] - cls: Type["StatScores"], + cls: type["StatScores"], task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py index ebcf4749d08..94d66609d14 100644 --- a/src/torchmetrics/clustering/adjusted_mutual_info_score.py +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -73,8 +73,8 @@ class AdjustedMutualInfoScore(MutualInfoScore): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] contingency: Tensor def __init__( diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 20278f74bc3..5f202536bae 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -68,8 +68,8 @@ class AdjustedRandScore(Metric): full_state_update: bool = False plot_lower_bound: float = -0.5 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index c331fba7866..48463b94988 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -69,8 +69,8 @@ class CalinskiHarabaszScore(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: List[Tensor] - labels: List[Tensor] + data: list[Tensor] + labels: list[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index ddd079793cd..c2b30c93e01 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -79,8 +79,8 @@ class DaviesBouldinScore(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: List[Tensor] - labels: List[Tensor] + data: list[Tensor] + labels: list[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 65d1c0c9a94..ddc6b1867ba 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -67,8 +67,8 @@ class DunnIndex(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: List[Tensor] - labels: List[Tensor] + data: list[Tensor] + labels: list[Tensor] def __init__(self, p: float = 2, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py index 1317a0cee1c..4c82f892da3 100644 --- a/src/torchmetrics/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -63,8 +63,8 @@ class FowlkesMallowsIndex(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index 260ab522245..f9b7b5cc5c7 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -68,8 +68,8 @@ class HomogeneityScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -164,8 +164,8 @@ class CompletenessScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -267,8 +267,8 @@ class VMeasureScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__(self, beta: float = 1.0, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index a2be02f834e..8d206ed8886 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -68,8 +68,8 @@ class MutualInfoScore(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/clustering/normalized_mutual_info_score.py b/src/torchmetrics/clustering/normalized_mutual_info_score.py index 2583b0b2a9e..f829b3a3512 100644 --- a/src/torchmetrics/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/clustering/normalized_mutual_info_score.py @@ -72,8 +72,8 @@ class NormalizedMutualInfoScore(MutualInfoScore): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 0.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] contingency: Tensor def __init__( diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 724a38b227c..c949dd06e5c 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -66,8 +66,8 @@ class RandScore(Metric): higher_is_better = None full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index ef6e2087e7e..930540473bc 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -191,17 +191,17 @@ class name of the metric: """ - _modules: Dict[str, Metric] # type: ignore[assignment] - _groups: Dict[int, List[str]] - __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] + _modules: dict[str, Metric] # type: ignore[assignment] + _groups: dict[int, list[str]] + __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"] def __init__( self, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], *additional_metrics: Metric, prefix: Optional[str] = None, postfix: Optional[str] = None, - compute_groups: Union[bool, List[List[str]]] = True, + compute_groups: Union[bool, list[list[str]]] = True, ) -> None: super().__init__() @@ -214,12 +214,12 @@ def __init__( self.add_metrics(metrics, *additional_metrics) @property - def metric_state(self) -> Dict[str, Dict[str, Any]]: + def metric_state(self) -> dict[str, dict[str, Any]]: """Get the current state of the metric.""" return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)} @torch.jit.unused - def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Call forward for each metric sequentially. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) @@ -342,13 +342,13 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count self._state_is_copy = copy - def compute(self) -> Dict[str, Any]: + def compute(self) -> dict[str, Any]: """Compute the result for each metric in the collection.""" return self._compute_and_reduce("compute") def _compute_and_reduce( self, method_name: Literal["compute", "forward"], *args: Any, **kwargs: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Compute result from collection and reduce into a single dictionary. Args: @@ -422,7 +422,7 @@ def persistent(self, mode: bool = True) -> None: m.persistent(mode) def add_metrics( - self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric + self, metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], *additional_metrics: Metric ) -> None: """Add new metrics to Metric Collection.""" if isinstance(metrics, Metric): @@ -516,7 +516,7 @@ def _init_compute_groups(self) -> None: self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))} @property - def compute_groups(self) -> Dict[int, List[str]]: + def compute_groups(self) -> dict[int, list[str]]: """Return a dict with the current compute groups in the collection.""" return self._groups @@ -548,7 +548,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: return self._modules.keys() return self._to_renamed_dict().keys() - def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]: + def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[tuple[str, Metric]]: r"""Return an iterable of the ModuleDict key/value pairs. Args: @@ -617,7 +617,7 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection": def plot( self, - val: Optional[Union[Dict, Sequence[Dict]]] = None, + val: Optional[Union[dict, Sequence[dict]]] = None, ax: Optional[Union[_AX_TYPE, Sequence[_AX_TYPE]]] = None, together: bool = False, ) -> Sequence[_PLOT_OUT_TYPE]: diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 73ffea36eb2..d9bba820185 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -35,7 +35,7 @@ log = logging.getLogger(__name__) -def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: +def compute_area(inputs: list[Any], iou_type: str = "bbox") -> Tensor: """Compute area of input depending on the specified iou_type. Default output for empty input is :class:`~torch.Tensor` @@ -57,8 +57,8 @@ def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: def compute_iou( - det: List[Any], - gt: List[Any], + det: list[Any], + gt: list[Any], iou_type: str = "bbox", ) -> Tensor: """Compute IOU between detections and ground-truth using the specified iou_type.""" @@ -125,7 +125,7 @@ class COCOMetricResults(BaseMetricResults): ) -def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarray, np.ndarray]]) -> Tensor: +def _segm_iou(det: list[tuple[np.ndarray, np.ndarray]], gt: list[tuple[np.ndarray, np.ndarray]]) -> Tensor: """Compute IOU between detections and ground-truths using mask-IOU. Implementation is based on pycocotools toolkit for mask_utils. @@ -306,19 +306,19 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detections: List[Tensor] - detection_scores: List[Tensor] - detection_labels: List[Tensor] - groundtruths: List[Tensor] - groundtruth_labels: List[Tensor] + detections: list[Tensor] + detection_scores: list[Tensor] + detection_labels: list[Tensor] + groundtruths: list[Tensor] + groundtruth_labels: list[Tensor] def __init__( self, box_format: str = "xyxy", iou_type: str = "bbox", - iou_thresholds: Optional[List[float]] = None, - rec_thresholds: Optional[List[float]] = None, - max_detection_thresholds: Optional[List[int]] = None, + iou_thresholds: Optional[list[float]] = None, + rec_thresholds: Optional[list[float]] = None, + max_detection_thresholds: Optional[list[int]] = None, class_metrics: bool = False, **kwargs: Any, ) -> None: @@ -365,7 +365,7 @@ def __init__( self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update state with predictions and targets.""" _input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type] @@ -394,7 +394,7 @@ def _move_list_states_to_cpu(self) -> None: current_to_cpu.append(cur_v) setattr(self, key, current_to_cpu) - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + def _get_safe_item_values(self, item: dict[str, Any]) -> Union[Tensor, tuple]: import pycocotools.mask as mask_utils from torchvision.ops import box_convert @@ -411,7 +411,7 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: return tuple(masks) raise Exception(f"IOU type {self.iou_type} is not supported") - def _get_classes(self) -> List: + def _get_classes(self) -> list: """Return a list of unique classes found in ground truth and detection data.""" if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: return torch.cat(self.detection_labels + self.groundtruth_labels).unique().tolist() @@ -458,8 +458,8 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: return compute_iou(det, gt, self.iou_type).to(self.device) def __evaluate_image_gt_no_preds( - self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], num_iou_thrs: int - ) -> Dict[str, Any]: + self, gt: Tensor, gt_label_mask: Tensor, area_range: tuple[int, int], num_iou_thrs: int + ) -> dict[str, Any]: """Evaluate images with a ground truth but no predictions.""" # GTs gt = [gt[i] for i in gt_label_mask] @@ -487,9 +487,9 @@ def __evaluate_image_preds_no_gt( idx: int, det_label_mask: Tensor, max_det: int, - area_range: Tuple[int, int], + area_range: tuple[int, int], num_iou_thrs: int, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Evaluate images with a prediction but no ground truth.""" # GTs num_gt = 0 @@ -521,7 +521,7 @@ def __evaluate_image_preds_no_gt( } def _evaluate_image( - self, idx: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict + self, idx: int, class_id: int, area_range: tuple[int, int], max_det: int, ious: dict ) -> Optional[dict]: """Perform evaluation for single class and image. @@ -652,7 +652,7 @@ def _find_best_gt_match( def _summarize( self, - results: Dict, + results: dict, avg_prec: bool = True, iou_threshold: Optional[float] = None, area_range: str = "all", @@ -695,7 +695,7 @@ def _summarize( return torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1]) - def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]: + def _calculate(self, class_ids: list) -> tuple[MAPMetricResults, MARMetricResults]: """Calculate the precision and recall for all supplied classes to calculate mAP/mAR. Args: @@ -753,7 +753,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult return precision, recall # type: ignore[return-value] - def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]: + def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> tuple[MAPMetricResults, MARMetricResults]: """Summarizes the precision and recall values to calculate mAP/mAR. Args: @@ -801,7 +801,7 @@ def __calculate_recall_precision_scores( max_det: int, num_imgs: int, num_bbox_areas: int, - ) -> Tuple[Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor]: num_rec_thrs = len(rec_thresholds) idx_cls_pointer = idx_cls * num_bbox_areas * num_imgs idx_bbox_area_pointer = idx_bbox_area * num_imgs @@ -917,8 +917,8 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt @staticmethod def _gather_tuple_list( - list_to_gather: List[Union[tuple, Tensor]], process_group: Optional[Any] = None - ) -> List[Any]: + list_to_gather: list[Union[tuple, Tensor]], process_group: Optional[Any] = None + ) -> list[Any]: """Gather a list of tuples over multiple devices.""" world_size = dist.get_world_size(group=process_group) dist.barrier(group=process_group) @@ -929,7 +929,7 @@ def _gather_tuple_list( return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index] def plot( - self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 97eb91f0ac9..ae4c6c2b88c 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -18,9 +18,9 @@ def _input_validator( - preds: Sequence[Dict[str, Tensor]], - targets: Sequence[Dict[str, Tensor]], - iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox", + preds: Sequence[dict[str, Tensor]], + targets: Sequence[dict[str, Tensor]], + iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"]]] = "bbox", ignore_score: bool = False, ) -> None: """Ensure the correct input format of `preds` and `targets`.""" @@ -89,8 +89,8 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: def _validate_iou_type_arg( - iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", -) -> Tuple[str]: + iou_type: Union[Literal["bbox", "segm"], tuple[str]] = "bbox", +) -> tuple[str]: """Validate that iou type argument is correct.""" allowed_iou_types = ("segm", "bbox") if isinstance(iou_type, str): diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 4cf7fbfbd43..9579eeae06d 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -132,8 +132,8 @@ class IntersectionOverUnion(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = True - groundtruth_labels: List[Tensor] - iou_matrix: List[Tensor] + groundtruth_labels: list[Tensor] + iou_matrix: list[Tensor] _iou_type: str = "iou" _invalid_val: float = -1.0 @@ -179,7 +179,7 @@ def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: return _iou_compute(*args, **kwargs) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update state with predictions and targets.""" _input_validator(preds, target, ignore_score=True) @@ -205,7 +205,7 @@ def _get_safe_item_values(self, boxes: Tensor) -> Tensor: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") return boxes - def _get_gt_classes(self) -> List: + def _get_gt_classes(self) -> list: """Returns a list of unique classes found in ground truth and detection data.""" if len(self.groundtruth_labels) > 0: return torch.cat(self.groundtruth_labels).unique().tolist() @@ -214,7 +214,7 @@ def _get_gt_classes(self) -> List: def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() - results: Dict[str, Tensor] = {f"{self._iou_type}": score} + results: dict[str, Tensor] = {f"{self._iou_type}": score} if torch.isnan(score): # if no valid boxes are found results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device) if self.class_metrics: diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 3cd67cd6e2c..2a3c1095a53 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -48,7 +48,7 @@ ] -def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> Tuple[object, object, ModuleType]: +def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> tuple[object, object, ModuleType]: """Load the backend tools for the given backend.""" if backend == "pycocotools": if not _PYCOCOTOOLS_AVAILABLE: @@ -344,19 +344,19 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detection_box: List[Tensor] - detection_mask: List[Tensor] - detection_scores: List[Tensor] - detection_labels: List[Tensor] - groundtruth_box: List[Tensor] - groundtruth_mask: List[Tensor] - groundtruth_labels: List[Tensor] - groundtruth_crowds: List[Tensor] - groundtruth_area: List[Tensor] + detection_box: list[Tensor] + detection_mask: list[Tensor] + detection_scores: list[Tensor] + detection_labels: list[Tensor] + groundtruth_box: list[Tensor] + groundtruth_mask: list[Tensor] + groundtruth_labels: list[Tensor] + groundtruth_crowds: list[Tensor] + groundtruth_area: list[Tensor] warn_on_many_detections: bool = True - __jit_unused_properties__: ClassVar[List[str]] = [ + __jit_unused_properties__: ClassVar[list[str]] = [ "is_differentiable", "higher_is_better", "plot_lower_bound", @@ -373,10 +373,10 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", - iou_thresholds: Optional[List[float]] = None, - rec_thresholds: Optional[List[float]] = None, - max_detection_thresholds: Optional[List[int]] = None, + iou_type: Union[Literal["bbox", "segm"], tuple[str]] = "bbox", + iou_thresholds: Optional[list[float]] = None, + rec_thresholds: Optional[list[float]] = None, + max_detection_thresholds: Optional[list[int]] = None, class_metrics: bool = False, extended_summary: bool = False, average: Literal["macro", "micro"] = "macro", @@ -475,7 +475,7 @@ def mask_utils(self) -> object: _, _, mask_utils = _load_backend_tools(self.backend) return mask_utils - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update metric state. Raises: @@ -597,7 +597,7 @@ def compute(self) -> dict: return result_dict - def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object, object]: + def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> tuple[object, object]: """Returns the coco datasets for the target and the predictions.""" if average == "micro": # for micro averaging we set everything to be the same class @@ -629,7 +629,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object return coco_preds, coco_target - def _coco_stats_to_tensor_dict(self, stats: List[float], prefix: str) -> Dict[str, Tensor]: + def _coco_stats_to_tensor_dict(self, stats: list[float], prefix: str) -> dict[str, Tensor]: """Converts the output of COCOeval.stats to a dict of tensors.""" mdt = self.max_detection_thresholds return { @@ -651,9 +651,9 @@ def _coco_stats_to_tensor_dict(self, stats: List[float], prefix: str) -> Dict[st def coco_to_tm( coco_preds: str, coco_target: str, - iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", + iou_type: Union[Literal["bbox", "segm"], list[str]] = "bbox", backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools", - ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: + ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: """Utility function for converting .json coco format files to the input format of this metric. The function accepts a file for the predictions and a file for the target in coco format and converts them to @@ -825,8 +825,8 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: f.write(target_json) def _get_safe_item_values( - self, item: Dict[str, Any], warn: bool = False - ) -> Tuple[Optional[Tensor], Optional[Tuple]]: + self, item: dict[str, Any], warn: bool = False + ) -> tuple[Optional[Tensor], Optional[tuple]]: """Convert and return the boxes or masks from the item depending on the iou_type. Args: @@ -858,7 +858,7 @@ def _get_safe_item_values( _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return output # type: ignore[return-value] - def _get_classes(self) -> List: + def _get_classes(self) -> list: """Return a list of unique classes found in ground truth and detection data.""" if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() @@ -866,13 +866,13 @@ def _get_classes(self) -> List: def _get_coco_format( self, - labels: List[torch.Tensor], - boxes: Optional[List[torch.Tensor]] = None, - masks: Optional[List[torch.Tensor]] = None, - scores: Optional[List[torch.Tensor]] = None, - crowds: Optional[List[torch.Tensor]] = None, - area: Optional[List[torch.Tensor]] = None, - ) -> Dict: + labels: list[torch.Tensor], + boxes: Optional[list[torch.Tensor]] = None, + masks: Optional[list[torch.Tensor]] = None, + scores: Optional[list[torch.Tensor]] = None, + crowds: Optional[list[torch.Tensor]] = None, + area: Optional[list[torch.Tensor]] = None, + ) -> dict: """Transforms and returns all cached targets or predictions in COCO format. Format is defined at @@ -958,7 +958,7 @@ def _get_coco_format( return {"images": images, "annotations": annotations, "categories": classes} def plot( - self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -1043,7 +1043,7 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) # type: ignore[arg-type] @staticmethod - def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + def _gather_tuple_list(list_to_gather: list[tuple], process_group: Optional[Any] = None) -> list[Any]: """Gather a list of tuples over multiple devices. Args: diff --git a/src/torchmetrics/functional/audio/_deprecated.py b/src/torchmetrics/functional/audio/_deprecated.py index 8b337318f7a..4b31c5db37d 100644 --- a/src/torchmetrics/functional/audio/_deprecated.py +++ b/src/torchmetrics/functional/audio/_deprecated.py @@ -16,7 +16,7 @@ def _permutation_invariant_training( mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise", eval_func: Literal["max", "min"] = "max", **kwargs: Any, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Wrapper for deprecated import. >>> from torch import tensor diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py index 8f74f8374c9..13935d9cdd2 100644 --- a/src/torchmetrics/functional/audio/dnsmos.py +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -33,7 +33,7 @@ class InferenceSession: # type:ignore """Dummy InferenceSession.""" - def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + def __init__(self, **kwargs: dict[str, Any]) -> None: ... __doctest_requires__ = { diff --git a/src/torchmetrics/functional/audio/nisqa.py b/src/torchmetrics/functional/audio/nisqa.py index 1a1c066ebf0..a900d538d59 100644 --- a/src/torchmetrics/functional/audio/nisqa.py +++ b/src/torchmetrics/functional/audio/nisqa.py @@ -121,7 +121,7 @@ def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor: @lru_cache -def _load_nisqa_model() -> Tuple[nn.Module, Dict[str, Any]]: +def _load_nisqa_model() -> tuple[nn.Module, dict[str, Any]]: """Load NISQA model and its parameters. Returns: @@ -157,7 +157,7 @@ class _NISQADIM(nn.Module): # ported from https://github.com/gabrielmittag/NISQA # Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab # MIT License - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.cnn = _Framewise(args) self.time_dependency = _TimeDependency(args) @@ -173,7 +173,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _Framewise(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _AdaptCNN(args) @@ -187,7 +187,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _AdaptCNN(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.pool_1 = args["cnn_pool_1"] self.pool_2 = args["cnn_pool_2"] @@ -231,7 +231,7 @@ def forward(self, x: Tensor) -> Tensor: class _TimeDependency(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _SelfAttention(args) @@ -241,7 +241,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _SelfAttention(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() encoder_layer = _SelfAttentionLayer(args) self.norm1 = nn.LayerNorm(args["td_sa_d_model"]) @@ -254,7 +254,7 @@ def _reset_parameters(self) -> None: if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]: src = self.linear(src) output = src.transpose(1, 0) output = self.norm1(output) @@ -265,7 +265,7 @@ def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: class _SelfAttentionLayer(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.self_attn = nn.MultiheadAttention(args["td_sa_d_model"], args["td_sa_nhead"], args["td_sa_dropout"]) self.linear1 = nn.Linear(args["td_sa_d_model"], args["td_sa_h"]) @@ -277,7 +277,7 @@ def __init__(self, args: Dict[str, Any]) -> None: self.dropout2 = nn.Dropout(args["td_sa_dropout"]) self.activation = relu - def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]: mask = torch.arange(src.shape[0])[None, :] < n_wins[:, None] src2 = self.self_attn(src, src, src, key_padding_mask=~mask)[0] src = src + self.dropout1(src2) @@ -290,7 +290,7 @@ def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: class _Pooling(nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.model = _PoolAttFF(args) @@ -300,7 +300,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: class _PoolAttFF(torch.nn.Module): # part of NISQA model definition - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: super().__init__() self.linear1 = nn.Linear(args["td_sa_d_model"], args["pool_att_h"]) self.linear2 = nn.Linear(args["pool_att_h"], 1) @@ -319,7 +319,7 @@ def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: return self.linear3(x) -def _get_librosa_melspec(y: np.ndarray, sr: int, args: Dict[str, Any]) -> np.ndarray: +def _get_librosa_melspec(y: np.ndarray, sr: int, args: dict[str, Any]) -> np.ndarray: """Compute mel spectrogram from waveform using librosa. Args: @@ -360,7 +360,7 @@ def _get_librosa_melspec(y: np.ndarray, sr: int, args: Dict[str, Any]) -> np.nda return np.stack([librosa.amplitude_to_db(m, ref=1.0, amin=1e-4, top_db=80.0) for m in melspec]) -def _segment_specs(x: Tensor, args: Dict[str, Any]) -> Tuple[Tensor, Tensor]: +def _segment_specs(x: Tensor, args: dict[str, Any]) -> tuple[Tensor, Tensor]: """Segment mel spectrogram into overlapping windows. Args: diff --git a/src/torchmetrics/functional/audio/pit.py b/src/torchmetrics/functional/audio/pit.py index a7cc72d48f7..6fc431811c1 100644 --- a/src/torchmetrics/functional/audio/pit.py +++ b/src/torchmetrics/functional/audio/pit.py @@ -42,7 +42,7 @@ def _gen_permutations(spk_num: int, device: torch.device) -> Tensor: def _find_best_perm_by_linear_sum_assignment( metric_mtx: Tensor, eval_func: Callable, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Solves the linear sum assignment problem. This implementation uses scipy and input is therefore transferred to cpu during calculations. @@ -68,7 +68,7 @@ def _find_best_perm_by_linear_sum_assignment( def _find_best_perm_by_exhaustive_method( metric_mtx: Tensor, eval_func: Callable, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Solves the linear sum assignment problem using exhaustive method. This is done by exhaustively calculating the metric values of all possible permutations, and returns the best metric @@ -111,7 +111,7 @@ def permutation_invariant_training( mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise", eval_func: Literal["max", "min"] = "max", **kwargs: Any, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Calculate `Permutation invariant training`_ (PIT). This metric can evaluate models for speaker independent multi-talker speech separation in a permutation diff --git a/src/torchmetrics/functional/audio/sdr.py b/src/torchmetrics/functional/audio/sdr.py index f6cb0bf2a2e..a68c3e5f047 100644 --- a/src/torchmetrics/functional/audio/sdr.py +++ b/src/torchmetrics/functional/audio/sdr.py @@ -53,7 +53,7 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: ).flip(dims=(-1,)) -def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> Tuple[Tensor, Tensor]: +def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> tuple[Tensor, Tensor]: r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds`. This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz metric of the diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index 26d27f00999..daf7befed02 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -56,7 +56,7 @@ def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.devi @lru_cache(maxsize=100) def _compute_modulation_filterbank_and_cutoffs( min_cf: float, max_cf: float, n: int, fs: float, q: int, device: torch.device -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: # this function is translated from the SRMRpy packaged spacing_factor = (max_cf / min_cf) ** (1.0 / (n - 1)) cfs = torch.zeros(n, dtype=torch.float64) @@ -73,7 +73,7 @@ def _make_modulation_filter(w0: Tensor, q: int) -> Tensor: mfb = torch.stack([_make_modulation_filter(w0, q) for w0 in 2 * pi * cfs / fs], dim=0) - def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> Tuple[Tensor, Tensor]: + def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> tuple[Tensor, Tensor]: # Calculates cutoff frequencies (3 dB) for 2nd order bandpass w0 = 2 * pi * cfs / fs b0 = torch.tan(w0 / 2) / q diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index fb802c05ec3..03bd8ff3aa0 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -43,8 +43,8 @@ def _reduce_auroc( - fpr: Union[Tensor, List[Tensor]], - tpr: Union[Tensor, List[Tensor]], + fpr: Union[Tensor, list[Tensor]], + tpr: Union[Tensor, list[Tensor]], average: Optional[Literal["macro", "weighted", "none"]] = "macro", weights: Optional[Tensor] = None, direction: float = 1.0, @@ -72,7 +72,7 @@ def _reduce_auroc( def _binary_auroc_arg_validation( max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -81,7 +81,7 @@ def _binary_auroc_arg_validation( def _binary_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], max_fpr: Optional[float] = None, pos_label: int = 1, @@ -111,7 +111,7 @@ def binary_auroc( preds: Tensor, target: Tensor, max_fpr: Optional[float] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -181,7 +181,7 @@ def binary_auroc( def _multiclass_auroc_arg_validation( num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -191,7 +191,7 @@ def _multiclass_auroc_arg_validation( def _multiclass_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Tensor] = None, @@ -210,7 +210,7 @@ def multiclass_auroc( target: Tensor, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -296,7 +296,7 @@ def multiclass_auroc( def _multilabel_auroc_arg_validation( num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -306,7 +306,7 @@ def _multilabel_auroc_arg_validation( def _multilabel_auroc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Tensor], @@ -338,7 +338,7 @@ def multilabel_auroc( target: Tensor, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -429,7 +429,7 @@ def auroc( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 93002bb6d2b..cd941e0a9df 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -41,8 +41,8 @@ def _reduce_average_precision( - precision: Union[Tensor, List[Tensor]], - recall: Union[Tensor, List[Tensor]], + precision: Union[Tensor, list[Tensor]], + recall: Union[Tensor, list[Tensor]], average: Optional[Literal["macro", "weighted", "none"]] = "macro", weights: Optional[Tensor] = None, ) -> Tensor: @@ -68,7 +68,7 @@ def _reduce_average_precision( def _binary_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], ) -> Tensor: precision, recall, _ = _binary_precision_recall_curve_compute(state, thresholds) @@ -78,7 +78,7 @@ def _binary_average_precision_compute( def binary_average_precision( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -152,7 +152,7 @@ def binary_average_precision( def _multiclass_average_precision_arg_validation( num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -162,7 +162,7 @@ def _multiclass_average_precision_arg_validation( def _multiclass_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Tensor] = None, @@ -181,7 +181,7 @@ def multiclass_average_precision( target: Tensor, num_classes: int, average: Optional[Literal["macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -272,7 +272,7 @@ def multiclass_average_precision( def _multilabel_average_precision_arg_validation( num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -282,7 +282,7 @@ def _multilabel_average_precision_arg_validation( def _multilabel_average_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Tensor], @@ -314,7 +314,7 @@ def multilabel_average_precision( target: Tensor, num_labels: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -410,7 +410,7 @@ def average_precision( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index ebb5eecbe0d..9d994679a42 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -28,7 +28,7 @@ def _binning_bucketize( confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute calibration bins using ``torch.bucketize``. Use for ``pytorch >=1.6``. Args: @@ -133,7 +133,7 @@ def _binary_calibration_error_tensor_validation( ) -def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: confidences, accuracies = preds, target return confidences, accuracies @@ -238,7 +238,7 @@ def _multiclass_calibration_error_tensor_validation( def _multiclass_calibration_error_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) confidences, predictions = preds.max(dim=1) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 1b93450ab69..31d9bbb7f46 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -121,7 +121,7 @@ def _binary_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, convert_to_labels: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - Remove all datapoints that should be ignored @@ -300,7 +300,7 @@ def _multiclass_confusion_matrix_format( target: Tensor, ignore_index: Optional[int] = None, convert_to_labels: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - Applies argmax if preds have one more dimension than target @@ -482,7 +482,7 @@ def _multilabel_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, should_threshold: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 9e6f7c49df7..2b0339e0adc 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -42,7 +42,7 @@ def _multiclass_exact_match_update( target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute the statistics.""" if ignore_index is not None: preds = preds.clone() @@ -123,7 +123,7 @@ def multiclass_exact_match( def _multilabel_exact_match_update( preds: Tensor, target: Tensor, num_labels: int, multidim_average: Literal["global", "samplewise"] = "global" -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute the statistics.""" if multidim_average == "global": preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels) diff --git a/src/torchmetrics/functional/classification/group_fairness.py b/src/torchmetrics/functional/classification/group_fairness.py index edccd4a6b2d..2d5da8582ed 100644 --- a/src/torchmetrics/functional/classification/group_fairness.py +++ b/src/torchmetrics/functional/classification/group_fairness.py @@ -57,7 +57,7 @@ def _binary_groups_stat_scores( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: +) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute the true/false positives and true/false negatives rates for binary classification by group. Related to `Type I and Type II errors`_. @@ -84,15 +84,15 @@ def _binary_groups_stat_scores( def _groups_reduce( - group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], -) -> Dict[str, torch.Tensor]: + group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], +) -> dict[str, torch.Tensor]: """Compute rates for all the group statistics.""" return {f"group_{group}": torch.stack(stats) / torch.stack(stats).sum() for group, stats in enumerate(group_stats)} def _groups_stat_transform( - group_stats: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], -) -> Dict[str, torch.Tensor]: + group_stats: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], +) -> dict[str, torch.Tensor]: """Transform group statistics by creating a tensor for each statistic.""" return { "tp": torch.stack([stat[0] for stat in group_stats]), @@ -110,7 +110,7 @@ def binary_groups_stat_rates( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""Compute the true/false positives and true/false negatives rates for binary classification by group. Related to `Type I and Type II errors`_. @@ -163,7 +163,7 @@ def binary_groups_stat_rates( def _compute_binary_demographic_parity( tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Compute demographic parity based on the binary stats.""" pos_rates = _safe_divide(tp + fp, tp + fp + tn + fn) min_pos_rate_id = torch.argmin(pos_rates) @@ -180,7 +180,7 @@ def demographic_parity( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""`Demographic parity`_ compares the positivity rates between all groups. If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest @@ -242,7 +242,7 @@ def demographic_parity( def _compute_binary_equal_opportunity( tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Compute equal opportunity based on the binary stats.""" true_pos_rates = _safe_divide(tp, tp + fn) min_pos_rate_id = torch.argmin(true_pos_rates) @@ -262,7 +262,7 @@ def equal_opportunity( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""`Equal opportunity`_ compares the true positive rates between all groups. If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest @@ -331,7 +331,7 @@ def binary_fairness( threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: r"""Compute either `Demographic parity`_ and `Equal opportunity`_ ratio for binary classification problems. This is done by setting the ``task`` argument to either ``'demographic_parity'``, ``'equal_opportunity'`` diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index c3fe40be105..d08df7d550d 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -51,7 +51,7 @@ def _binary_hinge_loss_update( preds: Tensor, target: Tensor, squared: bool, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] @@ -152,7 +152,7 @@ def _multiclass_hinge_loss_update( target: Tensor, squared: bool, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) diff --git a/src/torchmetrics/functional/classification/precision_fixed_recall.py b/src/torchmetrics/functional/classification/precision_fixed_recall.py index a708f703dc1..d16beba4664 100644 --- a/src/torchmetrics/functional/classification/precision_fixed_recall.py +++ b/src/torchmetrics/functional/classification/precision_fixed_recall.py @@ -44,7 +44,7 @@ def _precision_at_recall( recall: Tensor, thresholds: Tensor, min_recall: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: try: max_precision, _, best_threshold = max( (p, r, t) for p, r, t in zip(precision, recall, thresholds) if r >= min_recall @@ -64,10 +64,10 @@ def binary_precision_at_fixed_recall( preds: Tensor, target: Tensor, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for binary tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -140,10 +140,10 @@ def multiclass_precision_at_fixed_recall( target: Tensor, num_classes: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for multiclass tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -226,10 +226,10 @@ def multilabel_precision_at_fixed_recall( target: Tensor, num_labels: int, min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible precision value given the minimum recall thresholds provided for multilabel tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the precision @@ -311,12 +311,12 @@ def precision_at_fixed_recall( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_recall: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Optional[Tuple[Tensor, Tensor]]: +) -> Optional[tuple[Tensor, Tensor]]: r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 3c01ceeed36..c498ffcc864 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -32,7 +32,7 @@ def _binary_clf_curve( target: Tensor, sample_weights: Optional[Union[Sequence, Tensor]] = None, pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the TPs and false positives for all unique thresholds in the preds tensor. Adapted from @@ -83,7 +83,7 @@ def _binary_clf_curve( def _adjust_threshold_arg( - thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None + thresholds: Optional[Union[int, list[float], Tensor]] = None, device: Optional[torch.device] = None ) -> Optional[Tensor]: """Convert threshold arg for list and int to tensor format.""" if isinstance(thresholds, int): @@ -94,7 +94,7 @@ def _adjust_threshold_arg( def _binary_precision_recall_curve_arg_validation( - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: """Validate non tensor input. @@ -164,9 +164,9 @@ def _binary_precision_recall_curve_tensor_validation( def _binary_precision_recall_curve_format( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -193,7 +193,7 @@ def _binary_precision_recall_curve_update( preds: Tensor, target: Tensor, thresholds: Optional[Tensor], -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -213,7 +213,7 @@ def _binary_precision_recall_curve_update_vectorized( preds: Tensor, target: Tensor, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small @@ -231,7 +231,7 @@ def _binary_precision_recall_curve_update_loop( preds: Tensor, target: Tensor, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation loops over thresholds and is more memory-efficient than @@ -253,10 +253,10 @@ def _binary_precision_recall_curve_update_loop( def _binary_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -294,10 +294,10 @@ def _binary_precision_recall_curve_compute( def binary_precision_recall_curve( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Compute the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -369,7 +369,7 @@ def binary_precision_recall_curve( def _multiclass_precision_recall_curve_arg_validation( num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ) -> None: @@ -432,10 +432,10 @@ def _multiclass_precision_recall_curve_format( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -469,7 +469,7 @@ def _multiclass_precision_recall_curve_update( num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -492,7 +492,7 @@ def _multiclass_precision_recall_curve_update_vectorized( target: Tensor, num_classes: int, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the multi-threshold confusion matrix to calculate the pr-curve with. This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small @@ -514,7 +514,7 @@ def _multiclass_precision_recall_curve_update_loop( target: Tensor, num_classes: int, thresholds: Tensor, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. This implementation loops over thresholds and is more memory-efficient than @@ -536,11 +536,11 @@ def _multiclass_precision_recall_curve_update_loop( def _multiclass_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -595,11 +595,11 @@ def multiclass_precision_recall_curve( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -712,7 +712,7 @@ def multiclass_precision_recall_curve( def _multilabel_precision_recall_curve_arg_validation( num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: """Validate non tensor input. @@ -748,9 +748,9 @@ def _multilabel_precision_recall_curve_format( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. - flattens additional dimensions @@ -781,7 +781,7 @@ def _multilabel_precision_recall_curve_update( target: Tensor, num_labels: int, thresholds: Optional[Tensor], -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi @@ -802,11 +802,11 @@ def _multilabel_precision_recall_curve_update( def _multilabel_precision_recall_curve_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -842,10 +842,10 @@ def multilabel_precision_recall_curve( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -947,13 +947,13 @@ def precision_recall_curve( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the precision-recall curve. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index 87bade5e88d..d78fe807a88 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -45,7 +45,7 @@ def _multilabel_ranking_tensor_validation( raise ValueError(f"Expected preds tensor to be floating point, but received input with dtype {preds.dtype}") -def _multilabel_coverage_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_coverage_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for coverage error.""" offset = torch.zeros_like(preds) offset[target == 0] = preds.min().abs() + 10 # Any number >1 works @@ -109,7 +109,7 @@ def multilabel_coverage_error( return _ranking_reduce(coverage, total) -def _multilabel_ranking_average_precision_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_ranking_average_precision_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for label ranking average precision.""" # Invert so that the highest score receives rank 1 neg_preds = -preds @@ -182,7 +182,7 @@ def multilabel_ranking_average_precision( return _ranking_reduce(score, num_elements) -def _multilabel_ranking_loss_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _multilabel_ranking_loss_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Accumulate state for label ranking loss. Args: diff --git a/src/torchmetrics/functional/classification/recall_fixed_precision.py b/src/torchmetrics/functional/classification/recall_fixed_precision.py index 745a9de2c34..72faf8b8b6d 100644 --- a/src/torchmetrics/functional/classification/recall_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_fixed_precision.py @@ -60,7 +60,7 @@ def _recall_at_precision( recall: Tensor, thresholds: Tensor, min_precision: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) best_threshold = torch.tensor(0) @@ -78,7 +78,7 @@ def _recall_at_precision( def _binary_recall_at_fixed_precision_arg_validation( min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -89,12 +89,12 @@ def _binary_recall_at_fixed_precision_arg_validation( def _binary_recall_at_fixed_precision_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_precision: float, pos_label: int = 1, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _binary_precision_recall_curve_compute(state, thresholds, pos_label) return reduce_fn(precision, recall, thresholds, min_precision) @@ -103,10 +103,10 @@ def binary_recall_at_fixed_precision( preds: Tensor, target: Tensor, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for binary tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall @@ -175,7 +175,7 @@ def binary_recall_at_fixed_precision( def _multiclass_recall_at_fixed_precision_arg_validation( num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -186,12 +186,12 @@ def _multiclass_recall_at_fixed_precision_arg_validation( def _multiclass_recall_at_fixed_precision_arg_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_precision: float, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) if isinstance(state, Tensor): res = [reduce_fn(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] @@ -207,10 +207,10 @@ def multiclass_recall_at_fixed_precision( target: Tensor, num_classes: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for multiclass tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a @@ -285,7 +285,7 @@ def multiclass_recall_at_fixed_precision( def _multilabel_recall_at_fixed_precision_arg_validation( num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -296,13 +296,13 @@ def _multilabel_recall_at_fixed_precision_arg_validation( def _multilabel_recall_at_fixed_precision_arg_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_precision: float, reduce_fn: Callable = _recall_at_precision, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: precision, recall, thresholds = _multilabel_precision_recall_curve_compute( state, num_labels, thresholds, ignore_index ) @@ -320,10 +320,10 @@ def multilabel_recall_at_fixed_precision( target: Tensor, num_labels: int, min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible recall value given the minimum precision thresholds provided for multilabel tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a @@ -403,12 +403,12 @@ def recall_at_fixed_precision( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_precision: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Optional[Tuple[Tensor, Tensor]]: +) -> Optional[tuple[Tensor, Tensor]]: r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index d61b920aa9b..f2374778f01 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -38,10 +38,10 @@ def _binary_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], pos_label: int = 1, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: if isinstance(state, Tensor) and thresholds is not None: tps = state[:, 1, 1] fps = state[:, 0, 1] @@ -83,10 +83,10 @@ def _binary_roc_compute( def binary_roc( preds: Tensor, target: Tensor, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Compute the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -160,11 +160,11 @@ def binary_roc( def _multiclass_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: if average == "micro": return _binary_roc_compute(state, thresholds, pos_label=1) @@ -208,11 +208,11 @@ def multiclass_roc( preds: Tensor, target: Tensor, num_classes: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multiclass tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -327,11 +327,11 @@ def multiclass_roc( def _multilabel_roc_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -360,10 +360,10 @@ def multilabel_roc( preds: Tensor, target: Tensor, num_labels: int, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multilabel tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -472,13 +472,13 @@ def roc( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index b1f7e456b06..bc590a4c6b9 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -49,7 +49,7 @@ def _sensitivity_at_specificity( specificity: Tensor, thresholds: Tensor, min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: # get indices where specificity is greater than min_specificity indices = specificity >= min_specificity @@ -72,7 +72,7 @@ def _sensitivity_at_specificity( def _binary_sensitivity_at_specificity_arg_validation( min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -83,11 +83,11 @@ def _binary_sensitivity_at_specificity_arg_validation( def _binary_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_specificity: float, pos_label: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) specificity = _convert_fpr_to_specificity(fpr) return _sensitivity_at_specificity(sensitivity, specificity, thresholds, min_specificity) @@ -97,10 +97,10 @@ def binary_sensitivity_at_specificity( preds: Tensor, target: Tensor, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given the minimum specificity levels provided for binary tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -169,7 +169,7 @@ def binary_sensitivity_at_specificity( def _multiclass_sensitivity_at_specificity_arg_validation( num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -180,11 +180,11 @@ def _multiclass_sensitivity_at_specificity_arg_validation( def _multiclass_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -207,10 +207,10 @@ def multiclass_sensitivity_at_specificity( target: Tensor, num_classes: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given minimum specificity level provided for multiclass tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the @@ -285,7 +285,7 @@ def multiclass_sensitivity_at_specificity( def _multilabel_sensitivity_at_specificity_arg_validation( num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -296,12 +296,12 @@ def _multilabel_sensitivity_at_specificity_arg_validation( def _multilabel_sensitivity_at_specificity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_specificity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -324,10 +324,10 @@ def multilabel_sensitivity_at_specificity( target: Tensor, num_labels: int, min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible sensitivity value given minimum specificity level provided for multilabel tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -408,12 +408,12 @@ def sensitivity_at_specificity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_specificity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index d85b47eb453..3f93cea467d 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -50,7 +50,7 @@ def _specificity_at_sensitivity( sensitivity: Tensor, thresholds: Tensor, min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: # get indices where sensitivity is greater than min_sensitivity indices = sensitivity >= min_sensitivity @@ -73,7 +73,7 @@ def _specificity_at_sensitivity( def _binary_specificity_at_sensitivity_arg_validation( min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -84,11 +84,11 @@ def _binary_specificity_at_sensitivity_arg_validation( def _binary_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], thresholds: Optional[Tensor], min_sensitivity: float, pos_label: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) specificity = _convert_fpr_to_specificity(fpr) return _specificity_at_sensitivity(specificity, sensitivity, thresholds, min_sensitivity) @@ -98,10 +98,10 @@ def binary_specificity_at_sensitivity( preds: Tensor, target: Tensor, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given the minimum sensitivity levels provided for binary tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -170,7 +170,7 @@ def binary_specificity_at_sensitivity( def _multiclass_specificity_at_sensitivity_arg_validation( num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) @@ -181,11 +181,11 @@ def _multiclass_specificity_at_sensitivity_arg_validation( def _multiclass_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -208,10 +208,10 @@ def multiclass_specificity_at_sensitivity( target: Tensor, num_classes: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given minimum sensitivity level provided for multiclass tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the @@ -286,7 +286,7 @@ def multiclass_specificity_at_sensitivity( def _multilabel_specificity_at_sensitivity_arg_validation( num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, ) -> None: _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) @@ -297,12 +297,12 @@ def _multilabel_specificity_at_sensitivity_arg_validation( def _multilabel_specificity_at_sensitivity_compute( - state: Union[Tensor, Tuple[Tensor, Tensor]], + state: Union[Tensor, tuple[Tensor, Tensor]], num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int], min_sensitivity: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] if isinstance(state, Tensor): @@ -325,10 +325,10 @@ def multilabel_specificity_at_sensitivity( target: Tensor, num_labels: int, min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: r"""Compute the highest possible specificity value given minimum sensitivity level provided for multilabel tasks. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and @@ -409,12 +409,12 @@ def specicity_at_sensitivity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. .. warning:: @@ -445,12 +445,12 @@ def specificity_at_sensitivity( target: Tensor, task: Literal["binary", "multiclass", "multilabel"], min_sensitivity: float, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 565c212f9bd..d6079de55d3 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -96,7 +96,7 @@ def _binary_stat_scores_format( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range @@ -125,7 +125,7 @@ def _binary_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics.""" sum_dim = [0, 1] if multidim_average == "global" else [1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() @@ -326,7 +326,7 @@ def _multiclass_stat_scores_format( preds: Tensor, target: Tensor, top_k: int = 1, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format except if ``top_k`` is not 1. - Applies argmax if preds have one more dimension than target @@ -349,7 +349,7 @@ def _multiclass_stat_scores_update( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics. - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and @@ -650,7 +650,7 @@ def _multilabel_stat_scores_tensor_validation( def _multilabel_stat_scores_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert all input to label format. - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range @@ -675,7 +675,7 @@ def _multilabel_stat_scores_format( def _multilabel_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global" -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics.""" sum_dim = [0, -1] if multidim_average == "global" else [-1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() @@ -828,7 +828,7 @@ def _del_column(data: Tensor, idx: int) -> Tensor: def _drop_negative_ignored_indices( preds: Tensor, target: Tensor, ignore_index: int, mode: DataType -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Remove negative ignored indices. Args: @@ -866,7 +866,7 @@ def _stat_scores( preds: Tensor, target: Tensor, reduce: Optional[str] = "micro", -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate the number of tp, fp, tn, fn. Args: @@ -892,7 +892,7 @@ def _stat_scores( - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors """ - dim: Union[int, List[int]] = 1 # for "samples" + dim: Union[int, list[int]] = 1 # for "samples" if reduce == "micro": dim = [0, 1] if preds.ndim == 2 else [1, 2] elif reduce == "macro": @@ -921,7 +921,7 @@ def _stat_scores_update( multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, mode: Optional[DataType] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate true positives, false positives, true negatives, false negatives. Raises: diff --git a/src/torchmetrics/functional/clustering/dunn_index.py b/src/torchmetrics/functional/clustering/dunn_index.py index 51120697002..05f7c87df33 100644 --- a/src/torchmetrics/functional/clustering/dunn_index.py +++ b/src/torchmetrics/functional/clustering/dunn_index.py @@ -18,7 +18,7 @@ from torch import Tensor -def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> Tuple[Tensor, Tensor]: +def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> tuple[Tensor, Tensor]: """Update and return variables required to compute the Dunn index. Args: diff --git a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py index e7faae5175e..c2820e9001a 100644 --- a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py @@ -19,7 +19,7 @@ from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels -def _fowlkes_mallows_index_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _fowlkes_mallows_index_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Return contingency matrix required to compute the Fowlkes-Mallows index. Args: diff --git a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py index e98f1e26b5b..7eb7b478430 100644 --- a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py @@ -20,7 +20,7 @@ from torchmetrics.functional.clustering.utils import calculate_entropy, check_cluster_labels -def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the homogeneity score of a clustering given the predicted and target cluster labels.""" check_cluster_labels(preds, target) @@ -36,7 +36,7 @@ def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, T return homogeneity, mutual_info, entropy_preds, entropy_target -def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _completeness_score_compute(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Computes the completeness score of a clustering given the predicted and target cluster labels.""" homogeneity, mutual_info, entropy_preds, _ = _homogeneity_score_compute(preds, target) completeness = mutual_info / entropy_preds if entropy_preds else torch.ones_like(entropy_preds) diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index b3b0fc33295..9b6660e5629 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -19,10 +19,10 @@ from torchmetrics.utilities import rank_zero_warn -_Color = Tuple[int, int] # A (category_id, instance_id) tuple that uniquely identifies a panoptic segment. +_Color = tuple[int, int] # A (category_id, instance_id) tuple that uniquely identifies a panoptic segment. -def _nested_tuple(nested_list: List) -> Tuple: +def _nested_tuple(nested_list: list) -> tuple: """Construct a nested tuple from a nested list. Args: @@ -35,7 +35,7 @@ def _nested_tuple(nested_list: List) -> Tuple: return tuple(map(_nested_tuple, nested_list)) if isinstance(nested_list, list) else nested_list -def _to_tuple(t: Tensor) -> Tuple: +def _to_tuple(t: Tensor) -> tuple: """Convert a tensor into a nested tuple. Args: @@ -48,7 +48,7 @@ def _to_tuple(t: Tensor) -> Tuple: return _nested_tuple(t.tolist()) -def _get_color_areas(inputs: Tensor) -> Dict[Tuple, Tensor]: +def _get_color_areas(inputs: Tensor) -> dict[tuple, Tensor]: """Measure the size of each instance. Args: @@ -63,7 +63,7 @@ def _get_color_areas(inputs: Tensor) -> Dict[Tuple, Tensor]: return dict(zip(_to_tuple(unique_keys), unique_keys_area)) -def _parse_categories(things: Collection[int], stuffs: Collection[int]) -> Tuple[Set[int], Set[int]]: +def _parse_categories(things: Collection[int], stuffs: Collection[int]) -> tuple[set[int], set[int]]: """Parse and validate metrics arguments for `things` and `stuff`. Args: @@ -122,7 +122,7 @@ def _validate_inputs(preds: Tensor, target: torch.Tensor) -> None: ) -def _get_void_color(things: Set[int], stuffs: Set[int]) -> Tuple[int, int]: +def _get_void_color(things: set[int], stuffs: set[int]) -> tuple[int, int]: """Get an unused color ID. Args: @@ -137,7 +137,7 @@ def _get_void_color(things: Set[int], stuffs: Set[int]) -> Tuple[int, int]: return unused_category_id, 0 -def _get_category_id_to_continuous_id(things: Set[int], stuffs: Set[int]) -> Dict[int, int]: +def _get_category_id_to_continuous_id(things: set[int], stuffs: set[int]) -> dict[int, int]: """Convert original IDs to continuous IDs. Args: @@ -158,7 +158,7 @@ def _get_category_id_to_continuous_id(things: Set[int], stuffs: Set[int]) -> Dic return cat_id_to_continuous_id -def _isin(arr: Tensor, values: List) -> Tensor: +def _isin(arr: Tensor, values: list) -> Tensor: """Check if all values of an arr are in another array. Implementation of torch.isin to support pre 0.10 version. Args: @@ -174,10 +174,10 @@ def _isin(arr: Tensor, values: List) -> Tensor: def _prepocess_inputs( - things: Set[int], - stuffs: Set[int], + things: set[int], + stuffs: set[int], inputs: Tensor, - void_color: Tuple[int, int], + void_color: tuple[int, int], allow_unknown_category: bool, ) -> Tensor: """Preprocesses an input tensor for metric calculation. @@ -215,9 +215,9 @@ def _prepocess_inputs( def _calculate_iou( pred_color: _Color, target_color: _Color, - pred_areas: Dict[_Color, Tensor], - target_areas: Dict[_Color, Tensor], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], + pred_areas: dict[_Color, Tensor], + target_areas: dict[_Color, Tensor], + intersection_areas: dict[tuple[_Color, _Color], Tensor], void_color: _Color, ) -> Tensor: """Helper function that calculates the IoU from precomputed areas of segments and their intersections. @@ -253,10 +253,10 @@ def _calculate_iou( def _filter_false_negatives( - target_areas: Dict[_Color, Tensor], - target_segment_matched: Set[_Color], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], - void_color: Tuple[int, int], + target_areas: dict[_Color, Tensor], + target_segment_matched: set[_Color], + intersection_areas: dict[tuple[_Color, _Color], Tensor], + void_color: tuple[int, int], ) -> Iterator[int]: """Filter false negative segments and yield their category IDs. @@ -282,10 +282,10 @@ def _filter_false_negatives( def _filter_false_positives( - pred_areas: Dict[_Color, Tensor], - pred_segment_matched: Set[_Color], - intersection_areas: Dict[Tuple[_Color, _Color], Tensor], - void_color: Tuple[int, int], + pred_areas: dict[_Color, Tensor], + pred_segment_matched: set[_Color], + intersection_areas: dict[tuple[_Color, _Color], Tensor], + void_color: tuple[int, int], ) -> Iterator[int]: """Filter false positive segments and yield their category IDs. @@ -313,10 +313,10 @@ def _filter_false_positives( def _panoptic_quality_update_sample( flatten_preds: Tensor, flatten_target: Tensor, - cat_id_to_continuous_id: Dict[int, int], - void_color: Tuple[int, int], - stuffs_modified_metric: Optional[Set[int]] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + cat_id_to_continuous_id: dict[int, int], + void_color: tuple[int, int], + stuffs_modified_metric: Optional[set[int]] = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate stat scores required to compute the metric **for a single sample**. Computed scores: iou sum, true positives, false positives, false negatives. @@ -352,11 +352,11 @@ def _panoptic_quality_update_sample( # calculate the area of each prediction, ground truth and pairwise intersection. # NOTE: mypy needs `cast()` because the annotation for `_get_color_areas` is too generic. - pred_areas = cast(Dict[_Color, Tensor], _get_color_areas(flatten_preds)) - target_areas = cast(Dict[_Color, Tensor], _get_color_areas(flatten_target)) + pred_areas = cast(dict[_Color, Tensor], _get_color_areas(flatten_preds)) + target_areas = cast(dict[_Color, Tensor], _get_color_areas(flatten_target)) # intersection matrix of shape [num_pixels, 2, 2] intersection_matrix = torch.transpose(torch.stack((flatten_preds, flatten_target), -1), -1, -2) - intersection_areas = cast(Dict[Tuple[_Color, _Color], Tensor], _get_color_areas(intersection_matrix)) + intersection_areas = cast(dict[tuple[_Color, _Color], Tensor], _get_color_areas(intersection_matrix)) # select intersection of things of same category with iou > 0.5 pred_segment_matched = set() @@ -398,10 +398,10 @@ def _panoptic_quality_update_sample( def _panoptic_quality_update( flatten_preds: Tensor, flatten_target: Tensor, - cat_id_to_continuous_id: Dict[int, int], - void_color: Tuple[int, int], - modified_metric_stuffs: Optional[Set[int]] = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + cat_id_to_continuous_id: dict[int, int], + void_color: tuple[int, int], + modified_metric_stuffs: Optional[set[int]] = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Calculate stat scores required to compute the metric for a full batch. Computed scores: iou sum, true positives, false positives, false negatives. @@ -450,7 +450,7 @@ def _panoptic_quality_compute( true_positives: Tensor, false_positives: Tensor, false_negatives: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Compute the final panoptic quality from interim values. Args: diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index efaab73cded..55485f47639 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -58,7 +58,7 @@ def _error_relative_global_dimensionless_synthesis( return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction) -def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def _image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Wrapper for deprecated import. >>> import torch @@ -80,10 +80,10 @@ def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: def _peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, ) -> Tensor: """Wrapper for deprecated import. @@ -116,7 +116,7 @@ def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: def _root_mean_squared_error_using_sliding_window( preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False -) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: +) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]: """Wrapper for deprecated import. >>> from torch import rand @@ -157,10 +157,10 @@ def _multiscale_structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Optional[Literal["relu", "simple"]] = "relu", ) -> Tensor: """Wrapper for deprecated import. @@ -195,12 +195,12 @@ def _structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> import torch diff --git a/src/torchmetrics/functional/image/d_lambda.py b/src/torchmetrics/functional/image/d_lambda.py index 5921f51d32d..668dde7c0fc 100644 --- a/src/torchmetrics/functional/image/d_lambda.py +++ b/src/torchmetrics/functional/image/d_lambda.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.distributed import reduce -def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spectral Distortion Index. Args: diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 824cd3bfa8b..f8f54873dbe 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -28,7 +28,7 @@ def _spatial_distortion_index_update( preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: """Update and returns variables required to compute Spatial Distortion Index. Args: diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index 41500c552e0..c69773b06ba 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -21,7 +21,7 @@ from torchmetrics.utilities.distributed import reduce -def _ergas_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ergas_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Erreur Relative Globale Adimensionnelle de Synthèse. Args: diff --git a/src/torchmetrics/functional/image/gradients.py b/src/torchmetrics/functional/image/gradients.py index e87d4a75198..683c67c153d 100644 --- a/src/torchmetrics/functional/image/gradients.py +++ b/src/torchmetrics/functional/image/gradients.py @@ -25,7 +25,7 @@ def _image_gradients_validate(img: Tensor) -> None: raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor") -def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def _compute_image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Compute image gradients (dy/dx) for a given image.""" batch_size, channels, height, width = img.shape @@ -43,7 +43,7 @@ def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: return dy, dx -def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: +def image_gradients(img: Tensor) -> tuple[Tensor, Tensor]: """Compute `Gradient Computation of Image`_ of a given image using finite difference. Args: diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index c557f61ead1..44006874120 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -205,7 +205,7 @@ def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor: return in_tens.mean([2, 3], keepdim=keep_dim) -def _upsample(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor: +def _upsample(in_tens: Tensor, out_hw: tuple[int, ...] = (64, 64)) -> Tensor: """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) @@ -331,7 +331,7 @@ def __init__( def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False - ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: + ) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 @@ -378,7 +378,7 @@ def _valid_img(img: Tensor, normalize: bool) -> bool: return img.ndim == 4 and img.shape[1] == 3 and value_check # type: ignore[return-value] -def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> Tuple[Tensor, Union[int, Tensor]]: +def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> tuple[Tensor, Union[int, Tensor]]: if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)): raise ValueError( "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." diff --git a/src/torchmetrics/functional/image/perceptual_path_length.py b/src/torchmetrics/functional/image/perceptual_path_length.py index 58b0a7bae05..b1a9e3f7857 100644 --- a/src/torchmetrics/functional/image/perceptual_path_length.py +++ b/src/torchmetrics/functional/image/perceptual_path_length.py @@ -162,7 +162,7 @@ def perceptual_path_length( upper_discard: Optional[float] = 0.99, sim_net: Union[nn.Module, Literal["alex", "vgg", "squeeze"]] = "vgg", device: Union[str, torch.device] = "cpu", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: r"""Computes the perceptual path length (`PPL`_) of a generator model. The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 7bd93ba94e1..01348c425c3 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -58,8 +58,8 @@ def _psnr_compute( def _psnr_update( preds: Tensor, target: Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None, -) -> Tuple[Tensor, Tensor]: + dim: Optional[Union[int, tuple[int, ...]]] = None, +) -> tuple[Tensor, Tensor]: """Update and return variables required to compute peak signal-to-noise ratio. Args: @@ -95,10 +95,10 @@ def _psnr_update( def peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, ) -> Tensor: """Compute the peak signal-to-noise ratio. diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index f725ea4bd80..1b67df53519 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -86,7 +86,7 @@ def _psnrb_compute( return 10 * torch.log10(1.0 / sum_squared_error) -def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> Tuple[Tensor, Tensor, Tensor]: +def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[Tensor, Tensor, Tensor]: """Updates and returns variables required to compute peak signal-to-noise ratio. Args: diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index 388f2c237a3..fd30d455724 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -23,7 +23,7 @@ def _rase_update( preds: Tensor, target: Tensor, window_size: int, rmse_map: Tensor, target_sum: Tensor, total_images: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the sum of RMSE map values for the batch of examples and update intermediate states. Args: diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index a27582bd11a..6d9b9eae235 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -28,7 +28,7 @@ def _rmse_sw_update( rmse_val_sum: Optional[Tensor], rmse_map: Optional[Tensor], total_images: Optional[Tensor], -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate the sum of RMSE values and RMSE map for the batch of examples and update intermediate states. Args: @@ -89,7 +89,7 @@ def _rmse_sw_update( def _rmse_sw_compute( rmse_val_sum: Optional[Tensor], rmse_map: Tensor, total_images: Tensor -) -> Tuple[Optional[Tensor], Tensor]: +) -> tuple[Optional[Tensor], Tensor]: """Compute RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map. Args: @@ -111,7 +111,7 @@ def _rmse_sw_compute( def root_mean_squared_error_using_sliding_window( preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False -) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: +) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]: """Compute Root Mean Squared Error (RMSE) using sliding window. Args: diff --git a/src/torchmetrics/functional/image/sam.py b/src/torchmetrics/functional/image/sam.py index 71927ff6b42..82c304c9543 100644 --- a/src/torchmetrics/functional/image/sam.py +++ b/src/torchmetrics/functional/image/sam.py @@ -21,7 +21,7 @@ from torchmetrics.utilities.distributed import reduce -def _sam_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _sam_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spectral Angle Mapper. Args: diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index a4b7bd85725..f0c3db35295 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -23,7 +23,7 @@ from torchmetrics.utilities.distributed import reduce -def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]: +def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> tuple[Tensor, Tensor, Tensor]: """Update and returns variables required to compute Spatial Correlation Coefficient. Args: @@ -73,7 +73,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i return preds, target, hp_filter -def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: +def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, tuple[int, ...]]) -> Tensor: """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) @@ -106,7 +106,7 @@ def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: return _signal_convolve_2d(input_img, kernel) * 2.0 -def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Computes local variance and covariance of the input tensors.""" # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index bfa27e7df0a..9dbf38b0627 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -24,7 +24,7 @@ from torchmetrics.utilities.distributed import reduce -def _ssim_check_inputs(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ssim_check_inputs(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Structural Similarity Index Measure. Args: @@ -49,12 +49,12 @@ def _ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Structural Similarity Index Measure. Args: @@ -214,12 +214,12 @@ def structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Structural Similarity Index Measure. Args: @@ -298,11 +298,11 @@ def _get_normalized_sim_and_cs( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, normalize: Optional[Literal["relu", "simple"]] = None, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: sim, contrast_sensitivity = _ssim_update( preds, target, @@ -326,10 +326,10 @@ def _multiscale_ssim_update( gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = ( + betas: Union[tuple[float, float, float, float, float], tuple[float, ...]] = ( 0.0448, 0.2856, 0.3001, @@ -372,7 +372,7 @@ def _multiscale_ssim_update( If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. """ - mcs_list: List[Tensor] = [] + mcs_list: list[Tensor] = [] is_3d = preds.ndim == 5 @@ -453,10 +453,10 @@ def multiscale_structural_similarity_index_measure( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Optional[Literal["relu", "simple"]] = "relu", ) -> Tensor: """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure. diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index be8e3366caa..21c7b5f6f31 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -17,7 +17,7 @@ from typing_extensions import Literal -def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: +def _total_variation_update(img: Tensor) -> tuple[Tensor, int]: """Compute total variation statistics on current batch.""" if img.ndim != 4: raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index 366e06a3bb1..1b0711bb969 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -23,7 +23,7 @@ from torchmetrics.utilities.distributed import reduce -def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _uqi_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Universal Image Quality Index. Args: diff --git a/src/torchmetrics/functional/image/utils.py b/src/torchmetrics/functional/image/utils.py index a6869d1dc88..13dee62d2ae 100644 --- a/src/torchmetrics/functional/image/utils.py +++ b/src/torchmetrics/functional/image/utils.py @@ -57,7 +57,7 @@ def _gaussian_kernel_2d( return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) -def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> Tuple[Tensor, Tensor]: +def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> tuple[Tensor, Tensor]: """Construct uniform weight and bias for a 2d convolution. Args: diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 659c667fc52..b22b3472b82 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -40,7 +40,7 @@ def _download_clip_for_iqa_metric() -> None: if not _PIQ_GREATER_EQUAL_0_8: __doctest_skip__ = ["clip_image_quality_assessment"] -_PROMPTS: Dict[str, Tuple[str, str]] = { +_PROMPTS: dict[str, tuple[str, str]] = { "quality": ("Good photo.", "Bad photo."), "brightness": ("Bright photo.", "Dark photo."), "noisiness": ("Clean photo.", "Noisy photo."), @@ -68,7 +68,7 @@ def _get_clip_iqa_model_and_processor( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ], -) -> Tuple["_CLIPModel", "_CLIPProcessor"]: +) -> tuple["_CLIPModel", "_CLIPProcessor"]: """Extract the CLIP model and processor from the model name or path.""" from transformers import CLIPProcessor as _CLIPProcessor @@ -89,7 +89,7 @@ def _get_clip_iqa_model_and_processor( return _get_clip_model_and_processor(model_name_or_path) -def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]: +def _clip_iqa_format_prompts(prompts: tuple[Union[str, tuple[str, str]]] = ("quality",)) -> tuple[list[str], list[str]]: """Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors. Args: @@ -119,8 +119,8 @@ def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("qua if not isinstance(prompts, tuple): raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") - prompts_names: List[str] = [] - prompts_list: List[str] = [] + prompts_names: list[str] = [] + prompts_list: list[str] = [] count = 0 for p in prompts: if not isinstance(p, (str, tuple)): @@ -146,7 +146,7 @@ def _clip_iqa_get_anchor_vectors( model_name_or_path: str, model: "_CLIPModel", processor: "_CLIPProcessor", - prompts_list: List[str], + prompts_list: list[str], device: Union[str, torch.device], ) -> Tensor: """Calculates the anchor vectors for the CLIP IQA metric. @@ -202,9 +202,9 @@ def _clip_iqa_update( def _clip_iqa_compute( img_features: Tensor, anchors: Tensor, - prompts_names: List[str], + prompts_names: list[str], format_as_dict: bool = True, -) -> Union[Tensor, Dict[str, Tensor]]: +) -> Union[Tensor, dict[str, Tensor]]: """Final computation of CLIP IQA.""" logits_per_image = 100 * img_features @ anchors.t() probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0] @@ -225,8 +225,8 @@ def clip_image_quality_assessment( "openai/clip-vit-large-patch14", ] = "clip_iqa", data_range: float = 1.0, - prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), -) -> Union[Tensor, Dict[str, Tensor]]: + prompts: tuple[Union[str, tuple[str, str]]] = ("quality",), +) -> Union[Tensor, dict[str, Tensor]]: """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index f70f37d534b..ead25db6f5a 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -42,11 +42,11 @@ def _download_clip_for_clip_score() -> None: def _clip_score_update( - images: Union[Tensor, List[Tensor]], - text: Union[str, List[str]], + images: Union[Tensor, list[Tensor]], + text: Union[str, list[str]], model: _CLIPModel, processor: _CLIPProcessor, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: if not isinstance(images, list): if images.ndim == 3: images = [images] @@ -97,7 +97,7 @@ def _get_clip_model_and_processor( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ] = "openai/clip-vit-large-patch14", -) -> Tuple[_CLIPModel, _CLIPProcessor]: +) -> tuple[_CLIPModel, _CLIPProcessor]: if _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -113,8 +113,8 @@ def _get_clip_model_and_processor( def clip_score( - images: Union[Tensor, List[Tensor]], - text: Union[str, List[str]], + images: Union[Tensor, list[Tensor]], + text: Union[str, list[str]], model_name_or_path: Literal[ "openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32", diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py index 8c8cc166778..d98d7e8806c 100644 --- a/src/torchmetrics/functional/nominal/utils.py +++ b/src/torchmetrics/functional/nominal/utils.py @@ -93,7 +93,7 @@ def _compute_phi_squared_corrected( ) -def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> Tuple[Tensor, Tensor]: +def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> tuple[Tensor, Tensor]: """Compute bias-corrected number of rows and columns.""" rows_corrected = num_rows - (num_rows - 1) ** 2 / (confmat_sum - 1) cols_corrected = num_cols - (num_cols - 1) ** 2 / (confmat_sum - 1) @@ -102,7 +102,7 @@ def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: def _compute_bias_corrected_values( phi_squared: Tensor, num_rows: int, num_cols: int, confmat_sum: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute bias-corrected Phi Squared and number of rows and columns.""" phi_squared_corrected = _compute_phi_squared_corrected(phi_squared, num_rows, num_cols, confmat_sum) rows_corrected, cols_corrected = _compute_rows_and_cols_corrected(num_rows, num_cols, confmat_sum) @@ -114,7 +114,7 @@ def _handle_nan_in_data( target: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", nan_replace_value: Optional[float] = 0.0, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Handle ``NaN`` values in input data. If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``. diff --git a/src/torchmetrics/functional/pairwise/helpers.py b/src/torchmetrics/functional/pairwise/helpers.py index bc6ba79a81a..a0a0dbf2f9f 100644 --- a/src/torchmetrics/functional/pairwise/helpers.py +++ b/src/torchmetrics/functional/pairwise/helpers.py @@ -18,7 +18,7 @@ def _check_input( x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None -) -> Tuple[Tensor, Tensor, bool]: +) -> tuple[Tensor, Tensor, bool]: """Check that input has the right dimensionality and sets the ``zero_diagonal`` argument if user has not set it. Args: diff --git a/src/torchmetrics/functional/regression/cosine_similarity.py b/src/torchmetrics/functional/regression/cosine_similarity.py index 8f9460010e7..c90885aab86 100644 --- a/src/torchmetrics/functional/regression/cosine_similarity.py +++ b/src/torchmetrics/functional/regression/cosine_similarity.py @@ -22,7 +22,7 @@ def _cosine_similarity_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Cosine Similarity. Checks for same shape of input tensors. Args: diff --git a/src/torchmetrics/functional/regression/csi.py b/src/torchmetrics/functional/regression/csi.py index 8d43c6012fd..f58f615e62b 100644 --- a/src/torchmetrics/functional/regression/csi.py +++ b/src/torchmetrics/functional/regression/csi.py @@ -22,7 +22,7 @@ def _critical_success_index_update( preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute Critical Success Index. Checks for same shape of tensors. Args: diff --git a/src/torchmetrics/functional/regression/explained_variance.py b/src/torchmetrics/functional/regression/explained_variance.py index ab3158bc594..bf93bf94112 100644 --- a/src/torchmetrics/functional/regression/explained_variance.py +++ b/src/torchmetrics/functional/regression/explained_variance.py @@ -23,7 +23,7 @@ ALLOWED_MULTIOUTPUT = ("raw_values", "uniform_average", "variance_weighted") -def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]: +def _explained_variance_update(preds: Tensor, target: Tensor) -> tuple[int, Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Explained Variance. Checks for same shape of input tensors. Args: diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index c46173eae00..804c4cb10ef 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -47,7 +47,7 @@ def _name() -> str: return "alternative" -def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: +def _sort_on_first_sequence(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" # We need to clone `y` tensor not to change an object in memory y = torch.clone(y) @@ -94,7 +94,7 @@ def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: return _cumsum(torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0), dim=0) -def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _get_ties(x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Get a total number of ties and staistics for p-value calculation for a given sequence.""" ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) @@ -111,7 +111,7 @@ def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def _get_metric_metadata( preds: Tensor, target: Tensor, variant: _MetricVariant -) -> Tuple[ +) -> tuple[ Tensor, Tensor, Optional[Tensor], @@ -225,10 +225,10 @@ def _calculate_p_value( def _kendall_corrcoef_update( preds: Tensor, target: Tensor, - concat_preds: Optional[List[Tensor]] = None, - concat_target: Optional[List[Tensor]] = None, + concat_preds: Optional[list[Tensor]] = None, + concat_target: Optional[list[Tensor]] = None, num_outputs: int = 1, -) -> Tuple[List[Tensor], List[Tensor]]: +) -> tuple[list[Tensor], list[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. Args: @@ -263,7 +263,7 @@ def _kendall_corrcoef_compute( target: Tensor, variant: _MetricVariant, alternative: Optional[_TestAlternative] = None, -) -> Tuple[Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Optional[Tensor]]: """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. Args: @@ -324,7 +324,7 @@ def kendall_rank_corrcoef( variant: Literal["a", "b", "c"] = "b", t_test: bool = False, alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: r"""Compute `Kendall Rank Correlation Coefficient`_. .. math:: diff --git a/src/torchmetrics/functional/regression/kl_divergence.py b/src/torchmetrics/functional/regression/kl_divergence.py index 6e6563aee71..8d6aee5c001 100644 --- a/src/torchmetrics/functional/regression/kl_divergence.py +++ b/src/torchmetrics/functional/regression/kl_divergence.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.compute import _safe_xlogy -def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]: +def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> tuple[Tensor, int]: """Update and returns KL divergence scores for each observation and the total number of observations. Args: diff --git a/src/torchmetrics/functional/regression/log_cosh.py b/src/torchmetrics/functional/regression/log_cosh.py index ef9402c740f..a0931bcaecb 100644 --- a/src/torchmetrics/functional/regression/log_cosh.py +++ b/src/torchmetrics/functional/regression/log_cosh.py @@ -20,13 +20,13 @@ from torchmetrics.utilities.checks import _check_same_shape -def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: if preds.ndim == 2: return preds, target return preds.unsqueeze(1), target.unsqueeze(1) -def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: +def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute LogCosh error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/log_mse.py b/src/torchmetrics/functional/regression/log_mse.py index 3ba27ab9e99..34f3d9d71de 100644 --- a/src/torchmetrics/functional/regression/log_mse.py +++ b/src/torchmetrics/functional/regression/log_mse.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]: """Return variables required to compute Mean Squared Log Error. Checks for same shape of tensors. Args: diff --git a/src/torchmetrics/functional/regression/mae.py b/src/torchmetrics/functional/regression/mae.py index d40b30e9017..17db54f242c 100644 --- a/src/torchmetrics/functional/regression/mae.py +++ b/src/torchmetrics/functional/regression/mae.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_absolute_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: +def _mean_absolute_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Absolute Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/mape.py b/src/torchmetrics/functional/regression/mape.py index bdcb7c87264..89cbd865c43 100644 --- a/src/torchmetrics/functional/regression/mape.py +++ b/src/torchmetrics/functional/regression/mape.py @@ -23,7 +23,7 @@ def _mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, epsilon: float = 1.17e-06, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index f9649a87416..09f4fadb490 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, int]: """Update and returns variables required to compute Mean Squared Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index 52cae36adb0..a7e8c28a1af 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -22,7 +22,7 @@ def _normalized_root_mean_squared_error_update( preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean" -) -> Tuple[Tensor, int, Tensor]: +) -> tuple[Tensor, int, Tensor]: """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. Args: diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 47b26344163..6dffb193340 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -32,7 +32,7 @@ def _pearson_corrcoef_update( corr_xy: Tensor, num_prior: Tensor, num_outputs: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Pearson Correlation Coefficient. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/r2.py b/src/torchmetrics/functional/regression/r2.py index ec6aec10a12..f8d036ece99 100644 --- a/src/torchmetrics/functional/regression/r2.py +++ b/src/torchmetrics/functional/regression/r2.py @@ -20,7 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, int]: +def _r2_score_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor, Tensor, int]: """Update and returns variables required to compute R2 score. Check for same shape and 1D/2D input tensors. diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 7d412e407b0..7e57b93c06f 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -54,7 +54,7 @@ def _rank_data(data: Tensor) -> Tensor: return rank -def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: +def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Spearman Correlation Coefficient. Check for same shape and type of input tensors. diff --git a/src/torchmetrics/functional/regression/symmetric_mape.py b/src/torchmetrics/functional/regression/symmetric_mape.py index 2ab3a55f7d5..9d919c13b2a 100644 --- a/src/torchmetrics/functional/regression/symmetric_mape.py +++ b/src/torchmetrics/functional/regression/symmetric_mape.py @@ -23,7 +23,7 @@ def _symmetric_mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, epsilon: float = 1.17e-06, -) -> Tuple[Tensor, int]: +) -> tuple[Tensor, int]: """Update and returns variables required to compute Symmetric Mean Absolute Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/tweedie_deviance.py b/src/torchmetrics/functional/regression/tweedie_deviance.py index e3508dc04c7..e369d7d5ad5 100644 --- a/src/torchmetrics/functional/regression/tweedie_deviance.py +++ b/src/torchmetrics/functional/regression/tweedie_deviance.py @@ -20,7 +20,7 @@ from torchmetrics.utilities.compute import _safe_xlogy -def _tweedie_deviance_score_update(preds: Tensor, targets: Tensor, power: float = 0.0) -> Tuple[Tensor, Tensor]: +def _tweedie_deviance_score_update(preds: Tensor, targets: Tensor, power: float = 0.0) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Deviance Score for the given power. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/regression/wmape.py b/src/torchmetrics/functional/regression/wmape.py index 443badba325..c3834047777 100644 --- a/src/torchmetrics/functional/regression/wmape.py +++ b/src/torchmetrics/functional/regression/wmape.py @@ -22,7 +22,7 @@ def _weighted_mean_absolute_percentage_error_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute Weighted Absolute Percentage Error. Check for same shape of input tensors. diff --git a/src/torchmetrics/functional/retrieval/_deprecated.py b/src/torchmetrics/functional/retrieval/_deprecated.py index 2be7c408340..4621b7f5ed2 100644 --- a/src/torchmetrics/functional/retrieval/_deprecated.py +++ b/src/torchmetrics/functional/retrieval/_deprecated.py @@ -88,7 +88,7 @@ def _retrieval_precision( def _retrieval_precision_recall_curve( preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Wrapper for deprecated import. >>> from torch import tensor diff --git a/src/torchmetrics/functional/retrieval/precision_recall_curve.py b/src/torchmetrics/functional/retrieval/precision_recall_curve.py index 7c70c92ff4d..ed204d136cd 100644 --- a/src/torchmetrics/functional/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/functional/retrieval/precision_recall_curve.py @@ -23,7 +23,7 @@ def retrieval_precision_recall_curve( preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute precision-recall pairs for different k (from 1 to `max_k`). In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 87b3b699fc0..a9ecf99f69e 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -46,7 +46,7 @@ def _dice_score_update( num_classes: int, include_background: bool, input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py index daadc90f6ba..18d7c45ddff 100644 --- a/src/torchmetrics/functional/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -28,7 +28,7 @@ def _hausdorff_distance_validate_args( num_classes: int, include_background: bool, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", ) -> None: @@ -55,7 +55,7 @@ def hausdorff_distance( num_classes: int, include_background: bool = False, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 184e9578cab..5acbd445851 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -45,7 +45,7 @@ def _mean_iou_update( num_classes: int, include_background: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update the intersection and union counts for the mean IoU computation.""" _check_same_shape(preds, target) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 59d42e16171..b6ee4c1195e 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -24,7 +24,7 @@ from torchmetrics.utilities.imports import _SCIPY_AVAILABLE -def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _ignore_background(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: """Ignore the background class in the computation assuming it is the first, index 0.""" preds = preds[:, 1:] if preds.shape[1] > 1 else preds target = target[:, 1:] if target.shape[1] > 1 else target @@ -44,7 +44,7 @@ def check_if_binarized(x: Tensor) -> None: raise ValueError("Input x should be binarized") -def _unfold(x: Tensor, kernel_size: Tuple[int, ...]) -> Tensor: +def _unfold(x: Tensor, kernel_size: tuple[int, ...]) -> Tensor: """Unfold the input tensor to a matrix. Function supports 3d images e.g. (B, C, D, H, W). Inspired by: @@ -112,7 +112,7 @@ def generate_binary_structure(rank: int, connectivity: int) -> Tensor: def binary_erosion( - image: Tensor, structure: Optional[Tensor] = None, origin: Optional[Tuple[int, ...]] = None, border_value: int = 0 + image: Tensor, structure: Optional[Tensor] = None, origin: Optional[tuple[int, ...]] = None, border_value: int = 0 ) -> Tensor: """Binary erosion of a tensor image. @@ -183,7 +183,7 @@ def binary_erosion( def distance_transform( x: Tensor, - sampling: Optional[Union[Tensor, List[float]]] = None, + sampling: Optional[Union[Tensor, list[float]]] = None, metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", engine: Literal["pytorch", "scipy"] = "pytorch", ) -> Tensor: @@ -285,8 +285,8 @@ def mask_edges( preds: Tensor, target: Tensor, crop: bool = True, - spacing: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, -) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]: + spacing: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, +) -> Union[tuple[Tensor, Tensor], tuple[Tensor, Tensor, Tensor, Tensor]]: """Get the edges of binary segmentation masks. Args: @@ -343,7 +343,7 @@ def surface_distance( preds: Tensor, target: Tensor, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, ) -> Tensor: """Calculate the surface distance between two binary edge masks. @@ -393,9 +393,9 @@ def edge_surface_distance( preds: Tensor, target: Tensor, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, symmetric: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Extracts the edges from the input masks and calculates the surface distance between them. Args: @@ -423,8 +423,8 @@ def edge_surface_distance( @functools.lru_cache def get_neighbour_tables( - spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None -) -> Tuple[Tensor, Tensor]: + spacing: Union[tuple[int, int], tuple[int, int, int]], device: Optional[torch.device] = None +) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the contour length or surface area of the corresponding contour. Args: @@ -443,7 +443,7 @@ def get_neighbour_tables( raise ValueError("The spacing must be a tuple of length 2 or 3.") -def table_contour_length(spacing: Tuple[int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: +def table_contour_length(spacing: tuple[int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the contour length of the corresponding contour. Adopted from: @@ -487,7 +487,7 @@ def table_contour_length(spacing: Tuple[int, int], device: Optional[torch.device @functools.lru_cache -def table_surface_area(spacing: Tuple[int, int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: +def table_surface_area(spacing: tuple[int, int, int], device: Optional[torch.device] = None) -> tuple[Tensor, Tensor]: """Create a table that maps neighbour codes to the surface area of the corresponding surface. Adopted from: diff --git a/src/torchmetrics/functional/shape/procrustes.py b/src/torchmetrics/functional/shape/procrustes.py index 08068fd2454..c17871ed251 100644 --- a/src/torchmetrics/functional/shape/procrustes.py +++ b/src/torchmetrics/functional/shape/procrustes.py @@ -22,7 +22,7 @@ def procrustes_disparity( point_cloud1: Tensor, point_cloud2: Tensor, return_all: bool = False -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor]]: """Runs procrustrus analysis on a batch of data points. Works similar ``scipy.spatial.procrustes`` but for batches of data points. diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 8fe94195c58..8ff582f52fc 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -32,19 +32,19 @@ if not _TRANSFORMERS_GREATER_EQUAL_4_4: __doctest_skip__ = ["_bert_score", "_infolm"] -SQUAD_SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] -SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, List[SQUAD_SINGLE_TARGET_TYPE]] +SQUAD_SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] +SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, list[SQUAD_SINGLE_TARGET_TYPE]] def _bert_score( - preds: Union[List[str], Dict[str, Tensor]], - target: Union[List[str], Dict[str, Tensor]], + preds: Union[list[str], dict[str, Tensor]], + target: Union[list[str], dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Any = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -56,7 +56,7 @@ def _bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, -) -> Dict[str, Union[Tensor, List[float], str]]: +) -> dict[str, Union[Tensor, list[float], str]]: """Wrapper for deprecated import. >>> preds = ["hello there", "general kenobi"] @@ -112,7 +112,7 @@ def _bleu_score( return bleu_score(preds=preds, target=target, n_gram=n_gram, smooth=smooth, weights=weights) -def _char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -134,7 +134,7 @@ def _chrf_score( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] @@ -165,7 +165,7 @@ def _extended_edit_distance( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "here is an other sample"] @@ -202,7 +202,7 @@ def _infolm( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['he read the book because he was interested in world history'] @@ -230,7 +230,7 @@ def _infolm( ) -def _match_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -265,8 +265,8 @@ def _rouge_score( use_stemmer: bool = False, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), -) -> Dict[str, Tensor]: + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), +) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = "My name is John" @@ -328,7 +328,7 @@ def _sacre_bleu_score( ) -def _squad(preds: Union[Dict[str, str], List[Dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> Dict[str, Tensor]: +def _squad(preds: Union[dict[str, str], list[dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] @@ -349,7 +349,7 @@ def _translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] @@ -370,7 +370,7 @@ def _translation_edit_rate( ) -def _word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -383,7 +383,7 @@ def _word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]] return word_error_rate(preds=preds, target=target) -def _word_information_lost(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] @@ -396,7 +396,7 @@ def _word_information_lost(preds: Union[str, List[str]], target: Union[str, List return word_information_lost(preds=preds, target=target) -def _word_information_preserved(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def _word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 83cc23950d4..4380f447da7 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -76,8 +76,8 @@ def _get_embeddings_and_idf_scale( all_layers: bool = False, idf: bool = False, verbose: bool = False, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, -) -> Tuple[Tensor, Tensor]: + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, +) -> tuple[Tensor, Tensor]: """Calculate sentence embeddings and the inverse-document-frequency scaling factor. Args: @@ -106,8 +106,8 @@ def _get_embeddings_and_idf_scale( If ``all_layers = True`` and a model, which is not from the ``transformers`` package, is used. """ - embeddings_list: List[Tensor] = [] - idf_scale_list: List[Tensor] = [] + embeddings_list: list[Tensor] = [] + idf_scale_list: list[Tensor] = [] for batch in _get_progress_bar(dataloader, verbose): with torch.no_grad(): batch = _input_data_collator(batch, device) @@ -159,7 +159,7 @@ def _get_scaled_precision_or_recall(cos_sim: Tensor, metric: str, idf_scale: Ten def _get_precision_recall_f1( preds_embeddings: Tensor, target_embeddings: Tensor, preds_idf_scale: Tensor, target_idf_scale: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Calculate precision, recall and F1 score over candidate and reference sentences. Args: @@ -246,7 +246,7 @@ def _rescale_metrics_with_baseline( baseline: Tensor, num_layers: Optional[int] = None, all_layers: bool = False, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Rescale the computed metrics with the pre-computed baseline.""" if num_layers is None and all_layers is False: num_layers = -1 @@ -258,14 +258,14 @@ def _rescale_metrics_with_baseline( def bert_score( - preds: Union[str, Sequence[str], Dict[str, Tensor]], - target: Union[str, Sequence[str], Dict[str, Tensor]], + preds: Union[str, Sequence[str], dict[str, Tensor]], + target: Union[str, Sequence[str], dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Any = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -278,7 +278,7 @@ def bert_score( baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, truncation: bool = False, -) -> Dict[str, Union[Tensor, List[float], str]]: +) -> dict[str, Union[Tensor, list[float], str]]: """`Bert_score Evaluating Text Generation`_ for text similirity matching. This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference @@ -406,7 +406,7 @@ def bert_score( ) if _are_empty_lists: rank_zero_warn("Predictions and references are empty.") - output_dict: Dict[str, Union[Tensor, List[float], str]] = { + output_dict: dict[str, Union[Tensor, list[float], str]] = { "precision": [0.0], "recall": [0.0], "f1": [0.0], diff --git a/src/torchmetrics/functional/text/bleu.py b/src/torchmetrics/functional/text/bleu.py index 724cec16794..6ed88ae05c8 100644 --- a/src/torchmetrics/functional/text/bleu.py +++ b/src/torchmetrics/functional/text/bleu.py @@ -67,7 +67,7 @@ def _bleu_score_update( target_len: Tensor, n_gram: int = 4, tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute the BLEU score. Args: diff --git a/src/torchmetrics/functional/text/cer.py b/src/torchmetrics/functional/text/cer.py index 19dc55afba9..2b5e10f6c55 100644 --- a/src/torchmetrics/functional/text/cer.py +++ b/src/torchmetrics/functional/text/cer.py @@ -21,9 +21,9 @@ def _cer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the cer score with the current set of references and predictions. Args: @@ -63,7 +63,7 @@ def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def char_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Compute Character Error Rate used for performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 7d7c552e3fb..70f09142d7c 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -37,8 +37,8 @@ def _prepare_n_grams_dicts( n_char_order: int, n_word_order: int -) -> Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +) -> tuple[ + dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor] ]: """Prepare dictionaries with default zero values for total ref, hypothesis and matching character and word n-grams. @@ -51,12 +51,12 @@ def _prepare_n_grams_dicts( n-grams. """ - total_preds_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_preds_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_target_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_target_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_matching_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_preds_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_preds_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_target_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_target_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_matching_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_matching_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} return ( total_preds_char_n_grams, @@ -68,7 +68,7 @@ def _prepare_n_grams_dicts( ) -def _get_characters(sentence: str, whitespace: bool) -> List[str]: +def _get_characters(sentence: str, whitespace: bool) -> list[str]: """Split sentence into individual characters. Args: @@ -84,7 +84,7 @@ def _get_characters(sentence: str, whitespace: bool) -> List[str]: return list(sentence.strip().replace(" ", "")) -def _separate_word_and_punctuation(word: str) -> List[str]: +def _separate_word_and_punctuation(word: str) -> list[str]: """Separates out punctuations from beginning and end of words for chrF. Adapted from https://github.com/m-popovic/chrF and @@ -107,7 +107,7 @@ def _separate_word_and_punctuation(word: str) -> List[str]: return [word] -def _get_words_and_punctuation(sentence: str) -> List[str]: +def _get_words_and_punctuation(sentence: str) -> list[str]: """Separates out punctuations from beginning and end of words for chrF for all words in the sentence. Args: @@ -120,7 +120,7 @@ def _get_words_and_punctuation(sentence: str) -> List[str]: return list(chain.from_iterable(_separate_word_and_punctuation(word) for word in sentence.strip().split())) -def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]: +def _ngram_counts(char_or_word_list: list[str], n_gram_order: int) -> dict[int, dict[tuple[str, ...], Tensor]]: """Calculate n-gram counts. Args: @@ -131,7 +131,7 @@ def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, A dictionary of dictionaries with a counts of given n-grams. """ - ngrams: Dict[int, Dict[Tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) + ngrams: dict[int, dict[tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) for n in range(1, n_gram_order + 1): for ngram in (tuple(char_or_word_list[i : i + n]) for i in range(len(char_or_word_list) - n + 1)): ngrams[n][ngram] += tensor(1) @@ -140,11 +140,11 @@ def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, def _get_n_grams_counts_and_total_ngrams( sentence: str, n_char_order: int, n_word_order: int, lowercase: bool, whitespace: bool -) -> Tuple[ - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Tensor], - Dict[int, Tensor], +) -> tuple[ + dict[int, dict[tuple[str, ...], Tensor]], + dict[int, dict[tuple[str, ...], Tensor]], + dict[int, Tensor], + dict[int, Tensor], ]: """Get n-grams and total n-grams. @@ -165,7 +165,7 @@ def _get_n_grams_counts_and_total_ngrams( def _char_and_word_ngrams_counts( sentence: str, n_char_order: int, n_word_order: int, lowercase: bool - ) -> Tuple[Dict[int, Dict[Tuple[str, ...], Tensor]], Dict[int, Dict[Tuple[str, ...], Tensor]]]: + ) -> tuple[dict[int, dict[tuple[str, ...], Tensor]], dict[int, dict[tuple[str, ...], Tensor]]]: """Get a dictionary of dictionaries with a counts of given n-grams.""" if lowercase: sentence = sentence.lower() @@ -173,9 +173,9 @@ def _char_and_word_ngrams_counts( word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order) return char_n_grams_counts, word_n_grams_counts - def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]: + def _get_total_ngrams(n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]]) -> dict[int, Tensor]: """Get total sum of n-grams over n-grams w.r.t n.""" - total_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + total_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for n in n_grams_counts: total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore return total_n_grams @@ -190,9 +190,9 @@ def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) def _get_ngram_matches( - hyp_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], -) -> Dict[int, Tensor]: + hyp_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + ref_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], +) -> dict[int, Tensor]: """Get a number of n-gram matches between reference and hypothesis n-grams. Args: @@ -203,7 +203,7 @@ def _get_ngram_matches( matching_n_grams """ - matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + matching_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for n in hyp_n_grams_counts: min_n_grams = [ torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n] @@ -212,7 +212,7 @@ def _get_ngram_matches( return matching_n_grams -def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor]) -> Dict[int, Tensor]: +def _sum_over_dicts(total_n_grams: dict[int, Tensor], n_grams: dict[int, Tensor]) -> dict[int, Tensor]: """Aggregate total n-grams to keep corpus-level statistics. Args: @@ -229,12 +229,12 @@ def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor] def _calculate_fscore( - matching_char_n_grams: Dict[int, Tensor], - matching_word_n_grams: Dict[int, Tensor], - hyp_char_n_grams: Dict[int, Tensor], - hyp_word_n_grams: Dict[int, Tensor], - ref_char_n_grams: Dict[int, Tensor], - ref_word_n_grams: Dict[int, Tensor], + matching_char_n_grams: dict[int, Tensor], + matching_word_n_grams: dict[int, Tensor], + hyp_char_n_grams: dict[int, Tensor], + hyp_word_n_grams: dict[int, Tensor], + ref_char_n_grams: dict[int, Tensor], + ref_word_n_grams: dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: @@ -261,19 +261,19 @@ def _calculate_fscore( """ def _get_n_gram_fscore( - matching_n_grams: Dict[int, Tensor], ref_n_grams: Dict[int, Tensor], hyp_n_grams: Dict[int, Tensor], beta: float - ) -> Dict[int, Tensor]: + matching_n_grams: dict[int, Tensor], ref_n_grams: dict[int, Tensor], hyp_n_grams: dict[int, Tensor], beta: float + ) -> dict[int, Tensor]: """Get n-gram level f-score.""" - precision: Dict[int, Tensor] = { + precision: dict[int, Tensor] = { n: matching_n_grams[n] / hyp_n_grams[n] if hyp_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams } - recall: Dict[int, Tensor] = { + recall: dict[int, Tensor] = { n: matching_n_grams[n] / ref_n_grams[n] if ref_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams } - denominator: Dict[int, Tensor] = { + denominator: dict[int, Tensor] = { n: torch.max(beta**2 * precision[n] + recall[n], _EPS_SMOOTHING) for n in matching_n_grams } - f_score: Dict[int, Tensor] = { + f_score: dict[int, Tensor] = { n: (1 + beta**2) * precision[n] * recall[n] / denominator[n] for n in matching_n_grams } @@ -286,18 +286,18 @@ def _get_n_gram_fscore( def _calculate_sentence_level_chrf_score( - targets: List[str], - pred_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_char_n_grams: Dict[int, Tensor], - pred_word_n_grams: Dict[int, Tensor], + targets: list[str], + pred_char_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + pred_word_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]], + pred_char_n_grams: dict[int, Tensor], + pred_word_n_grams: dict[int, Tensor], n_char_order: int, n_word_order: int, n_order: float, beta: float, lowercase: bool, whitespace: bool, -) -> Tuple[Tensor, Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor]]: +) -> tuple[Tensor, dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor]]: """Calculate the best sentence-level chrF/chrF++ score. For a given pre-processed hypothesis, all references are evaluated and score and statistics @@ -329,10 +329,10 @@ def _calculate_sentence_level_chrf_score( """ best_f_score = tensor(0.0) - best_matching_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_matching_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_matching_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_matching_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for target in targets: ( @@ -374,27 +374,27 @@ def _calculate_sentence_level_chrf_score( def _chrf_score_update( preds: Union[str, Sequence[str]], target: Union[Sequence[str], Sequence[Sequence[str]]], - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], + total_preds_char_n_grams: dict[int, Tensor], + total_preds_word_n_grams: dict[int, Tensor], + total_target_char_n_grams: dict[int, Tensor], + total_target_word_n_grams: dict[int, Tensor], + total_matching_char_n_grams: dict[int, Tensor], + total_matching_word_n_grams: dict[int, Tensor], n_char_order: int, n_word_order: int, n_order: float, beta: float, lowercase: bool, whitespace: bool, - sentence_chrf_score: Optional[List[Tensor]] = None, -) -> Tuple[ - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Optional[List[Tensor]], + sentence_chrf_score: Optional[list[Tensor]] = None, +) -> tuple[ + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + dict[int, Tensor], + Optional[list[Tensor]], ]: """Update function for chrf score. @@ -483,12 +483,12 @@ def _chrf_score_update( def _chrf_score_compute( - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], + total_preds_char_n_grams: dict[int, Tensor], + total_preds_word_n_grams: dict[int, Tensor], + total_target_char_n_grams: dict[int, Tensor], + total_target_word_n_grams: dict[int, Tensor], + total_matching_char_n_grams: dict[int, Tensor], + total_matching_word_n_grams: dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: @@ -530,7 +530,7 @@ def chrf_score( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate `chrF score`_ of machine translated text with one or more references. This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in @@ -594,7 +594,7 @@ def chrf_score( total_matching_word_n_grams, ) = _prepare_n_grams_dicts(n_char_order, n_word_order) - sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None + sentence_chrf_score: Optional[list[Tensor]] = [] if return_sentence_level_score else None ( total_preds_char_n_grams, diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index 45e9c254412..abdc4e01de8 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -234,7 +234,7 @@ def _preprocess_ja(sentence: str) -> str: return unicodedata.normalize("NFKC", sentence) -def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor: +def _eed_compute(sentence_level_scores: list[Tensor]) -> Tensor: """Reduction for extended edit distance. Args: @@ -254,7 +254,7 @@ def _preprocess_sentences( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], language: Literal["en", "ja"], -) -> Tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: +) -> tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: """Preprocess strings according to language requirements. Args: @@ -328,8 +328,8 @@ def _eed_update( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, - sentence_eed: Optional[List[Tensor]] = None, -) -> List[Tensor]: + sentence_eed: Optional[list[Tensor]] = None, +) -> list[Tensor]: """Compute scores for ExtendedEditDistance. Args: @@ -371,7 +371,7 @@ def extended_edit_distance( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings. The metric utilises the Levenshtein distance and extends it by adding a jump operation. diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index bf02d40f3d8..3266876d129 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -68,12 +68,12 @@ class _LevenshteinEditDistance: """ def __init__( - self, reference_tokens: List[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1 + self, reference_tokens: list[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1 ) -> None: self.reference_tokens = reference_tokens self.reference_len = len(reference_tokens) - self.cache: Dict[str, Tuple[int, str]] = {} + self.cache: dict[str, tuple[int, str]] = {} self.cache_size = 0 self.op_insert = op_insert @@ -82,7 +82,7 @@ def __init__( self.op_nothing = 0 self.op_undefined = _INT_INFINITY - def __call__(self, prediction_tokens: List[str]) -> Tuple[int, Tuple[_EditOperations, ...]]: + def __call__(self, prediction_tokens: list[str]) -> tuple[int, tuple[_EditOperations, ...]]: """Calculate edit distance between self._words_ref and the hypothesis. Uses cache to skip some computations. Args: @@ -105,10 +105,10 @@ def __call__(self, prediction_tokens: List[str]) -> Tuple[int, Tuple[_EditOperat def _levenshtein_edit_distance( self, - prediction_tokens: List[str], + prediction_tokens: list[str], prediction_start: int, - cache: List[List[Tuple[int, _EditOperations]]], - ) -> Tuple[int, List[List[Tuple[int, _EditOperations]]], Tuple[_EditOperations, ...]]: + cache: list[list[tuple[int, _EditOperations]]], + ) -> tuple[int, list[list[tuple[int, _EditOperations]]], tuple[_EditOperations, ...]]: """Dynamic programming algorithm to compute the Levenhstein edit distance. Args: @@ -122,10 +122,10 @@ def _levenshtein_edit_distance( """ prediction_len = len(prediction_tokens) - empty_rows: List[List[Tuple[int, _EditOperations]]] = [ + empty_rows: list[list[tuple[int, _EditOperations]]] = [ list(self._get_empty_row(self.reference_len)) for _ in range(prediction_len - prediction_start) ] - edit_distance: List[List[Tuple[int, _EditOperations]]] = cache + empty_rows + edit_distance: list[list[tuple[int, _EditOperations]]] = cache + empty_rows length_ratio = self.reference_len / prediction_len if prediction_tokens else 1.0 # Ensure to not end up with zero overlaip with previous role @@ -172,8 +172,8 @@ def _levenshtein_edit_distance( return edit_distance[-1][-1][0], edit_distance[len(cache) :], trace def _get_trace( - self, prediction_len: int, edit_distance: List[List[Tuple[int, _EditOperations]]] - ) -> Tuple[_EditOperations, ...]: + self, prediction_len: int, edit_distance: list[list[tuple[int, _EditOperations]]] + ) -> tuple[_EditOperations, ...]: """Get a trace of executed operations from the edit distance matrix. Args: @@ -190,7 +190,7 @@ def _get_trace( If an unknown operation has been applied. """ - trace: Tuple[_EditOperations, ...] = () + trace: tuple[_EditOperations, ...] = () i = prediction_len j = self.reference_len @@ -209,7 +209,7 @@ def _get_trace( return trace - def _add_cache(self, prediction_tokens: List[str], edit_distance: List[List[Tuple[int, _EditOperations]]]) -> None: + def _add_cache(self, prediction_tokens: list[str], edit_distance: list[list[tuple[int, _EditOperations]]]) -> None: """Add newly computed rows to cache. Since edit distance is only calculated on the hypothesis suffix that was not in cache, the number of rows in @@ -242,7 +242,7 @@ def _add_cache(self, prediction_tokens: List[str], edit_distance: List[List[Tupl value = node[word] node = value[0] # type: ignore - def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tuple[int, _EditOperations]]]]: + def _find_cache(self, prediction_tokens: list[str]) -> tuple[int, list[list[tuple[int, _EditOperations]]]]: """Find the already calculated rows of the Levenshtein edit distance metric. Args: @@ -259,7 +259,7 @@ def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tupl """ node = self.cache start_position = 0 - edit_distance: List[List[Tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)] + edit_distance: list[list[tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)] for word in prediction_tokens: if word in node: start_position += 1 @@ -270,7 +270,7 @@ def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tupl return start_position, edit_distance - def _get_empty_row(self, length: int) -> List[Tuple[int, _EditOperations]]: + def _get_empty_row(self, length: int) -> list[tuple[int, _EditOperations]]: """Precomputed empty matrix row for Levenhstein edit distance. Args: @@ -282,7 +282,7 @@ def _get_empty_row(self, length: int) -> List[Tuple[int, _EditOperations]]: """ return [(int(self.op_undefined), _EditOperations.OP_UNDEFINED)] * (length + 1) - def _get_initial_row(self, length: int) -> List[Tuple[int, _EditOperations]]: + def _get_initial_row(self, length: int) -> list[tuple[int, _EditOperations]]: """First row corresponds to insertion operations of the reference, so 1 edit operation per reference word. Args: @@ -298,7 +298,7 @@ def _get_initial_row(self, length: int) -> List[Tuple[int, _EditOperations]]: def _validate_inputs( ref_corpus: Union[Sequence[str], Sequence[Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], -) -> Tuple[Sequence[Sequence[str]], Sequence[str]]: +) -> tuple[Sequence[Sequence[str]], Sequence[str]]: """Check and update (if needed) the format of reference and hypothesis corpora for various text evaluation metrics. Args: @@ -327,7 +327,7 @@ def _validate_inputs( return ref_corpus, hypothesis_corpus -def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int: +def _edit_distance(prediction_tokens: list[str], reference_tokens: list[str]) -> int: """Dynamic programming algorithm to compute the edit distance. Args: @@ -351,7 +351,7 @@ def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> return dp[-1][-1] -def _flip_trace(trace: Tuple[_EditOperations, ...]) -> Tuple[_EditOperations, ...]: +def _flip_trace(trace: tuple[_EditOperations, ...]) -> tuple[_EditOperations, ...]: """Flip the trace of edit operations. Instead of rewriting a->b, get a recipe for rewriting b->a. Simply flips insertions and deletions. @@ -364,13 +364,13 @@ def _flip_trace(trace: Tuple[_EditOperations, ...]) -> Tuple[_EditOperations, .. A tuple of inverted edit operations. """ - _flip_operations: Dict[_EditOperations, _EditOperations] = { + _flip_operations: dict[_EditOperations, _EditOperations] = { _EditOperations.OP_INSERT: _EditOperations.OP_DELETE, _EditOperations.OP_DELETE: _EditOperations.OP_INSERT, } def _replace_operation_or_retain( - operation: _EditOperations, _flip_operations: Dict[_EditOperations, _EditOperations] + operation: _EditOperations, _flip_operations: dict[_EditOperations, _EditOperations] ) -> _EditOperations: if operation in _flip_operations: return _flip_operations.get(operation) # type: ignore @@ -379,7 +379,7 @@ def _replace_operation_or_retain( return tuple(_replace_operation_or_retain(operation, _flip_operations) for operation in trace) -def _trace_to_alignment(trace: Tuple[_EditOperations, ...]) -> Tuple[Dict[int, int], List[int], List[int]]: +def _trace_to_alignment(trace: tuple[_EditOperations, ...]) -> tuple[dict[int, int], list[int], list[int]]: """Transform trace of edit operations into an alignment of the sequences. Args: @@ -396,9 +396,9 @@ def _trace_to_alignment(trace: Tuple[_EditOperations, ...]) -> Tuple[Dict[int, i """ reference_position = hypothesis_position = -1 - reference_errors: List[int] = [] - hypothesis_errors: List[int] = [] - alignments: Dict[int, int] = {} + reference_errors: list[int] = [] + hypothesis_errors: list[int] = [] + alignments: dict[int, int] = {} # we are rewriting a into b for operation in trace: diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index f2b59126c7d..ad57c13a2f4 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -49,8 +49,8 @@ def _process_attention_mask_for_special_tokens(attention_mask: Tensor) -> Tensor def _input_data_collator( - batch: Dict[str, Tensor], device: Optional[Union[str, torch.device]] = None -) -> Dict[str, Tensor]: + batch: dict[str, Tensor], device: Optional[Union[str, torch.device]] = None +) -> dict[str, Tensor]: """Trim model inputs. This function trims the model inputs to the longest sequence within the batch and put the input on the proper @@ -64,7 +64,7 @@ def _input_data_collator( return batch -def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> Tuple[Tensor, Tensor]: +def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> tuple[Tensor, Tensor]: """Pad the model output and attention mask to the target length.""" zeros_shape = list(model_output.shape) zeros_shape[2] = target_len - zeros_shape[2] @@ -76,7 +76,7 @@ def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_l return model_output, attention_mask -def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: +def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Sort tokenized sentence from the shortest to the longest one.""" sorted_indices = attention_mask.sum(1).argsort() input_ids = input_ids[sorted_indices] @@ -85,13 +85,13 @@ def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tu def _preprocess_text( - text: List[str], + text: list[str], tokenizer: Any, max_length: int = 512, truncation: bool = True, sort_according_length: bool = True, own_tokenizer: bool = False, -) -> Tuple[Dict[str, Tensor], Optional[Tensor]]: +) -> tuple[dict[str, Tensor], Optional[Tensor]]: """Text pre-processing function using `transformers` `AutoTokenizer` instance. Args: @@ -164,7 +164,7 @@ def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None: def _load_tokenizer_and_model( model_name_or_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None -) -> Tuple["PreTrainedTokenizerBase", "PreTrainedModel"]: +) -> tuple["PreTrainedTokenizerBase", "PreTrainedModel"]: """Load HuggingFace `transformers`' tokenizer and model. This function also handle a device placement. Args: @@ -191,14 +191,14 @@ class TextDataset(Dataset): def __init__( self, - text: List[str], + text: list[str], tokenizer: Any, max_length: int = 512, preprocess_text_fn: Callable[ - [List[str], Any, int, bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] + [list[str], Any, int, bool], Union[dict[str, Tensor], tuple[dict[str, Tensor], Optional[Tensor]]] ] = _preprocess_text, idf: bool = False, - tokens_idf: Optional[Dict[int, float]] = None, + tokens_idf: Optional[dict[int, float]] = None, truncation: bool = False, ) -> None: """Initialize text dataset class. @@ -225,7 +225,7 @@ def __init__( if idf: self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf() - def __getitem__(self, idx: int) -> Dict[str, Tensor]: + def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get the input ids and attention mask belonging to a specific datapoint.""" input_ids = self.text["input_ids"][idx, :] attention_mask = self.text["attention_mask"][idx, :] @@ -239,7 +239,7 @@ def __len__(self) -> int: """Return the number of sentences in the dataset.""" return self.num_sentences - def _get_tokens_idf(self) -> Dict[int, float]: + def _get_tokens_idf(self) -> dict[int, float]: """Calculate token inverse document frequencies. Return: @@ -250,7 +250,7 @@ def _get_tokens_idf(self) -> Dict[int, float]: for tokens in map(self._set_of_tokens, self.text["input_ids"]): token_counter.update(tokens) - tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value) + tokens_idf: dict[int, float] = defaultdict(self._get_tokens_idf_default_value) tokens_idf.update({ idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items() }) @@ -261,7 +261,7 @@ def _get_tokens_idf_default_value(self) -> float: return math.log((self.num_sentences + 1) / 1) @staticmethod - def _set_of_tokens(input_ids: Tensor) -> Set: + def _set_of_tokens(input_ids: Tensor) -> set: """Return set of tokens from the `input_ids` :class:`~torch.Tensor`.""" return set(input_ids.tolist()) @@ -274,7 +274,7 @@ def __init__( input_ids: Tensor, attention_mask: Tensor, idf: bool = False, - tokens_idf: Optional[Dict[int, float]] = None, + tokens_idf: Optional[dict[int, float]] = None, ) -> None: """Initialize the dataset class. diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index a3efadffe33..33c20f97e53 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -321,7 +321,7 @@ def _get_dataloader( return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) -def _get_special_tokens_map(tokenizer: "PreTrainedTokenizerBase") -> Dict[str, int]: +def _get_special_tokens_map(tokenizer: "PreTrainedTokenizerBase") -> dict[str, int]: """Build a dictionary of model/tokenizer special tokens. Args: @@ -367,10 +367,10 @@ def _get_token_mask(input_ids: Tensor, pad_token_id: int, sep_token_id: int, cls def _get_batch_distribution( model: "PreTrainedModel", - batch: Dict[str, Tensor], + batch: dict[str, Tensor], temperature: float, idf: bool, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], ) -> Tensor: """Calculate a discrete probability distribution for a batch of examples. See `InfoLM`_ for details. @@ -393,7 +393,7 @@ def _get_batch_distribution( """ seq_len = batch["input_ids"].shape[1] - prob_distribution_batch_list: List[Tensor] = [] + prob_distribution_batch_list: list[Tensor] = [] token_mask = _get_token_mask( batch["input_ids"], special_tokens_map["pad_token_id"], @@ -428,7 +428,7 @@ def _get_data_distribution( dataloader: DataLoader, temperature: float, idf: bool, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], verbose: bool, ) -> Tensor: """Calculate a discrete probability distribution according to the methodology described in `InfoLM`_. @@ -454,7 +454,7 @@ def _get_data_distribution( """ device = model.device - prob_distribution: List[Tensor] = [] + prob_distribution: list[Tensor] = [] for batch in _get_progress_bar(dataloader, verbose): batch = _input_data_collator(batch, device) @@ -468,7 +468,7 @@ def _infolm_update( target: Union[str, Sequence[str]], tokenizer: "PreTrainedTokenizerBase", max_length: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Update the metric state by a tokenization of ``preds`` and ``target`` sentencens. Args: @@ -504,7 +504,7 @@ def _infolm_compute( temperature: float, idf: bool, information_measure_cls: _InformationMeasure, - special_tokens_map: Dict[str, int], + special_tokens_map: dict[str, int], verbose: bool = True, ) -> Tensor: """Calculate selected information measure using the pre-trained language model. @@ -558,7 +558,7 @@ def infolm( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: +) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate `InfoLM`_ [1]. InfoML corresponds to distance/divergence between predicted and reference sentence discrete distribution using diff --git a/src/torchmetrics/functional/text/mer.py b/src/torchmetrics/functional/text/mer.py index 89d3764331d..34d0eaff0ed 100644 --- a/src/torchmetrics/functional/text/mer.py +++ b/src/torchmetrics/functional/text/mer.py @@ -21,9 +21,9 @@ def _mer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the mer score with the current set of references and predictions. Args: @@ -64,7 +64,7 @@ def _mer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def match_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Match error rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index 1561ffa412a..5a58a8da4c0 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -62,7 +62,7 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.") -def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tuple[Tensor, Tensor]: +def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> tuple[Tensor, Tensor]: """Compute intermediate statistics for Perplexity. Args: diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index a83b41a1a20..e9f926b4d66 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -24,7 +24,7 @@ __doctest_requires__ = {("rouge_score", "_rouge_score_update"): ["nltk"]} -ALLOWED_ROUGE_KEYS: Dict[str, Union[int, str]] = { +ALLOWED_ROUGE_KEYS: dict[str, Union[int, str]] = { "rouge1": 1, "rouge2": 2, "rouge3": 3, @@ -72,7 +72,7 @@ def _split_sentence(x: str) -> Sequence[str]: return nltk.sent_tokenize(x) -def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, Tensor]: +def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> dict[str, Tensor]: """Compute overall metrics. This function computes precision, recall and F1 score based on hits/lcs, the length of lists of tokenizer @@ -129,7 +129,7 @@ def _backtracked_lcs( """ i = len(pred_tokens) j = len(target_tokens) - backtracked_lcs: List[int] = [] + backtracked_lcs: list[int] = [] while i > 0 and j > 0: if pred_tokens[i - 1] == target_tokens[j - 1]: backtracked_lcs.insert(0, j - 1) @@ -200,7 +200,7 @@ def _normalize_and_tokenize_text( return [x for x in tokens if (isinstance(x, str) and len(x) > 0)] -def _rouge_n_score(pred: Sequence[str], target: Sequence[str], n_gram: int) -> Dict[str, Tensor]: +def _rouge_n_score(pred: Sequence[str], target: Sequence[str], n_gram: int) -> dict[str, Tensor]: """Compute precision, recall and F1 score for the Rouge-N metric. Args: @@ -226,7 +226,7 @@ def _create_ngrams(tokens: Sequence[str], n: int) -> Counter: return _compute_metrics(hits, max(pred_len, 1), max(target_len, 1)) -def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]: +def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> dict[str, Tensor]: """Compute precision, recall and F1 score for the Rouge-L metric. Args: @@ -242,7 +242,7 @@ def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tens return _compute_metrics(lcs, pred_len, target_len) -def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> Dict[str, Tensor]: +def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> dict[str, Tensor]: r"""Compute precision, recall and F1 score for the Rouge-LSum metric. More information can be found in Section 3.2 of the referenced paper [1]. This implementation follow the official @@ -288,12 +288,12 @@ def _get_token_counts(sentences: Sequence[Sequence[str]]) -> Counter: def _rouge_score_update( preds: Sequence[str], target: Sequence[Sequence[str]], - rouge_keys_values: List[Union[int, str]], + rouge_keys_values: list[Union[int, str]], accumulate: str, stemmer: Optional[Any] = None, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, -) -> Dict[Union[int, str], List[Dict[str, Tensor]]]: +) -> dict[Union[int, str], list[dict[str, Tensor]]]: """Update the rouge score with the current set of predicted and target sentences. Args: @@ -328,11 +328,11 @@ def _rouge_score_update( 'recall': tensor(0.5000)}]} """ - results: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} + results: dict[Union[int, str], list[dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} for pred_raw, target_raw in zip(preds, target): - result_inner: Dict[Union[int, str], Dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values} - result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} + result_inner: dict[Union[int, str], dict[str, Tensor]] = {rouge_key: {} for rouge_key in rouge_keys_values} + result_avg: dict[Union[int, str], list[dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) if "Lsum" in rouge_keys_values: @@ -370,11 +370,11 @@ def _rouge_score_update( results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo elif accumulate == "avg": - new_result_avg: Dict[Union[int, str], Dict[str, Tensor]] = { + new_result_avg: dict[Union[int, str], dict[str, Tensor]] = { rouge_key: {} for rouge_key in rouge_keys_values } for rouge_key, metrics in result_avg.items(): - _dict_metric_score_batch: Dict[str, List[Tensor]] = {} + _dict_metric_score_batch: dict[str, list[Tensor]] = {} for metric in metrics: for _type, value in metric.items(): if _type not in _dict_metric_score_batch: @@ -391,14 +391,14 @@ def _rouge_score_update( return results -def _rouge_score_compute(sentence_results: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: +def _rouge_score_compute(sentence_results: dict[str, list[Tensor]]) -> dict[str, Tensor]: """Compute the combined ROUGE metric for all the input set of predicted and target sentences. Args: sentence_results: Rouge-N/Rouge-L/Rouge-LSum metrics calculated for single sentence. """ - results: Dict[str, Tensor] = {} + results: dict[str, Tensor] = {} # Obtain mean scores for individual rouge metrics if sentence_results == {}: return results @@ -416,8 +416,8 @@ def rouge_score( use_stemmer: bool = False, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), -) -> Dict[str, Tensor]: + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), +) -> dict[str, Tensor]: """Calculate `Calculate Rouge Score`_ , used for automatic summarization. Args: @@ -495,7 +495,7 @@ def rouge_score( if isinstance(target, str): target = [[target]] - sentence_results: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update( + sentence_results: dict[Union[int, str], list[dict[str, Tensor]]] = _rouge_score_update( preds, target, rouge_keys_values, @@ -505,7 +505,7 @@ def rouge_score( accumulate=accumulate, ) - output: Dict[str, List[Tensor]] = { + output: dict[str, list[Tensor]] = { f"rouge{rouge_key}_{tp}": [] for rouge_key in rouge_keys_values for tp in ["fmeasure", "precision", "recall"] } for rouge_key, metrics in sentence_results.items(): diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index 6b18f4cab8e..28e36b2ea3b 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -142,7 +142,7 @@ class _SacreBLEUTokenizer: } # Keep it as class variable to avoid initializing over and over again - sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} + sentencepiece_processors: ClassVar[dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} def __init__(self, tokenize: _TokenizersLiteral, lowercase: bool = False) -> None: self._check_tokenizers_validity(tokenize) @@ -156,7 +156,7 @@ def __call__(self, line: str) -> Sequence[str]: @classmethod def tokenize( - cls: Type["_SacreBLEUTokenizer"], + cls: type["_SacreBLEUTokenizer"], line: str, tokenize: _TokenizersLiteral, lowercase: bool = False, @@ -168,7 +168,7 @@ def tokenize( return cls._lower(tokenized_line, lowercase).split() @classmethod - def _tokenize_regex(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_regex(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Post-processing tokenizer for `13a` and `zh` tokenizers. Args: @@ -197,7 +197,7 @@ def _is_chinese_char(uchar: str) -> bool: return any(start <= uchar <= end for start, end in _UCODE_RANGES) @classmethod - def _tokenize_base(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_base(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes an input line with the tokenizer. Args: @@ -210,7 +210,7 @@ def _tokenize_base(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return line @classmethod - def _tokenize_13a(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_13a(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a line using a relatively minimal tokenization that is equivalent to mteval-v13a, used by WMT. Args: @@ -234,7 +234,7 @@ def _tokenize_13a(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_regex(f" {line} ") @classmethod - def _tokenize_zh(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_zh(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenization of Chinese text. This is done in two steps: separate each Chinese characters (by utf-8 encoding) and afterwards tokenize the @@ -262,7 +262,7 @@ def _tokenize_zh(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_regex(line_in_chars) @classmethod - def _tokenize_international(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_international(cls: type["_SacreBLEUTokenizer"], line: str) -> str: r"""Tokenizes a string following the official BLEU implementation. See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 @@ -295,7 +295,7 @@ def _tokenize_international(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return " ".join(line.split()) @classmethod - def _tokenize_char(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_char(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes all the characters in the input line. Args: @@ -308,7 +308,7 @@ def _tokenize_char(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return " ".join(char for char in line) @classmethod - def _tokenize_ja_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_ja_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Japanese string line using MeCab morphological analyzer. Args: @@ -327,7 +327,7 @@ def _tokenize_ja_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return tagger.parse(line).strip() @classmethod - def _tokenize_ko_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_ko_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a Korean string line using MeCab-korean morphological analyzer. Args: @@ -347,7 +347,7 @@ def _tokenize_ko_mecab(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: @classmethod def _tokenize_flores( - cls: Type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"] + cls: type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"] ) -> str: """Tokenizes a string line using sentencepiece tokenizer. @@ -373,7 +373,7 @@ def _tokenize_flores( return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr] @classmethod - def _tokenize_flores_101(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_flores_101(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset. Args: @@ -386,7 +386,7 @@ def _tokenize_flores_101(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: return cls._tokenize_flores(line, "flores101") @classmethod - def _tokenize_flores_200(cls: Type["_SacreBLEUTokenizer"], line: str) -> str: + def _tokenize_flores_200(cls: type["_SacreBLEUTokenizer"], line: str) -> str: """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset. Args: @@ -405,7 +405,7 @@ def _lower(line: str, lowercase: bool) -> str: return line @classmethod - def _check_tokenizers_validity(cls: Type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None: + def _check_tokenizers_validity(cls: type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None: """Check if a supported tokenizer is chosen. Also check all dependencies of a given tokenizers are installed. diff --git a/src/torchmetrics/functional/text/squad.py b/src/torchmetrics/functional/text/squad.py index 6ea9c70c9a1..d317a7ce806 100644 --- a/src/torchmetrics/functional/text/squad.py +++ b/src/torchmetrics/functional/text/squad.py @@ -23,11 +23,11 @@ from torchmetrics.utilities import rank_zero_warn -SINGLE_PRED_TYPE = Dict[str, str] -PREDS_TYPE = Union[SINGLE_PRED_TYPE, List[SINGLE_PRED_TYPE]] -SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] -TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, List[SINGLE_TARGET_TYPE]] -UPDATE_METHOD_SINGLE_PRED_TYPE = Union[List[Dict[str, Union[str, int]]], str, Dict[str, Union[List[str], List[int]]]] +SINGLE_PRED_TYPE = dict[str, str] +PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]] +SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] +TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]] +UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]] SQuAD_FORMAT = { "answers": {"answer_start": [1], "text": ["This is a test text"]}, @@ -57,7 +57,7 @@ def lower(text: str) -> str: return white_space_fix(remove_articles(remove_punc(lower(s)))) -def _get_tokens(s: str) -> List[str]: +def _get_tokens(s: str) -> list[str]: """Split a sentence into separate tokens.""" return [] if not s else _normalize_text(s).split() @@ -84,7 +84,7 @@ def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: def _metric_max_over_ground_truths( - metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: List[str] + metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str] ) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" return max(metric_fn(prediction, truth) for truth in ground_truths) # type: ignore[type-var] @@ -92,12 +92,12 @@ def _metric_max_over_ground_truths( def _squad_input_check( preds: PREDS_TYPE, targets: TARGETS_TYPE -) -> Tuple[Dict[str, str], list[Dict[str, List[Dict[str, list[Dict[str, Any]]]]]]]: +) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]: """Check for types and convert the input to necessary format to compute the input.""" - if isinstance(preds, Dict): + if isinstance(preds, dict): preds = [preds] - if isinstance(targets, Dict): + if isinstance(targets, dict): targets = [targets] for pred in preds: @@ -118,7 +118,7 @@ def _squad_input_check( f"{SQuAD_FORMAT}" ) - answers: dict[str, Union[List[str], list[int]]] = target["answers"] # type: ignore[assignment] + answers: dict[str, Union[list[str], list[int]]] = target["answers"] # type: ignore[assignment] if "text" not in answers: raise KeyError( "Expected keys in a 'answers' are 'text'." @@ -134,9 +134,9 @@ def _squad_input_check( def _squad_update( - preds: Dict[str, str], + preds: dict[str, str], target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]], -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. Args: @@ -180,7 +180,7 @@ def _squad_update( return f1, exact_match, total -def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> Dict[str, Tensor]: +def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. Return: @@ -192,7 +192,7 @@ def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> Dict[str, return {"exact_match": exact_match, "f1": f1} -def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> Dict[str, Tensor]: +def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]: """Calculate `SQuAD Metric`_ . Args: diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 08bf6b7562f..3cb862d046a 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -151,7 +151,7 @@ def _normalize_general_and_western(sentence: str) -> str: return sentence @classmethod - def _normalize_asian(cls: Type["_TercomTokenizer"], sentence: str) -> str: + def _normalize_asian(cls: type["_TercomTokenizer"], sentence: str) -> str: """Split Chinese chars and Japanese kanji down to character level.""" # 4E00—9FFF CJK Unified Ideographs # 3400—4DBF CJK Unified Ideographs Extension A @@ -183,7 +183,7 @@ def _remove_punct(sentence: str) -> str: return re.sub(r"[\.,\?:;!\"\(\)]", "", sentence) @classmethod - def _remove_asian_punct(cls: Type["_TercomTokenizer"], sentence: str) -> str: + def _remove_asian_punct(cls: type["_TercomTokenizer"], sentence: str) -> str: """Remove asian punctuation from an input sentence string.""" sentence = re.sub(cls._ASIAN_PUNCTUATION, r"", sentence) return re.sub(cls._FULL_WIDTH_PUNCTUATION, r"", sentence) @@ -203,7 +203,7 @@ def _preprocess_sentence(sentence: str, tokenizer: _TercomTokenizer) -> str: return tokenizer(sentence.rstrip()) -def _find_shifted_pairs(pred_words: List[str], target_words: List[str]) -> Iterator[Tuple[int, int, int]]: +def _find_shifted_pairs(pred_words: list[str], target_words: list[str]) -> Iterator[tuple[int, int, int]]: """Find matching word sub-sequences in two lists of words. Ignores sub- sequences starting at the same position. Args: @@ -243,9 +243,9 @@ def _find_shifted_pairs(pred_words: List[str], target_words: List[str]) -> Itera def _handle_corner_cases_during_shifting( - alignments: Dict[int, int], - pred_errors: List[int], - target_errors: List[int], + alignments: dict[int, int], + pred_errors: list[int], + target_errors: list[int], pred_start: int, target_start: int, length: int, @@ -276,7 +276,7 @@ def _handle_corner_cases_during_shifting( return pred_start <= alignments[target_start] < pred_start + length -def _perform_shift(words: List[str], start: int, length: int, target: int) -> List[str]: +def _perform_shift(words: list[str], start: int, length: int, target: int) -> list[str]: """Perform a shift in ``words`` from ``start`` to ``target``. Args: @@ -290,13 +290,13 @@ def _perform_shift(words: List[str], start: int, length: int, target: int) -> Li """ - def _shift_word_before_previous_position(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_before_previous_position(words: list[str], start: int, target: int, length: int) -> list[str]: return words[:target] + words[start : start + length] + words[target:start] + words[start + length :] - def _shift_word_after_previous_position(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_after_previous_position(words: list[str], start: int, target: int, length: int) -> list[str]: return words[:start] + words[start + length : target] + words[start : start + length] + words[target:] - def _shift_word_within_shifted_string(words: List[str], start: int, target: int, length: int) -> List[str]: + def _shift_word_within_shifted_string(words: list[str], start: int, target: int, length: int) -> list[str]: shifted_words = words[:start] shifted_words += words[start + length : length + target] shifted_words += words[start : start + length] @@ -311,11 +311,11 @@ def _shift_word_within_shifted_string(words: List[str], start: int, target: int, def _shift_words( - pred_words: List[str], - target_words: List[str], + pred_words: list[str], + target_words: list[str], cached_edit_distance: _LevenshteinEditDistance, checked_candidates: int, -) -> Tuple[int, List[str], int]: +) -> tuple[int, list[str], int]: """Attempt to shift words to match a hypothesis with a reference. It returns the lowest number of required edits between a hypothesis and a provided reference, a list of shifted @@ -343,7 +343,7 @@ def _shift_words( trace = _flip_trace(inverted_trace) alignments, target_errors, pred_errors = _trace_to_alignment(trace) - best: Optional[Tuple[int, int, int, int, List[str]]] = None + best: Optional[tuple[int, int, int, int, list[str]]] = None for pred_start, target_start, length in _find_shifted_pairs(pred_words, target_words): if _handle_corner_cases_during_shifting( @@ -391,7 +391,7 @@ def _shift_words( return best_score, shifted_words, checked_candidates -def _translation_edit_rate(pred_words: List[str], target_words: List[str]) -> Tensor: +def _translation_edit_rate(pred_words: list[str], target_words: list[str]) -> Tensor: """Compute translation edit rate between hypothesis and reference sentences. Args: @@ -426,7 +426,7 @@ def _translation_edit_rate(pred_words: List[str], target_words: List[str]) -> Te return tensor(total_edits) -def _compute_sentence_statistics(pred_words: List[str], target_words: List[List[str]]) -> Tuple[Tensor, Tensor]: +def _compute_sentence_statistics(pred_words: list[str], target_words: list[list[str]]) -> tuple[Tensor, Tensor]: """Compute sentence TER statistics between hypothesis and provided references. Args: @@ -477,8 +477,8 @@ def _ter_update( tokenizer: _TercomTokenizer, total_num_edits: Tensor, total_tgt_length: Tensor, - sentence_ter: Optional[List[Tensor]] = None, -) -> Tuple[Tensor, Tensor, Optional[List[Tensor]]]: + sentence_ter: Optional[list[Tensor]] = None, +) -> tuple[Tensor, Tensor, Optional[list[Tensor]]]: """Update TER statistics. Args: @@ -505,8 +505,8 @@ def _ter_update( target, preds = _validate_inputs(target, preds) for pred, tgt in zip(preds, target): - tgt_words_: List[List[str]] = [_preprocess_sentence(_tgt, tokenizer).split() for _tgt in tgt] - pred_words_: List[str] = _preprocess_sentence(pred, tokenizer).split() + tgt_words_: list[list[str]] = [_preprocess_sentence(_tgt, tokenizer).split() for _tgt in tgt] + pred_words_: list[str] = _preprocess_sentence(pred, tokenizer).split() num_edits, tgt_length = _compute_sentence_statistics(pred_words_, tgt_words_) total_num_edits += num_edits total_tgt_length += tgt_length @@ -537,7 +537,7 @@ def translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: """Calculate Translation edit rate (`TER`_) of machine translated text with one or more references. This implementation follows the implementations from @@ -581,7 +581,7 @@ def translation_edit_rate( total_num_edits = tensor(0.0) total_tgt_length = tensor(0.0) - sentence_ter: Optional[List[Tensor]] = [] if return_sentence_level_score else None + sentence_ter: Optional[list[Tensor]] = [] if return_sentence_level_score else None total_num_edits, total_tgt_length, sentence_ter = _ter_update( preds, diff --git a/src/torchmetrics/functional/text/wer.py b/src/torchmetrics/functional/text/wer.py index af50d4cb289..0479ad0d945 100644 --- a/src/torchmetrics/functional/text/wer.py +++ b/src/torchmetrics/functional/text/wer.py @@ -21,9 +21,9 @@ def _wer_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor]: """Update the wer score with the current set of references and predictions. Args: @@ -63,7 +63,7 @@ def _wer_compute(errors: Tensor, total: Tensor) -> Tensor: return errors / total -def word_error_rate(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word error rate (WordErrorRate_) is a common metric of performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the diff --git a/src/torchmetrics/functional/text/wil.py b/src/torchmetrics/functional/text/wil.py index bb0d5f7e0d8..e7ca50abec4 100644 --- a/src/torchmetrics/functional/text/wil.py +++ b/src/torchmetrics/functional/text/wil.py @@ -20,9 +20,9 @@ def _word_info_lost_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor, Tensor]: """Update the WIL score with the current set of references and predictions. Args: @@ -69,7 +69,7 @@ def _word_info_lost_compute(errors: Tensor, target_total: Tensor, preds_total: T return 1 - ((errors / target_total) * (errors / preds_total)) -def word_information_lost(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better diff --git a/src/torchmetrics/functional/text/wip.py b/src/torchmetrics/functional/text/wip.py index 2c77139009b..2f6f635b053 100644 --- a/src/torchmetrics/functional/text/wip.py +++ b/src/torchmetrics/functional/text/wip.py @@ -19,9 +19,9 @@ def _wip_update( - preds: Union[str, List[str]], - target: Union[str, List[str]], -) -> Tuple[Tensor, Tensor, Tensor]: + preds: Union[str, list[str]], + target: Union[str, list[str]], +) -> tuple[Tensor, Tensor, Tensor]: """Update the wip score with the current set of references and predictions. Args: @@ -68,7 +68,7 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return (errors / target_total) * (errors / preds_total) -def word_information_preserved(preds: Union[str, List[str]], target: Union[str, List[str]]) -> Tensor: +def word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 50b54f479ff..18f3e1840ba 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -55,10 +55,10 @@ def __init__( kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = "relu", **kwargs: Any, ) -> None: @@ -91,10 +91,10 @@ class _PeakSignalNoiseRatio(PeakSignalNoiseRatio): def __init__( self, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("PeakSignalNoiseRatio", "image") @@ -116,7 +116,7 @@ class _RelativeAverageSpectralError(RelativeAverageSpectralError): def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RelativeAverageSpectralError", "image") super().__init__(window_size=window_size, **kwargs) @@ -137,7 +137,7 @@ class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWi def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image") super().__init__(window_size=window_size, **kwargs) @@ -201,7 +201,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 97d95ccd926..0330d203e33 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -70,8 +70,8 @@ class SpectralDistortionIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 02530f7cd90..926929fe2be 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -94,10 +94,10 @@ class SpatialDistortionIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - ms: List[Tensor] - pan: List[Tensor] - pan_lr: List[Tensor] + preds: list[Tensor] + ms: list[Tensor] + pan: list[Tensor] + pan_lr: list[Tensor] def __init__( self, @@ -128,7 +128,7 @@ def __init__( self.add_state("pan", default=[], dist_reduce_fx="cat") self.add_state("pan_lr", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: + def update(self, preds: Tensor, target: dict[str, Tensor]) -> None: """Update state with preds and target. Args: diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 22c24b164f1..6e8ba2624d8 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -78,8 +78,8 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 7ad1bb0b892..1eac1bc9fcf 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -48,7 +48,7 @@ class NoTrainInceptionV3(_FeatureExtractorInceptionV3): def __init__( self, name: str, - features_list: List[str], + features_list: list[str], feature_extractor_weights_path: Optional[str] = None, ) -> None: if not _TORCH_FIDELITY_AVAILABLE: @@ -65,7 +65,7 @@ def train(self, mode: bool) -> "NoTrainInceptionV3": """Force network to always be in evaluation mode.""" return super().train(False) - def _torch_fidelity_forward(self, x: Tensor) -> Tuple[Tensor, ...]: + def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]: """Forward method of inception net. Copy of the forward method from this file: @@ -299,7 +299,7 @@ def __init__( feature: Union[int, Module] = 2048, reset_real_features: bool = True, normalize: bool = False, - input_img_size: Tuple[int, int, int] = (3, 299, 299), + input_img_size: tuple[int, int, int] = (3, 299, 299), **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 8bf42584af2..02c5cf04a35 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -102,7 +102,7 @@ class InceptionScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - features: List + features: list inception: Module feature_network: str = "inception" @@ -152,7 +152,7 @@ def update(self, imgs: Tensor) -> None: features = self.inception(imgs) self.features.append(features) - def compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor]: """Compute metric.""" features = dim_zero_cat(self.features) # random permute the features diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 6deb4efa22d..5aa3645bc77 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -169,8 +169,8 @@ class KernelInceptionDistance(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - real_features: List[Tensor] - fake_features: List[Tensor] + real_features: list[Tensor] + fake_features: list[Tensor] inception: Module feature_network: str = "inception" @@ -264,7 +264,7 @@ def update(self, imgs: Tensor, real: bool) -> None: else: self.fake_features.append(features) - def compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor]: """Calculate KID score based on accumulated extracted features from the two distributions. Implementation inspired by `Fid Score`_ diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 8c4948b18a8..ba792ab5d19 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -99,7 +99,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): feature_network: str = "net" # due to the use of named tuple in the backbone the net variable cannot be scripted - __jit_ignored_attributes__: ClassVar[List[str]] = ["net"] + __jit_ignored_attributes__: ClassVar[list[str]] = ["net"] def __init__( self, diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 5d344b57dd1..31105fd290b 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -149,8 +149,8 @@ class MemorizationInformedFrechetInceptionDistance(Metric): is_differentiable: bool = False full_state_update: bool = False - real_features: List[Tensor] - fake_features: List[Tensor] + real_features: list[Tensor] + fake_features: list[Tensor] inception: Module feature_network: str = "inception" diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 117dca6f8cb..f6d909bf477 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -166,7 +166,7 @@ def update(self, generator: GeneratorType) -> None: _validate_generator_model(generator, self.conditional) self.generator = generator - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute the perceptual path length.""" return perceptual_path_length( generator=self.generator, diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index 40c6fdab89f..fa76d677133 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -87,10 +87,10 @@ class PeakSignalNoiseRatio(Metric): def __init__( self, - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, Tuple[int, ...]]] = None, + dim: Optional[Union[int, tuple[int, ...]]] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index 75b88379669..4d99396d9de 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -90,10 +90,10 @@ class QualityWithNoReference(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - ms: List[Tensor] - pan: List[Tensor] - pan_lr: List[Tensor] + preds: list[Tensor] + ms: list[Tensor] + pan: list[Tensor] + pan_lr: list[Tensor] def __init__( self, @@ -131,7 +131,7 @@ def __init__( self.add_state("pan", default=[], dist_reduce_fx="cat") self.add_state("pan_lr", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: + def update(self, preds: Tensor, target: dict[str, Tensor]) -> None: """Update state with preds and target. Args: diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index 6f7b0b10346..dbe51d5d969 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -64,13 +64,13 @@ class RelativeAverageSpectralError(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index 75feb612802..fdae49a72ba 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -69,7 +69,7 @@ class RootMeanSquaredErrorUsingSlidingWindow(Metric): def __init__( self, window_size: int = 8, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index b313158f80b..e1407120b2f 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -73,8 +73,8 @@ class SpectralAngleMapper(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] sum_sam: Tensor numel: Tensor diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 576ac4f0879..1a2a6f858c8 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -84,8 +84,8 @@ class StructuralSimilarityIndexMeasure(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, @@ -93,7 +93,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, @@ -156,7 +156,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: else: self.similarity.append(similarity) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute SSIM over state.""" if self.reduction == "elementwise_mean": similarity = self.similarity / self.total @@ -285,8 +285,8 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, @@ -294,10 +294,10 @@ def __init__( kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - data_range: Optional[Union[float, Tuple[float, float]]] = None, + data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, - betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = "relu", **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 287e58a3a43..087b6a06cb4 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -69,7 +69,7 @@ class TotalVariation(Metric): plot_lower_bound: float = 0.0 num_elements: Tensor - score_list: List[Tensor] + score_list: list[Tensor] score: Tensor def __init__(self, reduction: Optional[Literal["mean", "sum", "none"]] = "sum", **kwargs: Any) -> None: diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index c503cc1f394..c2cf917301d 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -72,8 +72,8 @@ class UniversalImageQualityIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] sum_uqi: Tensor numel: Tensor diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 6344d683edb..3b8868adb54 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -84,8 +84,8 @@ class Metric(Module, ABC): """ - __jit_ignored_attributes__: ClassVar[List[str]] = ["device"] - __jit_unused_properties__: ClassVar[List[str]] = [ + __jit_ignored_attributes__: ClassVar[list[str]] = ["device"] + __jit_unused_properties__: ClassVar[list[str]] = [ "is_differentiable", "higher_is_better", "plot_lower_bound", @@ -166,13 +166,13 @@ def __init__( self._dtype_convert = False # initialize state - self._defaults: Dict[str, Union[List, Tensor]] = {} - self._persistent: Dict[str, bool] = {} - self._reductions: Dict[str, Union[str, Callable[..., Any], None]] = {} + self._defaults: dict[str, Union[list, Tensor]] = {} + self._persistent: dict[str, bool] = {} + self._reductions: dict[str, Union[str, Callable[..., Any], None]] = {} # state management self._is_synced = False - self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None + self._cache: Optional[dict[str, Union[list[Tensor], Tensor]]] = None @property def _update_called(self) -> bool: @@ -194,7 +194,7 @@ def update_count(self) -> int: return self._update_count @property - def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]: + def metric_state(self) -> dict[str, Union[list[Tensor], Tensor]]: """Get the current state of the metric.""" return {attr: getattr(self, attr) for attr in self._defaults} @@ -402,7 +402,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: return batch_val - def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: + def merge_state(self, incoming_state: Union[dict[str, Any], "Metric"]) -> None: """Merge incoming metric state to the current state of the metric. Args: @@ -463,7 +463,7 @@ def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: self._reduce_states(incoming_state) - def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: + def _reduce_states(self, incoming_state: dict[str, Any]) -> None: """Add an incoming metric state to the current state of the metric. Args: @@ -726,7 +726,7 @@ def plot(self, *_: Any, **__: Any) -> Any: def _plot( self, - val: Optional[Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, + val: Optional[Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -777,7 +777,7 @@ def clone(self) -> "Metric": """Make a copy of the metric.""" return deepcopy(self) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """Get the current state, including all metric states, for the metric. Used for loading and saving a metric. @@ -786,7 +786,7 @@ def __getstate__(self) -> Dict[str, Any]: # ignore update and compute functions for pickling return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]} - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """Set the state of the metric, based on a input state. Used for loading and saving a metric. @@ -924,10 +924,10 @@ def persistent(self, mode: bool = False) -> None: def state_dict( # type: ignore[override] # todo self, - destination: Optional[Dict[str, Any]] = None, + destination: Optional[dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get the current state of metric as an dictionary. Args: @@ -938,7 +938,7 @@ def state_dict( # type: ignore[override] # todo If set to ``True``, detaching will not be performed. """ - destination: Dict[str, Union[torch.Tensor, List, Any]] = super().state_dict( + destination: dict[str, Union[torch.Tensor, list, Any]] = super().state_dict( destination=destination, # type: ignore[arg-type] prefix=prefix, keep_vars=keep_vars, @@ -956,9 +956,9 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + def _copy_state_dict(self) -> dict[str, Union[Tensor, list[Any]]]: """Copy the current state values.""" - cache: Dict[str, Union[Tensor, List[Any]]] = {} + cache: dict[str, Union[Tensor, list[Any]]] = {} for attr in self._defaults: current_value = getattr(self, attr) @@ -977,9 +977,9 @@ def _load_from_state_dict( prefix: str, local_metadata: dict, strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: """Load metric states from state_dict.""" for key in self._defaults: @@ -990,7 +990,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs ) - def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Filter kwargs such that they match the update signature of the metric.""" # filter all parameters based on update signature except those of # types `VAR_POSITIONAL` for `* args` and `VAR_KEYWORD` for `** kwargs` @@ -1173,7 +1173,7 @@ def __getitem__(self, idx: int) -> "CompositionalMetric": """Construct compositional metric using the get item operator.""" return CompositionalMetric(lambda x: x[idx], self, None) - def __getnewargs__(self) -> Tuple: + def __getnewargs__(self) -> tuple: """Needed method for construction of new metrics __new__ method.""" return tuple( Metric.__str__(self), diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index e4abfaa89b1..947972ff02e 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -166,7 +166,7 @@ class CLIPImageQualityAssessment(Metric): plot_upper_bound = 100.0 anchors: Tensor - probs_list: List[Tensor] + probs_list: list[Tensor] feature_network: str = "model" def __init__( @@ -179,7 +179,7 @@ def __init__( "openai/clip-vit-large-patch14", ] = "clip_iqa", data_range: float = 1.0, - prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), + prompts: tuple[Union[str, tuple[str, str]]] = ("quality",), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -213,7 +213,7 @@ def update(self, images: Tensor) -> None: raise ValueError("Output probs should be a tensor") self.probs_list.append(probs) - def compute(self) -> Union[Tensor, Dict[str, Tensor]]: + def compute(self) -> Union[Tensor, dict[str, Tensor]]: """Compute metric.""" probs = dim_zero_cat(self.probs_list) if len(self.prompts_name) == 1: diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index aca26ac09e5..f4cf857bbf2 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -118,7 +118,7 @@ def __init__( self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None: + def update(self, images: Union[Tensor, list[Tensor]], text: Union[str, list[str]]) -> None: """Update CLIP score on a batch of images and text. Args: diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index 254796e96c9..cf6f9058326 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -77,7 +77,7 @@ class FleissKappa(Metric): is_differentiable: bool = False higher_is_better: bool = True plot_upper_bound: float = 1.0 - counts: List[Tensor] + counts: list[Tensor] def __init__(self, mode: Literal["counts", "probs"] = "counts", **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index 5c86ac00cab..f7bf201d7ea 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -66,8 +66,8 @@ class CosineSimilarity(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index b75762a3b0c..9014cb2daab 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -63,9 +63,9 @@ class CriticalSuccessIndex(Metric): hits: torch.Tensor misses: torch.Tensor false_alarms: torch.Tensor - hits_list: List[torch.Tensor] - misses_list: List[torch.Tensor] - false_alarms_list: List[torch.Tensor] + hits_list: list[torch.Tensor] + misses_list: list[torch.Tensor] + false_alarms_list: list[torch.Tensor] def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 63c2ec150b6..8b651737e8a 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -120,8 +120,8 @@ class KendallRankCorrCoef(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, @@ -154,7 +154,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: num_outputs=self.num_outputs, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 75d323a60f6..8ae82e17662 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -33,7 +33,7 @@ def _final_aggregation( vars_y: Tensor, corrs_xy: Tensor, nbs: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Aggregate the statistics from multiple devices. Formula taken from here: `Aggregate the statistics from multiple devices`_ @@ -117,8 +117,8 @@ class PearsonCorrCoef(Metric): full_state_update: bool = True plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] mean_x: Tensor mean_y: Tensor var_x: Tensor diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index de94903c8c0..59755846cc9 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -75,8 +75,8 @@ class SpearmanCorrCoef(Metric): plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 - preds: List[Tensor] - target: List[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, diff --git a/src/torchmetrics/retrieval/base.py b/src/torchmetrics/retrieval/base.py index f9a0a4f8cc4..94bb49982f7 100644 --- a/src/torchmetrics/retrieval/base.py +++ b/src/torchmetrics/retrieval/base.py @@ -98,9 +98,9 @@ class RetrievalMetric(Metric, ABC): higher_is_better: bool = True full_state_update: bool = False - indexes: List[Tensor] - preds: List[Tensor] - target: List[Tensor] + indexes: list[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 74fe1dd4c50..5b6af216843 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -35,7 +35,7 @@ def _retrieval_recall_at_fixed_precision( recall: Tensor, top_k: Tensor, min_precision: float, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Compute maximum recall with condition that corresponding precision >= `min_precision`. Args: @@ -143,9 +143,9 @@ class RetrievalPrecisionRecallCurve(Metric): higher_is_better: bool = True full_state_update: bool = False - indexes: List[Tensor] - preds: List[Tensor] - target: List[Tensor] + indexes: list[Tensor] + preds: list[Tensor] + target: list[Tensor] def __init__( self, @@ -202,7 +202,7 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: self.preds.append(preds) self.target.append(target) - def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor, Tensor]: """Compute metric.""" # concat all data indexes = dim_zero_cat(self.indexes) @@ -257,7 +257,7 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: def plot( self, - curve: Optional[Tuple[Tensor, Tensor, Tensor]] = None, + curve: Optional[tuple[Tensor, Tensor, Tensor]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -381,7 +381,7 @@ def __init__( self.min_precision = min_precision - def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override] """Compute metric.""" precisions, recalls, top_k = super().compute() diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 05a6e29b387..b28ade07f1e 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -100,9 +100,9 @@ class DiceScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - numerator: List[Tensor] - denominator: List[Tensor] - support: List[Tensor] + numerator: list[Tensor] + denominator: list[Tensor] + support: list[Tensor] def __init__( self, diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index 727f9e7bb08..666790f914d 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -89,7 +89,7 @@ def __init__( num_classes: int, include_background: bool = False, distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", - spacing: Optional[Union[Tensor, List[float]]] = None, + spacing: Optional[Union[Tensor, list[float]]] = None, directed: bool = False, input_format: Literal["one-hot", "index"] = "one-hot", **kwargs: Any, diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 5960e16a00f..6df6f20eb6b 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -47,7 +47,7 @@ def _download_model_for_bert_score() -> None: __doctest_skip__ = ["BERTScore", "BERTScore.plot"] -def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> Dict[str, Tensor]: +def _get_input_dict(input_ids: list[Tensor], attention_mask: list[Tensor]) -> dict[str, Tensor]: """Create an input dictionary of ``input_ids`` and ``attention_mask`` for BERTScore calculation.""" return {"input_ids": torch.cat(input_ids), "attention_mask": torch.cat(attention_mask)} @@ -128,10 +128,10 @@ class BERTScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds_input_ids: List[Tensor] - preds_attention_mask: List[Tensor] - target_input_ids: List[Tensor] - target_attention_mask: List[Tensor] + preds_input_ids: list[Tensor] + preds_attention_mask: list[Tensor] + target_input_ids: list[Tensor] + target_attention_mask: list[Tensor] def __init__( self, @@ -140,7 +140,7 @@ def __init__( all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Optional[Any] = None, - user_forward_fn: Optional[Callable[[Module, Dict[str, Tensor]], Tensor]] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, @@ -232,7 +232,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s self.target_input_ids.append(target_dict["input_ids"]) self.target_attention_mask.append(target_dict["attention_mask"]) - def compute(self) -> Dict[str, Union[Tensor, List[float], str]]: + def compute(self) -> dict[str, Union[Tensor, list[float], str]]: """Calculate BERT scores.""" preds = { "input_ids": dim_zero_cat(self.preds_input_ids), diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 9862c450ed3..47370328dbc 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -85,7 +85,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, total = _cer_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 962b0813b3b..88a791e9714 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -45,8 +45,8 @@ "total_matching_word_n_grams", ) -_DICT_STATES_TYPES = Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +_DICT_STATES_TYPES = tuple[ + dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor] ] @@ -101,7 +101,7 @@ class CHRFScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - sentence_chrf_score: Optional[List[Tensor]] = None + sentence_chrf_score: Optional[list[Tensor]] = None def __init__( self, @@ -157,7 +157,7 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: if self.sentence_chrf_score is not None: self.sentence_chrf_score = n_grams_dicts_tuple[-1] - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate chrF/chrF++ score.""" if self.sentence_chrf_score is not None: return ( @@ -168,7 +168,7 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" - n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( + n_grams_dicts: dict[str, dict[int, Tensor]] = dict( zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) ) @@ -201,7 +201,7 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: """Return a metric state name w.r.t input args.""" return f"total_{text}_{n_gram_level}_{n}_grams" - def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: + def _get_text_n_gram_iterator(self) -> Iterator[tuple[tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) diff --git a/src/torchmetrics/text/edit.py b/src/torchmetrics/text/edit.py index 947fc79cd6c..060a8fc6b26 100644 --- a/src/torchmetrics/text/edit.py +++ b/src/torchmetrics/text/edit.py @@ -90,7 +90,7 @@ class EditDistance(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - edit_scores_list: List[Tensor] + edit_scores_list: list[Tensor] edit_scores: Tensor num_elements: Tensor diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index c0629b9ba44..9dfaaa9edcf 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -65,7 +65,7 @@ class ExtendedEditDistance(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - sentence_eed: List[Tensor] + sentence_eed: list[Tensor] def __init__( self, @@ -113,7 +113,7 @@ def update( self.sentence_eed, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate extended edit distance score.""" average = _eed_compute(self.sentence_eed) diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index 3488e2074b4..6c1320248d7 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -112,10 +112,10 @@ class InfoLM(Metric): """ is_differentiable = False - preds_input_ids: List[Tensor] - preds_attention_mask: List[Tensor] - target_input_ids: List[Tensor] - target_attention_mask: List[Tensor] + preds_input_ids: list[Tensor] + preds_attention_mask: list[Tensor] + target_input_ids: list[Tensor] + target_attention_mask: list[Tensor] _information_measure_higher_is_better: ClassVar = { # following values are <0 @@ -145,7 +145,7 @@ def __init__( num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) self.model_name_or_path = model_name_or_path @@ -189,7 +189,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s self.target_input_ids.append(target_input_ids) self.target_attention_mask.append(target_attention_mask) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate selected information measure using the pre-trained language model.""" preds_dataloader = _get_dataloader( input_ids=dim_zero_cat(self.preds_input_ids), diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index 37dae4cc4f6..b3f086e0d00 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -84,8 +84,8 @@ def __init__( def update( self, - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[str, list[str]], + target: Union[str, list[str]], ) -> None: """Update state with predictions and targets.""" errors, total = _mer_update(preds, target) diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index a090d522db8..950bb9ad449 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -69,7 +69,7 @@ class Perplexity(Metric): def __init__( self, ignore_index: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) if ignore_index is not None and not isinstance(ignore_index, int): diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index d0cac0df18d..be8399f6fad 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -109,7 +109,7 @@ def __init__( normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, accumulate: Literal["avg", "best"] = "best", - rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), + rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -156,7 +156,7 @@ def update( if isinstance(target, str): target = [[target]] - output: Dict[Union[int, str], List[Dict[str, Tensor]]] = _rouge_score_update( + output: dict[Union[int, str], list[dict[str, Tensor]]] = _rouge_score_update( preds, target, self.rouge_keys_values, @@ -170,7 +170,7 @@ def update( for tp, value in metric.items(): getattr(self, f"rouge{rouge_key}_{tp}").append(value.to(self.device)) # todo - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Calculate (Aggregate and provide confidence intervals) ROUGE score.""" update_output = {} for rouge_key in self.rouge_keys_values: diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index e4d98c2a8b6..124c87f06ce 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -119,7 +119,7 @@ def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None: self.exact_match += exact_match self.total += total - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch.""" return _squad_compute(self.f1_score, self.exact_match, self.total) diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 8ded3c9b606..bb477a65bfb 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -69,7 +69,7 @@ class TranslationEditRate(Metric): total_num_edits: Tensor total_tgt_len: Tensor - sentence_ter: Optional[List[Tensor]] = None + sentence_ter: Optional[list[Tensor]] = None def __init__( self, @@ -109,7 +109,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Sequence[Union[str, S self.sentence_ter, ) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]: """Calculate the translate error rate (TER).""" ter = _ter_compute(self.total_num_edits, self.total_tgt_len) if self.sentence_ter is not None: diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index fc947ef2772..b96044ef1bc 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -84,7 +84,7 @@ def __init__( self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, total = _wer_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index a0d42fbfcf2..0a0c9f7cd13 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -82,7 +82,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, target_total, preds_total = _word_info_lost_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index 6d5db6b3e2c..fc4ea5b8f15 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -83,7 +83,7 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: + def update(self, preds: Union[str, list[str]], target: Union[str, list[str]]) -> None: """Update state with predictions and targets.""" errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 79878d058b7..6da328c01e7 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -71,7 +71,7 @@ def _basic_input_validation( raise ValueError("If you set `multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") -def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[DataType, int]: +def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> tuple[DataType, int]: """Check that the shape and type of inputs are consistent with each other. The input types needs to be one of allowed input types (see the documentation of docstring of @@ -302,7 +302,7 @@ def _check_classification_inputs( def _input_squeeze( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Remove excess dimensions.""" if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) @@ -319,7 +319,7 @@ def _input_format_classification( num_classes: Optional[int] = None, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, DataType]: +) -> tuple[Tensor, Tensor, DataType]: """Convert preds and target tensors into common format. Preds and targets are supposed to fall into one of these categories (and are @@ -461,7 +461,7 @@ def _input_format_classification_one_hot( target: Tensor, threshold: float = 0.5, multilabel: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Convert preds and target tensors into one hot spare label tensors. Args: @@ -509,7 +509,7 @@ def _check_retrieval_functional_inputs( preds: Tensor, target: Tensor, allow_non_binary_target: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -542,7 +542,7 @@ def _check_retrieval_inputs( target: Tensor, allow_non_binary_target: bool = False, ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -589,7 +589,7 @@ def _check_retrieval_target_and_prediction_types( preds: Tensor, target: Tensor, allow_non_binary_target: bool = False, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type. Args: @@ -634,8 +634,8 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-6) -> bool: @no_type_check def check_forward_full_state_property( metric_class: Metric, - init_args: Optional[Dict[str, Any]] = None, - input_args: Optional[Dict[str, Any]] = None, + init_args: Optional[dict[str, Any]] = None, + input_args: Optional[dict[str, Any]] = None, num_update_to_compare: Sequence[int] = [10, 100, 1000], reps: int = 5, ) -> None: diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index e526ecc8456..f16600f35b6 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -81,7 +81,7 @@ def _adjust_weights_safe_divide( return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) -def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: +def _auc_format_inputs(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: """Check that auc input is correct.""" x = x.squeeze() if x.ndim > 1 else x y = y.squeeze() if y.ndim > 1 else y diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 4428c8cc7e9..d3ec64d621a 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -26,7 +26,7 @@ METRIC_EPS = 1e-6 -def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: +def dim_zero_cat(x: Union[Tensor, list[Tensor]]) -> Tensor: """Concatenation along the zero dimension.""" if isinstance(x, torch.Tensor): return x @@ -61,7 +61,7 @@ def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist] -def _flatten_dict(x: Dict) -> Tuple[Dict, bool]: +def _flatten_dict(x: dict) -> tuple[dict, bool]: """Flatten dict of dicts into single dict and checking for duplicates in keys along the way.""" new_dict = {} duplicates = False diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 90239b46af0..150138e2198 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -88,7 +88,7 @@ def class_reduce( raise ValueError(f"Reduction parameter {class_reduction} unknown. Choose between one of these: {valid_reduction}") -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]: with torch.no_grad(): gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) @@ -97,7 +97,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L return gathered_result -def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: +def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: """Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index bfc2fd20190..14ec7135a86 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -25,7 +25,7 @@ def _name() -> str: return "Task" @classmethod - def from_str(cls: Type["EnumStr"], value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": + def from_str(cls: type["EnumStr"], value: str, source: Literal["key", "value", "any"] = "key") -> "EnumStr": """Load from string. Raises: diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index dae78a873e9..98ef71aa4d2 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -27,13 +27,13 @@ import matplotlib.axes import matplotlib.pyplot as plt - _PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] + _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] _AX_TYPE = matplotlib.axes.Axes _CMAP_TYPE = Union[matplotlib.colors.Colormap, str] style_change = plt.style.context else: - _PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc] + _PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc] _AX_TYPE = object _CMAP_TYPE = object # type: ignore[misc] @@ -63,7 +63,7 @@ def _error_on_missing_matplotlib() -> None: @style_change(_style) def plot_single_or_multi_val( - val: Union[Tensor, Sequence[Tensor], Dict[str, Tensor], Sequence[Dict[str, Tensor]]], + val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]], ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] higher_is_better: Optional[bool] = None, lower_bound: Optional[float] = None, @@ -172,7 +172,7 @@ def plot_single_or_multi_val( return fig, ax -def _get_col_row_split(n: int) -> Tuple[int, int]: +def _get_col_row_split(n: int) -> tuple[int, int]: """Split `n` figures into `rows` x `cols` figures.""" nsq = sqrt(n) if int(nsq) == nsq: # square number @@ -182,7 +182,7 @@ def _get_col_row_split(n: int) -> Tuple[int, int]: return ceil(nsq), ceil(nsq) -def _get_text_color(patch_color: Tuple[float, float, float, float]) -> str: +def _get_text_color(patch_color: tuple[float, float, float, float]) -> str: """Get the text color for a given value and colormap. Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance. @@ -222,7 +222,7 @@ def plot_confusion_matrix( confmat: Tensor, ax: Optional[_AX_TYPE] = None, add_text: bool = True, - labels: Optional[List[Union[int, str]]] = None, + labels: Optional[list[Union[int, str]]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot an confusion matrix. @@ -295,13 +295,13 @@ def plot_confusion_matrix( @style_change(_style) def plot_curve( - curve: Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]], + curve: Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]], score: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] - label_names: Optional[Tuple[str, str]] = None, + label_names: Optional[tuple[str, str]] = None, legend_name: Optional[str] = None, name: Optional[str] = None, - labels: Optional[List[Union[int, str]]] = None, + labels: Optional[list[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py. diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 083cafd76c6..47192dad5f7 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -146,7 +146,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute the bootstrapped metric values. Always returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 682cfdad4b3..517d4c060d0 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -116,7 +116,7 @@ class ClasswiseWrapper(WrapperMetric): def __init__( self, metric: Metric, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, prefix: Optional[str] = None, postfix: Optional[str] = None, ) -> None: @@ -139,11 +139,11 @@ def __init__( self._update_count = 1 - def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Filter kwargs for the metric.""" return self.metric._filter_kwargs(**kwargs) - def _convert_output(self, x: Tensor) -> Dict[str, Any]: + def _convert_output(self, x: Tensor) -> dict[str, Any]: """Convert output to dictionary with labels as keys.""" # Will set the class name as prefix if neither prefix nor postfix is given if not self._prefix and not self._postfix: @@ -164,7 +164,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update state.""" self.metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute metric.""" return self._convert_output(self.metric.compute()) diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index 5cbc0106beb..caeb2e93217 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -84,7 +84,7 @@ class FeatureShare(MetricCollection): def __init__( self, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], dict[str, Metric]], max_cache_size: Optional[int] = None, ) -> None: # disable compute groups because the feature sharing is more custom diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index f91c3992529..236179a6cbe 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -83,7 +83,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update the underlying metric.""" self._base_metric.update(*args, **kwargs) - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> dict[str, Tensor]: """Compute the underlying metric as well as max and min values for this metric. Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 85cdb2c573f..86f53af6908 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -104,7 +104,7 @@ def __init__( self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> List[Tuple[Tensor, Tensor]]: + def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index a955d938583..c0b312cbedd 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -134,7 +134,7 @@ class MultitaskWrapper(WrapperMetric): def __init__( self, - task_metrics: Dict[str, Union[Metric, MetricCollection]], + task_metrics: dict[str, Union[Metric, MetricCollection]], prefix: Optional[str] = None, postfix: Optional[str] = None, ) -> None: @@ -160,7 +160,7 @@ def __init__( raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") self._postfix = postfix or "" - def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: + def items(self, flatten: bool = True) -> Iterable[tuple[str, nn.Module]]: """Iterate over task and task metrics. Args: @@ -204,7 +204,7 @@ def values(self, flatten: bool = True) -> Iterable[nn.Module]: else: yield metric - def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> None: + def update(self, task_preds: dict[str, Any], task_targets: dict[str, Any]) -> None: """Update each task's metric with its corresponding pred and target. Args: @@ -224,15 +224,15 @@ def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> No target = task_targets[task_name] metric.update(pred, target) - def _convert_output(self, output: Dict[str, Any]) -> Dict[str, Any]: + def _convert_output(self, output: dict[str, Any]) -> dict[str, Any]: """Convert the output of the underlying metrics to a dictionary with the task names as keys.""" return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()} - def compute(self) -> Dict[str, Any]: + def compute(self) -> dict[str, Any]: """Compute metrics for all tasks.""" return self._convert_output({task_name: metric.compute() for task_name, metric in self.task_metrics.items()}) - def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]: + def forward(self, task_preds: dict[str, Tensor], task_targets: dict[str, Tensor]) -> dict[str, Any]: """Call underlying forward methods for all tasks and return the result as a dictionary.""" # This method is overridden because we do not need the complex version defined in Metric, that relies on the # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the @@ -269,7 +269,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> return multitask_copy def plot( - self, val: Optional[Union[Dict, Sequence[Dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None + self, val: Optional[Union[dict, Sequence[dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None ) -> Sequence[_PLOT_OUT_TYPE]: """Plot a single or multiple values from the metric. @@ -353,7 +353,7 @@ def plot( fig_axs = [] for i, (task_name, task_metric) in enumerate(self.task_metrics.items()): ax = axes[i] if axes is not None else None - if isinstance(val, Dict): + if isinstance(val, dict): f, a = task_metric.plot(val[task_name], ax=ax) elif isinstance(val, Sequence): f, a = task_metric.plot([v[task_name] for v in val], ax=ax) diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 148e16c412c..a223f803397 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -104,10 +104,10 @@ class MetricTracker(ModuleList): """ - maximize: Union[bool, List[bool]] + maximize: Union[bool, list[bool]] def __init__( - self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, List[bool]]] = True + self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, list[bool]]] = True ) -> None: super().__init__() if not isinstance(metric, (Metric, MetricCollection)): @@ -221,10 +221,10 @@ def best_metric( None, float, Tensor, - Tuple[Union[int, float, Tensor], Union[int, float, Tensor]], - Tuple[None, None], - Dict[str, Union[float, None]], - Tuple[Dict[str, Union[float, None]], Dict[str, Union[int, None]]], + tuple[Union[int, float, Tensor], Union[int, float, Tensor]], + tuple[None, None], + dict[str, Union[float, None]], + tuple[dict[str, Union[float, None]], dict[str, Union[int, None]]], ]: """Return the highest metric out of all tracked. diff --git a/src/torchmetrics/wrappers/transformations.py b/src/torchmetrics/wrappers/transformations.py index f2ac106fe86..d4fb0f270d0 100644 --- a/src/torchmetrics/wrappers/transformations.py +++ b/src/torchmetrics/wrappers/transformations.py @@ -28,7 +28,7 @@ class MetricInputTransformer(WrapperMetric): """ - def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Dict[str, Any]) -> None: + def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) if not isinstance(wrapped_metric, (Metric, MetricCollection)): raise TypeError( @@ -53,7 +53,7 @@ def transform_target(self, target: torch.Tensor) -> torch.Tensor: """ return target - def _wrap_transform(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: + def _wrap_transform(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: """Wrap transformation functions to dispatch args to their individual transform functions.""" if len(args) == 1: return (self.transform_pred(args[0]),) @@ -61,7 +61,7 @@ def _wrap_transform(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: return self.transform_pred(args[0]), self.transform_target(args[1]) return self.transform_pred(args[0]), self.transform_target(args[1]), *args[2:] - def update(self, *args: torch.Tensor, **kwargs: Dict[str, Any]) -> None: + def update(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> None: """Wrap the update call of the underlying metric.""" args = self._wrap_transform(*args) self.wrapped_metric.update(*args, **kwargs) @@ -70,7 +70,7 @@ def compute(self) -> Any: """Wrap the compute call of the underlying metric.""" return self.wrapped_metric.compute() - def forward(self, *args: torch.Tensor, **kwargs: Dict[str, Any]) -> Any: + def forward(self, *args: torch.Tensor, **kwargs: dict[str, Any]) -> Any: """Wrap the forward call of the underlying metric.""" args = self._wrap_transform(*args) return self.wrapped_metric.forward(*args, **kwargs) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 42510b08747..ed7b940c315 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -43,7 +43,7 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): _assert_allclose(pl_res, ref_res, atol=atol) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert np.allclose( @@ -61,7 +61,7 @@ def _assert_tensor(tm_result: Any, key: Optional[str] = None) -> None: if isinstance(tm_result, Sequence): for plr in tm_result: _assert_tensor(plr) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert isinstance(tm_result[key], Tensor) @@ -74,7 +74,7 @@ def _assert_requires_grad(metric: Metric, tm_result: Any, key: Optional[str] = N if isinstance(tm_result, Sequence): for plr in tm_result: _assert_requires_grad(metric, plr, key=key) - elif isinstance(tm_result, Dict): + elif isinstance(tm_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert metric.is_differentiable == tm_result[key].requires_grad @@ -85,8 +85,8 @@ def _assert_requires_grad(metric: Metric, tm_result: Any, key: Optional[str] = N def _class_test( rank: int, world_size: int, - preds: Union[Tensor, list, List[Dict[str, Tensor]]], - target: Union[Tensor, list, List[Dict[str, Tensor]]], + preds: Union[Tensor, list, list[dict[str, Tensor]]], + target: Union[Tensor, list, list[dict[str, Tensor]]], metric_class: Metric, reference_metric: Callable, dist_sync_on_step: bool, @@ -252,7 +252,7 @@ def _class_test( def _functional_test( preds: Union[Tensor, list], - target: Union[Tensor, list, List[Dict[str, Tensor]]], + target: Union[Tensor, list, list[dict[str, Tensor]]], metric_functional: Callable, reference_metric: Callable, metric_args: Optional[dict] = None, @@ -317,7 +317,7 @@ def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, - target: Union[Tensor, List[Dict[str, Tensor]]], + target: Union[Tensor, list[dict[str, Tensor]]], device: str = "cpu", dtype: torch.dtype = torch.half, **kwargs_update: Any, @@ -420,8 +420,8 @@ def run_functional_metric_test( def run_class_metric_test( self, ddp: bool, - preds: Union[Tensor, List[Dict]], - target: Union[Tensor, List[Dict]], + preds: Union[Tensor, list[dict]], + target: Union[Tensor, list[dict]], metric_class: Metric, reference_metric: Callable, dist_sync_on_step: bool = False, @@ -685,7 +685,7 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: return x -def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]: +def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> tuple[Tensor, Tensor]: """Remove samples that are equal to the ignore_index in comparison functions. Example: @@ -704,7 +704,7 @@ def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[in def remove_ignore_index_groups( target: Tensor, preds: Tensor, groups: Tensor, ignore_index: Optional[int] -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Version of the remove_ignore_index which includes groups.""" if ignore_index is not None: idx = target == ignore_index diff --git a/tests/unittests/audio/test_dnsmos.py b/tests/unittests/audio/test_dnsmos.py index c3c1df7a03b..c8239fe4776 100644 --- a/tests/unittests/audio/test_dnsmos.py +++ b/tests/unittests/audio/test_dnsmos.py @@ -43,7 +43,7 @@ class InferenceSession: # type:ignore """Dummy InferenceSession.""" - def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + def __init__(self, **kwargs: dict[str, Any]) -> None: ... SAMPLING_RATE = 16000 @@ -82,7 +82,7 @@ def _get_polyfit_val(self, sig, bak, ovr, is_personalized): return sig_poly, bak_poly, ovr_poly - def __call__(self, aud, input_fs, is_personalized) -> Dict[str, Any]: + def __call__(self, aud, input_fs, is_personalized) -> dict[str, Any]: fs = SAMPLING_RATE audio = librosa.resample(aud, orig_sr=input_fs, target_sr=fs) if input_fs != fs else aud actual_audio_len = len(audio) @@ -143,7 +143,7 @@ def _reference_metric_batch( personalized: bool, device: Optional[str] = None, # for tester reduce_mean: bool = False, - **kwargs: Dict[str, Any], # for tester + **kwargs: dict[str, Any], # for tester ): # download onnx first _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", torch.device("cpu")) @@ -170,7 +170,7 @@ def _reference_metric_batch( return score.reshape(*shape[:-1], 4).reshape(shape[:-1] + (4,)).numpy() -def _dnsmos_cheat(preds, target, **kwargs: Dict[str, Any]): +def _dnsmos_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as the deep_noise_suppression_mean_opinion_score doesn't need target return deep_noise_suppression_mean_opinion_score(preds, **kwargs) diff --git a/tests/unittests/audio/test_nisqa.py b/tests/unittests/audio/test_nisqa.py index a36ad0db1c6..8d4e47512c0 100644 --- a/tests/unittests/audio/test_nisqa.py +++ b/tests/unittests/audio/test_nisqa.py @@ -115,7 +115,7 @@ def _reference_metric(preds): return out.mean(dim=0) if mean else out.reshape(*preds.shape[:-1], 5) -def _nisqa_cheat(preds, target, **kwargs: Dict[str, Any]): +def _nisqa_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as non_intrusive_speech_quality_assessment does not need a target return non_intrusive_speech_quality_assessment(preds, **kwargs) @@ -156,7 +156,7 @@ def test_nisqa_functional(self, preds: Tensor, reference: Tensor, fs: int, devic @pytest.mark.parametrize("shape", [(3000,), (2, 3000), (1, 2, 3000), (2, 3, 1, 3000)]) -def test_shape(shape: Tuple[int]): +def test_shape(shape: tuple[int]): """Test output shape.""" preds = torch.rand(*shape) out = non_intrusive_speech_quality_assessment(preds, 16000) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 85baab5e045..b431b9e8d6c 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -57,7 +57,7 @@ def _reference_scipy_pit( target: Tensor, metric_func: Callable, eval_func: str, -) -> Tuple[Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: """Naive implementation of `Permutation Invariant Training` based on Scipy. Args: @@ -87,7 +87,7 @@ def _reference_scipy_pit( return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) -def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: return _reference_scipy_pit( preds=preds, target=target, @@ -96,7 +96,7 @@ def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Ten ) -def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: return _reference_scipy_pit( preds=preds, target=target, diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index e7370546478..b627e6b8642 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -30,7 +30,7 @@ def _reference_srmr_batch( - preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, reduce_mean: bool = False, **kwargs: Dict[str, Any] + preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, reduce_mean: bool = False, **kwargs: dict[str, Any] ): # shape: preds [BATCH_SIZE, Time] shape = preds.shape @@ -51,7 +51,7 @@ def _reference_srmr_batch( return srmr -def _speech_reverberation_modulation_energy_ratio_cheat(preds, target, **kwargs: Dict[str, Any]): +def _speech_reverberation_modulation_energy_ratio_cheat(preds, target, **kwargs: dict[str, Any]): # cheat the MetricTester as the speech_reverberation_modulation_energy_ratio doesn't need target return speech_reverberation_modulation_energy_ratio(preds, **kwargs) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 9e89d041fd7..42e64e21262 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -72,7 +72,7 @@ def _reference_fairlearn_binary(preds, target, groups, ignore_index): } -def _assert_tensor(pl_result: Dict[str, Tensor], key: Optional[str] = None) -> None: +def _assert_tensor(pl_result: dict[str, Tensor], key: Optional[str] = None) -> None: if isinstance(pl_result, dict) and key is None: for key, val in pl_result.items(): assert isinstance(val, Tensor), f"{key!r} is not a Tensor!" @@ -81,7 +81,7 @@ def _assert_tensor(pl_result: Dict[str, Tensor], key: Optional[str] = None) -> N def _assert_allclose( # todo: unify with the general assert_allclose - pl_result: Dict[str, Tensor], sk_result: Dict[str, Tensor], atol: float = 1e-8, key: Optional[str] = None + pl_result: dict[str, Tensor], sk_result: dict[str, Tensor], atol: float = 1e-8, key: Optional[str] = None ) -> None: if isinstance(pl_result, dict) and key is None: for (pl_key, pl_val), (sk_key, sk_val) in zip(pl_result.items(), sk_result.items()): diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 4c864d0e9af..a5d0767c342 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -180,7 +180,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 245fb4097fc..452a43f7ee5 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -214,7 +214,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index e14cb2c96a0..6161df001fa 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -34,7 +34,7 @@ class _Input(NamedTuple): preds: Tensor - target: List[Dict[str, Tensor]] + target: list[dict[str, Tensor]] ms: Tensor pan: Tensor pan_lr: Tensor diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index 4cb42cf36be..89c9113adac 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -32,7 +32,7 @@ class _Input(NamedTuple): preds: Tensor - target: List[Dict[str, Tensor]] + target: list[dict[str, Tensor]] ms: Tensor pan: Tensor pan_lr: Tensor diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index e2804ecebb9..6e5e3f3da50 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -33,7 +33,7 @@ class _InputImagesCaptions(NamedTuple): images: Tensor - captions: List[List[str]] + captions: list[list[str]] captions = [ diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 03d5429ba47..8075c5e6ce0 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -59,7 +59,7 @@ def _retrieval_aggregate( return aggregation(values, dim=dim) -def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: +def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> list[Union[Tensor, np.ndarray]]: """Extract group indexes. Given an integer :class:`~torch.Tensor` or `np.ndarray` `indexes`, return a :class:`~torch.Tensor` or @@ -151,7 +151,7 @@ def _compute_sklearn_metric( return np.array(0.0) -def _concat_tests(*tests: Tuple[Dict]) -> Dict: +def _concat_tests(*tests: tuple[dict]) -> dict: """Concat tests composed by a string and a list of arguments.""" assert len(tests), "`_concat_tests` expects at least an argument" assert all(tests[0]["argnames"] == x["argnames"] for x in tests[1:]), "the header must be the same for all tests" @@ -408,7 +408,7 @@ def _errors_test_class_metric( metric_class: Metric, message: str = "", metric_args: Optional[dict] = None, - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ): """Check types, parameters and errors. @@ -437,7 +437,7 @@ def _errors_test_functional_metric( target: Tensor, metric_functional: Metric, message: str = "", - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ): """Check types, parameters and errors. @@ -564,7 +564,7 @@ def run_metric_class_arguments_test( metric_class: Metric, message: str = "", metric_args: Optional[dict] = None, - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ) -> None: """Test that specific errors are raised for incorrect input.""" @@ -585,7 +585,7 @@ def run_functional_metric_arguments_test( target: Tensor, metric_functional: Callable, message: str = "", - exception_type: Type[Exception] = ValueError, + exception_type: type[Exception] = ValueError, kwargs_update: Optional[dict] = None, ) -> None: """Test that specific errors are raised for incorrect input.""" diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index 691334d135e..12f7e60f8e2 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -42,7 +42,7 @@ def _compute_precision_recall_curve( empty_target_action: str = "skip", reverse: bool = False, aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", -) -> Tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor]: """Compute metric with multiple iterations over every query predictions set. Didn't find a reliable implementation of precision-recall curve in Information Retrieval, diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index d580a62fa39..15723b529e0 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -48,7 +48,7 @@ def _assert_all_close_regardless_of_order( elif isinstance(pl_result, Sequence): for pl_res, ref_res in zip(pl_result, ref_result): _assert_allclose(pl_res, ref_res, atol=atol) - elif isinstance(pl_result, Dict): + elif isinstance(pl_result, dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") assert np.allclose( diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 6ef3f7390be..a4d5de091a0 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_cer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_cer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import cer except ImportError: diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 69e595465a7..d4c5a8c2d09 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -24,7 +24,7 @@ seed_all(42) -def _reference_jiwer_mer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_mer(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import compute_measures except ImportError: diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 37278b829f1..59657956cd0 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -21,7 +21,7 @@ from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wil(preds: Union[str, list[str]], target: Union[str, list[str]]): try: from jiwer import wil except ImportError: From 6132d57de5a800727e40f708d3f06952e3a7c952 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 7 Nov 2024 16:17:19 +0000 Subject: [PATCH 06/15] imports --- .github/assistant.py | 2 +- _samples/bert_score-own_model.py | 2 +- examples/audio/signal_to_noise_ratio.py | 1 - pyproject.toml | 3 --- setup.py | 2 +- src/torchmetrics/aggregation.py | 2 +- src/torchmetrics/audio/pit.py | 2 +- src/torchmetrics/classification/accuracy.py | 2 +- src/torchmetrics/classification/auroc.py | 2 +- src/torchmetrics/classification/average_precision.py | 2 +- src/torchmetrics/classification/calibration_error.py | 2 +- src/torchmetrics/classification/cohen_kappa.py | 2 +- src/torchmetrics/classification/confusion_matrix.py | 2 +- src/torchmetrics/classification/dice.py | 2 +- src/torchmetrics/classification/exact_match.py | 2 +- src/torchmetrics/classification/f_beta.py | 2 +- src/torchmetrics/classification/group_fairness.py | 2 +- src/torchmetrics/classification/hamming.py | 2 +- src/torchmetrics/classification/hinge.py | 2 +- src/torchmetrics/classification/jaccard.py | 2 +- src/torchmetrics/classification/matthews_corrcoef.py | 2 +- src/torchmetrics/classification/negative_predictive_value.py | 2 +- src/torchmetrics/classification/precision_fixed_recall.py | 2 +- src/torchmetrics/classification/precision_recall.py | 2 +- src/torchmetrics/classification/precision_recall_curve.py | 2 +- src/torchmetrics/classification/recall_fixed_precision.py | 2 +- src/torchmetrics/classification/roc.py | 2 +- src/torchmetrics/classification/sensitivity_specificity.py | 2 +- src/torchmetrics/classification/specificity.py | 2 +- src/torchmetrics/classification/specificity_sensitivity.py | 2 +- src/torchmetrics/classification/stat_scores.py | 2 +- src/torchmetrics/clustering/adjusted_mutual_info_score.py | 2 +- src/torchmetrics/clustering/adjusted_rand_score.py | 2 +- src/torchmetrics/clustering/calinski_harabasz_score.py | 2 +- src/torchmetrics/clustering/davies_bouldin_score.py | 2 +- src/torchmetrics/clustering/dunn_index.py | 2 +- src/torchmetrics/clustering/fowlkes_mallows_index.py | 2 +- .../clustering/homogeneity_completeness_v_measure.py | 2 +- src/torchmetrics/clustering/mutual_info_score.py | 2 +- src/torchmetrics/clustering/normalized_mutual_info_score.py | 2 +- src/torchmetrics/clustering/rand_score.py | 2 +- src/torchmetrics/collections.py | 2 +- src/torchmetrics/detection/_mean_ap.py | 2 +- src/torchmetrics/detection/helpers.py | 2 +- src/torchmetrics/detection/iou.py | 2 +- src/torchmetrics/detection/mean_ap.py | 2 +- src/torchmetrics/functional/audio/_deprecated.py | 2 +- src/torchmetrics/functional/audio/dnsmos.py | 2 +- src/torchmetrics/functional/audio/nisqa.py | 2 +- src/torchmetrics/functional/audio/pit.py | 2 +- src/torchmetrics/functional/audio/sdr.py | 2 +- src/torchmetrics/functional/audio/srmr.py | 2 +- src/torchmetrics/functional/classification/auroc.py | 2 +- .../functional/classification/average_precision.py | 2 +- .../functional/classification/calibration_error.py | 2 +- src/torchmetrics/functional/classification/confusion_matrix.py | 2 +- src/torchmetrics/functional/classification/exact_match.py | 2 +- src/torchmetrics/functional/classification/group_fairness.py | 2 +- src/torchmetrics/functional/classification/hinge.py | 2 +- .../functional/classification/precision_fixed_recall.py | 2 +- .../functional/classification/precision_recall_curve.py | 2 +- src/torchmetrics/functional/classification/ranking.py | 2 +- .../functional/classification/recall_fixed_precision.py | 2 +- src/torchmetrics/functional/classification/roc.py | 2 +- .../functional/classification/sensitivity_specificity.py | 2 +- .../functional/classification/specificity_sensitivity.py | 2 +- src/torchmetrics/functional/classification/stat_scores.py | 2 +- src/torchmetrics/functional/clustering/dunn_index.py | 1 - .../functional/clustering/fowlkes_mallows_index.py | 1 - .../clustering/homogeneity_completeness_v_measure.py | 1 - .../functional/detection/_panoptic_quality_common.py | 2 +- src/torchmetrics/functional/image/_deprecated.py | 2 +- src/torchmetrics/functional/image/d_lambda.py | 1 - src/torchmetrics/functional/image/d_s.py | 2 +- src/torchmetrics/functional/image/ergas.py | 1 - src/torchmetrics/functional/image/gradients.py | 1 - src/torchmetrics/functional/image/lpips.py | 2 +- src/torchmetrics/functional/image/perceptual_path_length.py | 2 +- src/torchmetrics/functional/image/psnr.py | 2 +- src/torchmetrics/functional/image/psnrb.py | 1 - src/torchmetrics/functional/image/rase.py | 1 - src/torchmetrics/functional/image/rmse_sw.py | 2 +- src/torchmetrics/functional/image/sam.py | 1 - src/torchmetrics/functional/image/scc.py | 2 +- src/torchmetrics/functional/image/ssim.py | 2 +- src/torchmetrics/functional/image/tv.py | 2 +- src/torchmetrics/functional/image/uqi.py | 2 +- src/torchmetrics/functional/image/utils.py | 2 +- src/torchmetrics/functional/multimodal/clip_iqa.py | 2 +- src/torchmetrics/functional/multimodal/clip_score.py | 2 +- src/torchmetrics/functional/nominal/utils.py | 2 +- src/torchmetrics/functional/pairwise/helpers.py | 2 +- src/torchmetrics/functional/regression/cosine_similarity.py | 2 +- src/torchmetrics/functional/regression/csi.py | 2 +- src/torchmetrics/functional/regression/explained_variance.py | 2 +- src/torchmetrics/functional/regression/kendall.py | 2 +- src/torchmetrics/functional/regression/kl_divergence.py | 2 +- src/torchmetrics/functional/regression/log_cosh.py | 1 - src/torchmetrics/functional/regression/log_mse.py | 2 +- src/torchmetrics/functional/regression/mae.py | 2 +- src/torchmetrics/functional/regression/mape.py | 2 +- src/torchmetrics/functional/regression/mse.py | 2 +- src/torchmetrics/functional/regression/nrmse.py | 2 +- src/torchmetrics/functional/regression/pearson.py | 1 - src/torchmetrics/functional/regression/r2.py | 2 +- src/torchmetrics/functional/regression/spearman.py | 1 - src/torchmetrics/functional/regression/symmetric_mape.py | 2 +- src/torchmetrics/functional/regression/tweedie_deviance.py | 1 - src/torchmetrics/functional/regression/wmape.py | 1 - src/torchmetrics/functional/retrieval/_deprecated.py | 2 +- .../functional/retrieval/precision_recall_curve.py | 2 +- src/torchmetrics/functional/segmentation/dice.py | 2 +- src/torchmetrics/functional/segmentation/hausdorff_distance.py | 2 +- src/torchmetrics/functional/segmentation/mean_iou.py | 1 - src/torchmetrics/functional/segmentation/utils.py | 2 +- src/torchmetrics/functional/shape/procrustes.py | 2 +- src/torchmetrics/functional/text/_deprecated.py | 2 +- src/torchmetrics/functional/text/bert.py | 2 +- src/torchmetrics/functional/text/bleu.py | 2 +- src/torchmetrics/functional/text/cer.py | 2 +- src/torchmetrics/functional/text/chrf.py | 2 +- src/torchmetrics/functional/text/eed.py | 2 +- src/torchmetrics/functional/text/helper.py | 2 +- src/torchmetrics/functional/text/helper_embedding_metric.py | 2 +- src/torchmetrics/functional/text/infolm.py | 2 +- src/torchmetrics/functional/text/mer.py | 2 +- src/torchmetrics/functional/text/perplexity.py | 2 +- src/torchmetrics/functional/text/rouge.py | 2 +- src/torchmetrics/functional/text/sacre_bleu.py | 2 +- src/torchmetrics/functional/text/squad.py | 2 +- src/torchmetrics/functional/text/ter.py | 2 +- src/torchmetrics/functional/text/wer.py | 2 +- src/torchmetrics/functional/text/wil.py | 2 +- src/torchmetrics/functional/text/wip.py | 2 +- src/torchmetrics/image/_deprecated.py | 2 +- src/torchmetrics/image/d_lambda.py | 2 +- src/torchmetrics/image/d_s.py | 2 +- src/torchmetrics/image/ergas.py | 2 +- src/torchmetrics/image/fid.py | 2 +- src/torchmetrics/image/inception.py | 2 +- src/torchmetrics/image/kid.py | 2 +- src/torchmetrics/image/lpip.py | 2 +- src/torchmetrics/image/mifid.py | 2 +- src/torchmetrics/image/perceptual_path_length.py | 2 +- src/torchmetrics/image/psnr.py | 2 +- src/torchmetrics/image/qnr.py | 2 +- src/torchmetrics/image/rase.py | 2 +- src/torchmetrics/image/rmse_sw.py | 2 +- src/torchmetrics/image/sam.py | 2 +- src/torchmetrics/image/ssim.py | 2 +- src/torchmetrics/image/tv.py | 2 +- src/torchmetrics/image/uqi.py | 2 +- src/torchmetrics/metric.py | 2 +- src/torchmetrics/multimodal/clip_iqa.py | 2 +- src/torchmetrics/multimodal/clip_score.py | 2 +- src/torchmetrics/nominal/fleiss_kappa.py | 2 +- src/torchmetrics/regression/cosine_similarity.py | 2 +- src/torchmetrics/regression/csi.py | 2 +- src/torchmetrics/regression/kendall.py | 2 +- src/torchmetrics/regression/pearson.py | 2 +- src/torchmetrics/regression/spearman.py | 2 +- src/torchmetrics/retrieval/base.py | 2 +- src/torchmetrics/retrieval/precision_recall_curve.py | 2 +- src/torchmetrics/segmentation/dice.py | 2 +- src/torchmetrics/segmentation/hausdorff_distance.py | 2 +- src/torchmetrics/text/bert.py | 2 +- src/torchmetrics/text/cer.py | 2 +- src/torchmetrics/text/chrf.py | 2 +- src/torchmetrics/text/edit.py | 2 +- src/torchmetrics/text/eed.py | 2 +- src/torchmetrics/text/infolm.py | 2 +- src/torchmetrics/text/mer.py | 2 +- src/torchmetrics/text/perplexity.py | 2 +- src/torchmetrics/text/rouge.py | 2 +- src/torchmetrics/text/squad.py | 2 +- src/torchmetrics/text/ter.py | 2 +- src/torchmetrics/text/wer.py | 2 +- src/torchmetrics/text/wil.py | 2 +- src/torchmetrics/text/wip.py | 2 +- src/torchmetrics/utilities/checks.py | 2 +- src/torchmetrics/utilities/compute.py | 2 +- src/torchmetrics/utilities/data.py | 2 +- src/torchmetrics/utilities/distributed.py | 2 +- src/torchmetrics/utilities/enums.py | 1 - src/torchmetrics/utilities/plot.py | 2 +- src/torchmetrics/wrappers/bootstrapping.py | 2 +- src/torchmetrics/wrappers/classwise.py | 2 +- src/torchmetrics/wrappers/feature_share.py | 2 +- src/torchmetrics/wrappers/minmax.py | 2 +- src/torchmetrics/wrappers/multioutput.py | 2 +- src/torchmetrics/wrappers/multitask.py | 2 +- src/torchmetrics/wrappers/tracker.py | 2 +- src/torchmetrics/wrappers/transformations.py | 2 +- tests/unittests/_helpers/testers.py | 2 +- tests/unittests/audio/test_dnsmos.py | 2 +- tests/unittests/audio/test_nisqa.py | 2 +- tests/unittests/audio/test_pit.py | 2 +- tests/unittests/audio/test_srmr.py | 2 +- tests/unittests/classification/test_group_fairness.py | 2 +- tests/unittests/detection/test_modified_panoptic_quality.py | 2 +- tests/unittests/detection/test_panoptic_quality.py | 2 +- tests/unittests/image/test_d_s.py | 2 +- tests/unittests/image/test_qnr.py | 2 +- tests/unittests/multimodal/test_clip_score.py | 2 +- tests/unittests/retrieval/helpers.py | 2 +- tests/unittests/retrieval/test_precision_recall_curve.py | 2 +- tests/unittests/text/_helpers.py | 2 +- tests/unittests/text/test_cer.py | 2 +- tests/unittests/text/test_mer.py | 2 +- tests/unittests/text/test_wer.py | 2 +- tests/unittests/text/test_wil.py | 2 +- tests/unittests/text/test_wip.py | 2 +- 212 files changed, 194 insertions(+), 214 deletions(-) diff --git a/.github/assistant.py b/.github/assistant.py index 1718727c452..174ed2117a4 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -16,7 +16,7 @@ import os import re import sys -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import fire from packaging.version import parse diff --git a/_samples/bert_score-own_model.py b/_samples/bert_score-own_model.py index 74799c41acc..982d7f63876 100644 --- a/_samples/bert_score-own_model.py +++ b/_samples/bert_score-own_model.py @@ -18,7 +18,7 @@ """ from pprint import pprint -from typing import Dict, List, Union +from typing import Union import torch from torch import Tensor, nn diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index b203efb87d5..7099fc08d2b 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -8,7 +8,6 @@ # %% # Import necessary libraries -from typing import Tuple import matplotlib.animation as animation import matplotlib.pyplot as plt diff --git a/pyproject.toml b/pyproject.toml index ca69b1c1de3..9aa3f5c4e53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,9 +77,6 @@ lint.per-file-ignores."tests/**" = [ "S101", "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue ] -lint.unfixable = [ - "F401", -] # Unlike Flake8, default to a complexity level of 10. lint.mccabe.max-complexity = 10 # Use Google-style docstrings. diff --git a/setup.py b/setup.py index dc66936b98a..915045028ac 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from importlib.util import module_from_spec, spec_from_file_location from itertools import chain from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union from pkg_resources import Requirement, yield_lines from setuptools import find_packages, setup diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index ae30429bc20..312197ccbc4 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 56cd28b5ae0..6c28738a3f9 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 1baec43a9e5..bc1a8bb5e36 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index e82498b221b..8e5a69092fb 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index e37f6989fa3..221f400918c 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 7de354e54cd..3cd760d4c17 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 154cae505c1..aa1d1d0780c 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index f713f6c1c40..b5f304dbd82 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Type +from typing import Any, Optional import torch from torch import Tensor diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 281324767af..ad160aa0d09 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Optional, Tuple, Union, no_type_check +from typing import Any, Callable, Optional, Union, no_type_check import torch from torch import Tensor diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 189b5b44822..e167b22219c 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 46b1dd6d297..dcbe8a8b69d 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 0235ef123af..d966063856c 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index c0dc94d3c21..183af336ae8 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 4fed17aa0f1..878ea271049 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 7098db7896c..5f9a15b2e4f 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index c84ebed69d3..ea1b7a23cb5 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py index a4bfb9bc4c4..cdff97f86e2 100644 --- a/src/torchmetrics/classification/negative_predictive_value.py +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index 5ce97e7effc..a17f19aa39b 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index b9b790008cd..0bd6f8b0d99 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 86f2c66967b..0149f78bbd4 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 88a4fbdd74f..f34b3bb580a 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 73be641c66d..cf22c8c5646 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index bce575a1014..23851affe24 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 59b7baac794..dab5fde8a60 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index 2fd7b4c3f70..d8ed0f27cd4 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 3a18a8b481e..483cbabda25 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py index 94d66609d14..966b2d00540 100644 --- a/src/torchmetrics/clustering/adjusted_mutual_info_score.py +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 5f202536bae..3f614c25c69 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index 48463b94988..4a3d25138f2 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index c2b30c93e01..98f373b9558 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index ddc6b1867ba..db635dee23a 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py index 4c82f892da3..276c28f5456 100644 --- a/src/torchmetrics/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index f9b7b5cc5c7..b610f134b85 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 8d206ed8886..549c2f8376b 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/normalized_mutual_info_score.py b/src/torchmetrics/clustering/normalized_mutual_info_score.py index f829b3a3512..eedda19f784 100644 --- a/src/torchmetrics/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/clustering/normalized_mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index c949dd06e5c..d625dec55ac 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 930540473bc..23cf92a7f9c 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -15,7 +15,7 @@ from collections import OrderedDict from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index d9bba820185..5704fae1209 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np import torch diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index ae4c6c2b88c..ddb54463f54 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Dict, Literal, Tuple, Union +from typing import Literal, Union from torch import Tensor diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 9579eeae06d..26c48bd42c1 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 2a3c1095a53..f8234a08d69 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -16,7 +16,7 @@ import json from collections.abc import Sequence from types import ModuleType -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Optional, Union import numpy as np import torch diff --git a/src/torchmetrics/functional/audio/_deprecated.py b/src/torchmetrics/functional/audio/_deprecated.py index 4b31c5db37d..ff3c18e5458 100644 --- a/src/torchmetrics/functional/audio/_deprecated.py +++ b/src/torchmetrics/functional/audio/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py index 13935d9cdd2..2e23fba23a2 100644 --- a/src/torchmetrics/functional/audio/dnsmos.py +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import lru_cache -from typing import Any, Dict, Optional +from typing import Any, Optional import numpy as np import torch diff --git a/src/torchmetrics/functional/audio/nisqa.py b/src/torchmetrics/functional/audio/nisqa.py index a900d538d59..8718c97d8e6 100644 --- a/src/torchmetrics/functional/audio/nisqa.py +++ b/src/torchmetrics/functional/audio/nisqa.py @@ -40,7 +40,7 @@ import os import warnings from functools import lru_cache -from typing import Any, Dict, Tuple +from typing import Any import numpy as np import torch diff --git a/src/torchmetrics/functional/audio/pit.py b/src/torchmetrics/functional/audio/pit.py index 6fc431811c1..2d126c7000f 100644 --- a/src/torchmetrics/functional/audio/pit.py +++ b/src/torchmetrics/functional/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import permutations -from typing import Any, Callable, Tuple +from typing import Any, Callable import numpy as np import torch diff --git a/src/torchmetrics/functional/audio/sdr.py b/src/torchmetrics/functional/audio/sdr.py index a68c3e5f047..85da9597998 100644 --- a/src/torchmetrics/functional/audio/sdr.py +++ b/src/torchmetrics/functional/audio/sdr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index daf7befed02..20ad898fe8d 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -17,7 +17,7 @@ from functools import lru_cache from math import ceil, pi -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 03bd8ff3aa0..c4c55769be5 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index cd941e0a9df..ff7a46f0a11 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 9d994679a42..ca55bb2f79b 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 31d9bbb7f46..92059072490 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 2b0339e0adc..d22b4cc6f89 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/group_fairness.py b/src/torchmetrics/functional/classification/group_fairness.py index 2d5da8582ed..5e52e94733f 100644 --- a/src/torchmetrics/functional/classification/group_fairness.py +++ b/src/torchmetrics/functional/classification/group_fairness.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from typing_extensions import Literal diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index d08df7d550d..8fe7cf840b8 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/classification/precision_fixed_recall.py b/src/torchmetrics/functional/classification/precision_fixed_recall.py index d16beba4664..078d970299e 100644 --- a/src/torchmetrics/functional/classification/precision_fixed_recall.py +++ b/src/torchmetrics/functional/classification/precision_fixed_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index c498ffcc864..f1765916b36 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index d78fe807a88..57dd389ec3a 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/recall_fixed_precision.py b/src/torchmetrics/functional/classification/recall_fixed_precision.py index 72faf8b8b6d..196209bce21 100644 --- a/src/torchmetrics/functional/classification/recall_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_fixed_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index f2374778f01..2ce36bb3643 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index bc590a4c6b9..940691a51f7 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 3f93cea467d..1f252cdd98e 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index d6079de55d3..e9d78b1a1b7 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/clustering/dunn_index.py b/src/torchmetrics/functional/clustering/dunn_index.py index 05f7c87df33..ac073b7c273 100644 --- a/src/torchmetrics/functional/clustering/dunn_index.py +++ b/src/torchmetrics/functional/clustering/dunn_index.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import combinations -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py index c2820e9001a..88b9288f9ed 100644 --- a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py index 7eb7b478430..a85d9ab2a11 100644 --- a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index 9b6660e5629..16d0463ba45 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Collection, Iterator -from typing import Dict, List, Optional, Set, Tuple, cast +from typing import Optional, cast import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 55485f47639..8ee36b77a43 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Optional, Tuple, Union +from typing import Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/image/d_lambda.py b/src/torchmetrics/functional/image/d_lambda.py index 668dde7c0fc..478455c0a68 100644 --- a/src/torchmetrics/functional/image/d_lambda.py +++ b/src/torchmetrics/functional/image/d_lambda.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index f8f54873dbe..1c2797a2aab 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index c69773b06ba..ae940250dbe 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/gradients.py b/src/torchmetrics/functional/image/gradients.py index 683c67c153d..68045663fa9 100644 --- a/src/torchmetrics/functional/image/gradients.py +++ b/src/torchmetrics/functional/image/gradients.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 44006874120..768116613ea 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -24,7 +24,7 @@ # License under BSD 2-clause import inspect import os -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Union import torch from torch import Tensor, nn diff --git a/src/torchmetrics/functional/image/perceptual_path_length.py b/src/torchmetrics/functional/image/perceptual_path_length.py index b1a9e3f7857..035425539d2 100644 --- a/src/torchmetrics/functional/image/perceptual_path_length.py +++ b/src/torchmetrics/functional/image/perceptual_path_length.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import torch from torch import Tensor, nn diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 01348c425c3..e058d34e7b6 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index 1b67df53519..4a8469e19fe 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index fd30d455724..832a759b562 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index 6d9b9eae235..3b0eaf6221f 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/sam.py b/src/torchmetrics/functional/image/sam.py index 82c304c9543..21efde6f9c4 100644 --- a/src/torchmetrics/functional/image/sam.py +++ b/src/torchmetrics/functional/image/sam.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index f0c3db35295..d680a22b676 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 9dbf38b0627..6a1ae9deefb 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index 21c7b5f6f31..a0f310d498d 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index 1b0711bb969..30b5e781e55 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor, nn diff --git a/src/torchmetrics/functional/image/utils.py b/src/torchmetrics/functional/image/utils.py index 13dee62d2ae..24ed9cd0de8 100644 --- a/src/torchmetrics/functional/image/utils.py +++ b/src/torchmetrics/functional/image/utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index b22b3472b82..99985c5b022 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Literal, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index ead25db6f5a..bae7cb7b849 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py index d98d7e8806c..9d8dd8dc4af 100644 --- a/src/torchmetrics/functional/nominal/utils.py +++ b/src/torchmetrics/functional/nominal/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/pairwise/helpers.py b/src/torchmetrics/functional/pairwise/helpers.py index a0a0dbf2f9f..703b5ddb083 100644 --- a/src/torchmetrics/functional/pairwise/helpers.py +++ b/src/torchmetrics/functional/pairwise/helpers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional from torch import Tensor diff --git a/src/torchmetrics/functional/regression/cosine_similarity.py b/src/torchmetrics/functional/regression/cosine_similarity.py index c90885aab86..c57623931a4 100644 --- a/src/torchmetrics/functional/regression/cosine_similarity.py +++ b/src/torchmetrics/functional/regression/cosine_similarity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/csi.py b/src/torchmetrics/functional/regression/csi.py index f58f615e62b..65d38e6f573 100644 --- a/src/torchmetrics/functional/regression/csi.py +++ b/src/torchmetrics/functional/regression/csi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/explained_variance.py b/src/torchmetrics/functional/regression/explained_variance.py index bf93bf94112..d401bb5a349 100644 --- a/src/torchmetrics/functional/regression/explained_variance.py +++ b/src/torchmetrics/functional/regression/explained_variance.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 804c4cb10ef..4d5b2028a1d 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/kl_divergence.py b/src/torchmetrics/functional/regression/kl_divergence.py index 8d6aee5c001..c4ed9efb25a 100644 --- a/src/torchmetrics/functional/regression/kl_divergence.py +++ b/src/torchmetrics/functional/regression/kl_divergence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/log_cosh.py b/src/torchmetrics/functional/regression/log_cosh.py index a0931bcaecb..e13ee2a3e7a 100644 --- a/src/torchmetrics/functional/regression/log_cosh.py +++ b/src/torchmetrics/functional/regression/log_cosh.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/log_mse.py b/src/torchmetrics/functional/regression/log_mse.py index 34f3d9d71de..7c3a3585127 100644 --- a/src/torchmetrics/functional/regression/log_mse.py +++ b/src/torchmetrics/functional/regression/log_mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/mae.py b/src/torchmetrics/functional/regression/mae.py index 17db54f242c..8774d67fbe1 100644 --- a/src/torchmetrics/functional/regression/mae.py +++ b/src/torchmetrics/functional/regression/mae.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/mape.py b/src/torchmetrics/functional/regression/mape.py index 89cbd865c43..c109e898eee 100644 --- a/src/torchmetrics/functional/regression/mape.py +++ b/src/torchmetrics/functional/regression/mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index 09f4fadb490..4ea9490c841 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index a7e8c28a1af..ccbe81c333d 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 6dffb193340..e8fa9ef3408 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/r2.py b/src/torchmetrics/functional/regression/r2.py index f8d036ece99..f52d082fc74 100644 --- a/src/torchmetrics/functional/regression/r2.py +++ b/src/torchmetrics/functional/regression/r2.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 7e57b93c06f..b72bd7f03f7 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/symmetric_mape.py b/src/torchmetrics/functional/regression/symmetric_mape.py index 9d919c13b2a..3fca98258c7 100644 --- a/src/torchmetrics/functional/regression/symmetric_mape.py +++ b/src/torchmetrics/functional/regression/symmetric_mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/tweedie_deviance.py b/src/torchmetrics/functional/regression/tweedie_deviance.py index e369d7d5ad5..328829dffe3 100644 --- a/src/torchmetrics/functional/regression/tweedie_deviance.py +++ b/src/torchmetrics/functional/regression/tweedie_deviance.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/regression/wmape.py b/src/torchmetrics/functional/regression/wmape.py index c3834047777..1781f306608 100644 --- a/src/torchmetrics/functional/regression/wmape.py +++ b/src/torchmetrics/functional/regression/wmape.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/retrieval/_deprecated.py b/src/torchmetrics/functional/retrieval/_deprecated.py index 4621b7f5ed2..6284470d1b2 100644 --- a/src/torchmetrics/functional/retrieval/_deprecated.py +++ b/src/torchmetrics/functional/retrieval/_deprecated.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional from torch import Tensor diff --git a/src/torchmetrics/functional/retrieval/precision_recall_curve.py b/src/torchmetrics/functional/retrieval/precision_recall_curve.py index ed204d136cd..269dca04016 100644 --- a/src/torchmetrics/functional/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/functional/retrieval/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index a9ecf99f69e..de9a8068bf9 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py index 18d7c45ddff..ca58e5f036c 100644 --- a/src/torchmetrics/functional/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 5acbd445851..9cfed0fa1bf 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torch import Tensor diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index b6ee4c1195e..b6dd123a4d6 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/shape/procrustes.py b/src/torchmetrics/functional/shape/procrustes.py index c17871ed251..e72ca339306 100644 --- a/src/torchmetrics/functional/shape/procrustes.py +++ b/src/torchmetrics/functional/shape/procrustes.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch import Tensor, linalg diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 8ff582f52fc..c34419ac613 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -1,6 +1,6 @@ import os from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 4380f447da7..e4a38b29755 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -16,7 +16,7 @@ import urllib from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/bleu.py b/src/torchmetrics/functional/text/bleu.py index 6ed88ae05c8..52b8bb17432 100644 --- a/src/torchmetrics/functional/text/bleu.py +++ b/src/torchmetrics/functional/text/bleu.py @@ -18,7 +18,7 @@ # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter from collections.abc import Sequence -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/cer.py b/src/torchmetrics/functional/text/cer.py index 2b5e10f6c55..e9ee191c449 100644 --- a/src/torchmetrics/functional/text/cer.py +++ b/src/torchmetrics/functional/text/cer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 70f09142d7c..66ee3b7a449 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -23,7 +23,7 @@ from collections import defaultdict from collections.abc import Sequence from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index abdc4e01de8..ca562c5212a 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -90,7 +90,7 @@ import unicodedata from collections.abc import Sequence from math import inf -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from torch import Tensor, stack, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index 3266876d129..a61f06232ab 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -31,7 +31,7 @@ import math from collections.abc import Sequence from enum import Enum, unique -from typing import Dict, List, Tuple, Union +from typing import Union # Tercom-inspired limits _BEAM_WIDTH = 25 diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index ad57c13a2f4..17c89558163 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -14,7 +14,7 @@ import math import os from collections import Counter, defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 33c20f97e53..48a42e32b2f 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -14,7 +14,7 @@ import os from collections.abc import Sequence from enum import unique -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/mer.py b/src/torchmetrics/functional/text/mer.py index 34d0eaff0ed..46f30331b85 100644 --- a/src/torchmetrics/functional/text/mer.py +++ b/src/torchmetrics/functional/text/mer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index 5a58a8da4c0..5931b7b6944 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index e9f926b4d66..791de8dbe26 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -14,7 +14,7 @@ import re from collections import Counter from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index 28e36b2ea3b..a398acaf676 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -42,7 +42,7 @@ import tempfile from collections.abc import Sequence from functools import partial -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Any, ClassVar, Optional import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/squad.py b/src/torchmetrics/functional/text/squad.py index d317a7ce806..c52f0860e14 100644 --- a/src/torchmetrics/functional/text/squad.py +++ b/src/torchmetrics/functional/text/squad.py @@ -17,7 +17,7 @@ import re import string from collections import Counter -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 3cb862d046a..02617f2cbaf 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -36,7 +36,7 @@ import re from collections.abc import Iterator, Sequence from functools import lru_cache -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/wer.py b/src/torchmetrics/functional/text/wer.py index 0479ad0d945..b61bdb4c105 100644 --- a/src/torchmetrics/functional/text/wer.py +++ b/src/torchmetrics/functional/text/wer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/wil.py b/src/torchmetrics/functional/text/wil.py index e7ca50abec4..3d8c370facb 100644 --- a/src/torchmetrics/functional/text/wil.py +++ b/src/torchmetrics/functional/text/wil.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union from torch import Tensor, tensor diff --git a/src/torchmetrics/functional/text/wip.py b/src/torchmetrics/functional/text/wip.py index 2f6f635b053..77dae42e5ed 100644 --- a/src/torchmetrics/functional/text/wip.py +++ b/src/torchmetrics/functional/text/wip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Union from torch import Tensor, tensor diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 18f3e1840ba..cab7692f5e7 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from typing_extensions import Literal diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 0330d203e33..65824c19600 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 926929fe2be..f6b12efd5ad 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 6e8ba2624d8..bdf884c2719 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 1eac1bc9fcf..c15b3302af9 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 02c5cf04a35..fd11a6afe03 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 5aa3645bc77..f3c3d16bdd4 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index ba792ab5d19..811476dd75d 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, ClassVar, List, Optional, Union +from typing import Any, ClassVar, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 31105fd290b..6ce894cffd1 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index f6d909bf477..5c48b57c3fe 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union from torch import Tensor, nn diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index fa76d677133..5f00d21c7cb 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from functools import partial -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index 4d99396d9de..649a69617ef 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index dbe51d5d969..0a06041f94d 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index fdae49a72ba..6312174b1be 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index e1407120b2f..07449cd6939 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 1a2a6f858c8..034dabf7821 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 087b6a06cb4..eb6122fb41f 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index c2cf917301d..b094f14e851 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 3b8868adb54..2f0b3b617ed 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -21,7 +21,7 @@ from collections.abc import Generator, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index 947972ff02e..1378b203040 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index f4cf857bbf2..382e68db53a 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index cf6f9058326..715565a9fdc 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index f7bf201d7ea..7d2eec54470 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index 9014cb2daab..854ce500d7b 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, Optional import torch diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 8b651737e8a..e3266f23623 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 8ae82e17662..adff24aab9e 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 59755846cc9..83ea08d5f50 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/retrieval/base.py b/src/torchmetrics/retrieval/base.py index 94bb49982f7..d63be19ff1b 100644 --- a/src/torchmetrics/retrieval/base.py +++ b/src/torchmetrics/retrieval/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 5b6af216843..fb4d54a4f50 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index b28ade07f1e..70a38009ea2 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index 666790f914d..4e646bb6a93 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index 6df6f20eb6b..f6550749f6b 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 47370328dbc..e0337801916 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 88a791e9714..8b70e1b3269 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -19,7 +19,7 @@ import itertools from collections.abc import Iterator, Sequence -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/edit.py b/src/torchmetrics/text/edit.py index 060a8fc6b26..dd1750436b1 100644 --- a/src/torchmetrics/text/edit.py +++ b/src/torchmetrics/text/edit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index 9dfaaa9edcf..dd12a852ccf 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union from torch import Tensor, stack from typing_extensions import Literal diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index 6c1320248d7..ae585161bc2 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -13,7 +13,7 @@ # limitations under the License. import os from collections.abc import Sequence -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index b3f086e0d00..a898c9c4758 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/perplexity.py b/src/torchmetrics/text/perplexity.py index 950bb9ad449..af8d9e70795 100644 --- a/src/torchmetrics/text/perplexity.py +++ b/src/torchmetrics/text/perplexity.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/text/rouge.py b/src/torchmetrics/text/rouge.py index be8399f6fad..ec1ef711afc 100644 --- a/src/torchmetrics/text/rouge.py +++ b/src/torchmetrics/text/rouge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from torch import Tensor from typing_extensions import Literal diff --git a/src/torchmetrics/text/squad.py b/src/torchmetrics/text/squad.py index 124c87f06ce..a545da95803 100644 --- a/src/torchmetrics/text/squad.py +++ b/src/torchmetrics/text/squad.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index bb477a65bfb..3773e3af923 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index b96044ef1bc..93bc7b8da13 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index 0a0c9f7cd13..edd00b7f657 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index fc4ea5b8f15..b9b4351f548 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 6da328c01e7..a1f7e47632f 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -17,7 +17,7 @@ from collections.abc import Mapping, Sequence from functools import partial from time import perf_counter -from typing import Any, Callable, Dict, Optional, Tuple, no_type_check +from typing import Any, Callable, Optional, no_type_check from unittest.mock import Mock import torch diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index f16600f35b6..cbb648a8844 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index d3ec64d621a..2f98de35d19 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -13,7 +13,7 @@ # limitations under the License. import sys from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 150138e2198..31a1d0dca5b 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 14ec7135a86..155f1bb8f60 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type from lightning_utilities.core.enums import StrEnum from typing_extensions import Literal diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 98ef71aa4d2..e3ced02575c 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -14,7 +14,7 @@ from collections.abc import Generator, Sequence from itertools import product from math import ceil, floor, sqrt -from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check +from typing import Any, Optional, Union, no_type_check import numpy as np import torch diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 47192dad5f7..566c4d4ab66 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 517d4c060d0..c37e3c7fa2b 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -13,7 +13,7 @@ # limitations under the License. import typing from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from torch import Tensor diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index caeb2e93217..62302fff140 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from functools import lru_cache -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from torch.nn import Module diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index 236179a6cbe..25300e8fbe0 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 86f53af6908..12a135099a2 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index c0b312cbedd..66e9516591e 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from torch import Tensor, nn diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index a223f803397..0a8ca7eac1d 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/src/torchmetrics/wrappers/transformations.py b/src/torchmetrics/wrappers/transformations.py index d4fb0f270d0..8a86e3224af 100644 --- a/src/torchmetrics/wrappers/transformations.py +++ b/src/torchmetrics/wrappers/transformations.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index ed7b940c315..1622e4ad8a3 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -16,7 +16,7 @@ from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest diff --git a/tests/unittests/audio/test_dnsmos.py b/tests/unittests/audio/test_dnsmos.py index c8239fe4776..80607057467 100644 --- a/tests/unittests/audio/test_dnsmos.py +++ b/tests/unittests/audio/test_dnsmos.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Optional import numpy as np import pytest diff --git a/tests/unittests/audio/test_nisqa.py b/tests/unittests/audio/test_nisqa.py index 8d4e47512c0..06eac64710c 100644 --- a/tests/unittests/audio/test_nisqa.py +++ b/tests/unittests/audio/test_nisqa.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Any, Dict, Tuple +from typing import Any import pytest import torch diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index b431b9e8d6c..70c5b44c6ab 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Tuple +from typing import Callable import numpy as np import pytest diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index b627e6b8642..d3a18cca357 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Dict +from typing import Any import pytest import torch diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 42e64e21262..5f627676f09 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from functools import partial -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from unittest import mock import numpy as np diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index a5d0767c342..f4fe1d1ee06 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import numpy as np import pytest diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 452a43f7ee5..58287aa7fcd 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import numpy as np import pytest diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 6161df001fa..09a9675d380 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Dict, List, NamedTuple +from typing import NamedTuple import numpy as np import pytest diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index 89c9113adac..52d77896235 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Dict, List, NamedTuple +from typing import NamedTuple import pytest import torch diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 6e5e3f3da50..9e71a30ca0a 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import List, NamedTuple +from typing import NamedTuple import matplotlib import matplotlib.pyplot as plt diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 8075c5e6ce0..6747be475ed 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index 12f7e60f8e2..d8e9817bd51 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np import pytest diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index 15723b529e0..2b9f8381d42 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -15,7 +15,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import numpy as np import pytest diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index a4d5de091a0..99de422ba26 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.cer import char_error_rate diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index d4c5a8c2d09..56592c34762 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.mer import match_error_rate diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index bae4c91c11d..e57c51af5fd 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wer import word_error_rate diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 59657956cd0..9b1615e5ee6 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wil import word_information_lost diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index fe4bedb482c..b0393d12b29 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import pytest from torchmetrics.functional.text.wip import word_information_preserved From 3c2d8c12fcabfbe1c3a36955cd116adf0a577508 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 7 Nov 2024 18:03:46 +0000 Subject: [PATCH 07/15] Tensor --- src/torchmetrics/detection/mean_ap.py | 12 ++++++------ src/torchmetrics/regression/csi.py | 25 +++++++++++++------------ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index f8234a08d69..25e598ad57b 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -866,12 +866,12 @@ def _get_classes(self) -> list: def _get_coco_format( self, - labels: list[torch.Tensor], - boxes: Optional[list[torch.Tensor]] = None, - masks: Optional[list[torch.Tensor]] = None, - scores: Optional[list[torch.Tensor]] = None, - crowds: Optional[list[torch.Tensor]] = None, - area: Optional[list[torch.Tensor]] = None, + labels: list[Tensor], + boxes: Optional[list[Tensor]] = None, + masks: Optional[list[Tensor]] = None, + scores: Optional[list[Tensor]] = None, + crowds: Optional[list[Tensor]] = None, + area: Optional[list[Tensor]] = None, ) -> dict: """Transforms and returns all cached targets or predictions in COCO format. diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index 854ce500d7b..e54b4d8d6ea 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -14,6 +14,7 @@ from typing import Any, Optional import torch +from torch import Tensor from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update from torchmetrics.metric import Metric @@ -40,8 +41,8 @@ class CriticalSuccessIndex(Metric): Example: >>> import torch >>> from torchmetrics.regression import CriticalSuccessIndex - >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]]) - >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]]) + >>> x = Tensor([[0.2, 0.7], [0.9, 0.3]]) + >>> y = Tensor([[0.4, 0.2], [0.8, 0.6]]) >>> csi = CriticalSuccessIndex(0.5) >>> csi(x, y) tensor(0.3333) @@ -49,8 +50,8 @@ class CriticalSuccessIndex(Metric): Example: >>> import torch >>> from torchmetrics.regression import CriticalSuccessIndex - >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]]) - >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]]) + >>> x = Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]]) + >>> y = Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]]) >>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0) >>> csi(x, y) tensor([0.3333, 0.3333]) @@ -60,12 +61,12 @@ class CriticalSuccessIndex(Metric): is_differentiable: bool = False higher_is_better: bool = True - hits: torch.Tensor - misses: torch.Tensor - false_alarms: torch.Tensor - hits_list: list[torch.Tensor] - misses_list: list[torch.Tensor] - false_alarms_list: list[torch.Tensor] + hits:Tensor + misses: Tensor + false_alarms: Tensor + hits_list: list[Tensor] + misses_list: list[Tensor] + false_alarms_list: list[Tensor] def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -84,7 +85,7 @@ def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, ** self.add_state("misses_list", default=[], dist_reduce_fx="cat") self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat") - def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" hits, misses, false_alarms = _critical_success_index_update( preds, target, self.threshold, self.keep_sequence_dim @@ -98,7 +99,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: self.misses_list.append(misses) self.false_alarms_list.append(false_alarms) - def compute(self) -> torch.Tensor: + def compute(self) -> Tensor: """Compute critical success index over state.""" if self.keep_sequence_dim is None: hits = self.hits From 94631234786863ba74e78ae333d40564e9524ffc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 18:04:12 +0000 Subject: [PATCH 08/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/regression/csi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index e54b4d8d6ea..44ad0b8fad3 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -61,7 +61,7 @@ class CriticalSuccessIndex(Metric): is_differentiable: bool = False higher_is_better: bool = True - hits:Tensor + hits: Tensor misses: Tensor false_alarms: Tensor hits_list: list[Tensor] From 58ac4c47f7031568d78c9ff712dcc00db335fedb Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 10:59:12 +0000 Subject: [PATCH 09/15] Revert "Tensor" This reverts commit 3c2d8c12 --- src/torchmetrics/detection/mean_ap.py | 12 ++++++------ src/torchmetrics/regression/csi.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 25e598ad57b..f8234a08d69 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -866,12 +866,12 @@ def _get_classes(self) -> list: def _get_coco_format( self, - labels: list[Tensor], - boxes: Optional[list[Tensor]] = None, - masks: Optional[list[Tensor]] = None, - scores: Optional[list[Tensor]] = None, - crowds: Optional[list[Tensor]] = None, - area: Optional[list[Tensor]] = None, + labels: list[torch.Tensor], + boxes: Optional[list[torch.Tensor]] = None, + masks: Optional[list[torch.Tensor]] = None, + scores: Optional[list[torch.Tensor]] = None, + crowds: Optional[list[torch.Tensor]] = None, + area: Optional[list[torch.Tensor]] = None, ) -> dict: """Transforms and returns all cached targets or predictions in COCO format. diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index 44ad0b8fad3..37216587c4c 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -41,8 +41,8 @@ class CriticalSuccessIndex(Metric): Example: >>> import torch >>> from torchmetrics.regression import CriticalSuccessIndex - >>> x = Tensor([[0.2, 0.7], [0.9, 0.3]]) - >>> y = Tensor([[0.4, 0.2], [0.8, 0.6]]) + >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]]) + >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]]) >>> csi = CriticalSuccessIndex(0.5) >>> csi(x, y) tensor(0.3333) @@ -50,8 +50,8 @@ class CriticalSuccessIndex(Metric): Example: >>> import torch >>> from torchmetrics.regression import CriticalSuccessIndex - >>> x = Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]]) - >>> y = Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]]) + >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]]) + >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]]) >>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0) >>> csi(x, y) tensor([0.3333, 0.3333]) @@ -64,9 +64,9 @@ class CriticalSuccessIndex(Metric): hits: Tensor misses: Tensor false_alarms: Tensor - hits_list: list[Tensor] - misses_list: list[Tensor] - false_alarms_list: list[Tensor] + hits_list: list[torch.Tensor] + misses_list: list[torch.Tensor] + false_alarms_list: list[torch.Tensor] def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) From 79e9d9c0e2c9dc0c964b3832b6a64acee80846ff Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 11:00:52 +0000 Subject: [PATCH 10/15] List --- src/torchmetrics/detection/mean_ap.py | 14 +++++++------- src/torchmetrics/regression/csi.py | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index f8234a08d69..2567d7405f5 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -16,7 +16,7 @@ import json from collections.abc import Sequence from types import ModuleType -from typing import Any, Callable, ClassVar, Optional, Union +from typing import Any, Callable, ClassVar, Optional, Union, List import numpy as np import torch @@ -866,12 +866,12 @@ def _get_classes(self) -> list: def _get_coco_format( self, - labels: list[torch.Tensor], - boxes: Optional[list[torch.Tensor]] = None, - masks: Optional[list[torch.Tensor]] = None, - scores: Optional[list[torch.Tensor]] = None, - crowds: Optional[list[torch.Tensor]] = None, - area: Optional[list[torch.Tensor]] = None, + labels: List[Tensor], + boxes: Optional[List[Tensor]] = None, + masks: Optional[List[Tensor]] = None, + scores: Optional[List[Tensor]] = None, + crowds: Optional[List[Tensor]] = None, + area: Optional[List[Tensor]] = None, ) -> dict: """Transforms and returns all cached targets or predictions in COCO format. diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index 37216587c4c..6316ee3aa3b 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, List import torch from torch import Tensor @@ -64,9 +64,9 @@ class CriticalSuccessIndex(Metric): hits: Tensor misses: Tensor false_alarms: Tensor - hits_list: list[torch.Tensor] - misses_list: list[torch.Tensor] - false_alarms_list: list[torch.Tensor] + hits_list: List[Tensor] + misses_list: List[Tensor] + false_alarms_list: List[Tensor] def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: super().__init__(**kwargs) From 890155a9eab1a6e6188937956dbb6ba893727915 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:01:17 +0000 Subject: [PATCH 11/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/detection/mean_ap.py | 2 +- src/torchmetrics/regression/csi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 2567d7405f5..1bc89f35329 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -16,7 +16,7 @@ import json from collections.abc import Sequence from types import ModuleType -from typing import Any, Callable, ClassVar, Optional, Union, List +from typing import Any, Callable, ClassVar, List, Optional, Union import numpy as np import torch diff --git a/src/torchmetrics/regression/csi.py b/src/torchmetrics/regression/csi.py index 6316ee3aa3b..b5c7356aaab 100644 --- a/src/torchmetrics/regression/csi.py +++ b/src/torchmetrics/regression/csi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, List +from typing import Any, List, Optional import torch from torch import Tensor From a8b9541272f72d40f462e6260d1a2158b9704575 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 12:22:43 +0000 Subject: [PATCH 12/15] List... --- .../classification/calibration_error.py | 10 ++++----- .../classification/precision_recall_curve.py | 22 +++++++++---------- src/torchmetrics/classification/roc.py | 10 ++++----- .../classification/stat_scores.py | 10 ++++----- .../clustering/adjusted_mutual_info_score.py | 6 ++--- .../clustering/adjusted_rand_score.py | 6 ++--- .../clustering/calinski_harabasz_score.py | 6 ++--- .../clustering/davies_bouldin_score.py | 6 ++--- src/torchmetrics/clustering/dunn_index.py | 6 ++--- .../clustering/fowlkes_mallows_index.py | 6 ++--- .../homogeneity_completeness_v_measure.py | 14 ++++++------ .../clustering/mutual_info_score.py | 6 ++--- .../normalized_mutual_info_score.py | 6 ++--- src/torchmetrics/clustering/rand_score.py | 6 ++--- src/torchmetrics/detection/_mean_ap.py | 12 +++++----- src/torchmetrics/detection/iou.py | 6 ++--- src/torchmetrics/detection/mean_ap.py | 18 +++++++-------- .../functional/classification/auroc.py | 6 ++--- .../classification/average_precision.py | 6 ++--- .../classification/precision_recall_curve.py | 12 +++++----- .../functional/classification/roc.py | 12 +++++----- .../classification/sensitivity_specificity.py | 4 ++-- .../classification/specificity_sensitivity.py | 6 ++--- src/torchmetrics/functional/image/lpips.py | 4 ++-- src/torchmetrics/functional/image/ssim.py | 4 ++-- .../functional/multimodal/clip_score.py | 6 ++--- .../functional/regression/kendall.py | 8 +++---- .../functional/text/_deprecated.py | 4 ++-- src/torchmetrics/functional/text/bert.py | 6 ++--- src/torchmetrics/functional/text/chrf.py | 8 +++---- src/torchmetrics/functional/text/eed.py | 8 +++---- src/torchmetrics/functional/text/infolm.py | 6 ++--- src/torchmetrics/functional/text/rouge.py | 8 +++---- src/torchmetrics/functional/text/ter.py | 10 ++++----- src/torchmetrics/image/d_lambda.py | 6 ++--- src/torchmetrics/image/d_s.py | 10 ++++----- src/torchmetrics/image/ergas.py | 6 ++--- src/torchmetrics/image/kid.py | 6 ++--- src/torchmetrics/image/mifid.py | 6 ++--- src/torchmetrics/image/qnr.py | 10 ++++----- src/torchmetrics/image/rase.py | 6 ++--- src/torchmetrics/image/sam.py | 6 ++--- src/torchmetrics/image/ssim.py | 10 ++++----- src/torchmetrics/image/tv.py | 4 ++-- src/torchmetrics/image/uqi.py | 6 ++--- src/torchmetrics/metric.py | 6 ++--- src/torchmetrics/multimodal/clip_iqa.py | 4 ++-- src/torchmetrics/multimodal/clip_score.py | 4 ++-- src/torchmetrics/nominal/fleiss_kappa.py | 4 ++-- .../regression/cosine_similarity.py | 6 ++--- src/torchmetrics/regression/kendall.py | 6 ++--- src/torchmetrics/regression/pearson.py | 6 ++--- src/torchmetrics/regression/spearman.py | 6 ++--- src/torchmetrics/retrieval/base.py | 8 +++---- .../retrieval/precision_recall_curve.py | 8 +++---- src/torchmetrics/segmentation/dice.py | 8 +++---- src/torchmetrics/text/bert.py | 12 +++++----- src/torchmetrics/text/chrf.py | 4 ++-- src/torchmetrics/text/edit.py | 4 ++-- src/torchmetrics/text/eed.py | 4 ++-- src/torchmetrics/text/infolm.py | 10 ++++----- src/torchmetrics/text/ter.py | 4 ++-- src/torchmetrics/utilities/data.py | 4 ++-- src/torchmetrics/utilities/distributed.py | 8 +++---- src/torchmetrics/utilities/plot.py | 4 ++-- 65 files changed, 235 insertions(+), 235 deletions(-) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 3cd760d4c17..da2dc4d3d40 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -106,8 +106,8 @@ class BinaryCalibrationError(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - confidences: list[Tensor] - accuracies: list[Tensor] + confidences: List[Tensor] + accuracies: List[Tensor] def __init__( self, @@ -259,8 +259,8 @@ class MulticlassCalibrationError(Metric): plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" - confidences: list[Tensor] - accuracies: list[Tensor] + confidences: List[Tensor] + accuracies: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 0149f78bbd4..7a8430eeab2 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -130,8 +130,8 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] confmat: Tensor def __init__( @@ -323,8 +323,8 @@ class MulticlassPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] confmat: Tensor def __init__( @@ -374,14 +374,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, - curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: @@ -523,8 +523,8 @@ class MultilabelPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] confmat: Tensor def __init__( @@ -570,14 +570,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(state[0]) self.target.append(state[1]) - def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) def plot( self, - curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, ) -> _PLOT_OUT_TYPE: diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index cf22c8c5646..b56bc856744 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -287,14 +287,14 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" - def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, labels: Optional[list[str]] = None, @@ -449,14 +449,14 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): plot_upper_bound: float = 1.0 plot_legend_name: str = "Label" - def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: + def compute(self) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type] def plot( self, - curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]] = None, + curve: Optional[Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None, score: Optional[Union[Tensor, bool]] = None, ax: Optional[_AX_TYPE] = None, labels: Optional[list[str]] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 483cbabda25..3c55c431fa2 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -41,10 +41,10 @@ class _AbstractStatScores(Metric): - tp: Union[list[Tensor], Tensor] - fp: Union[list[Tensor], Tensor] - tn: Union[list[Tensor], Tensor] - fn: Union[list[Tensor], Tensor] + tp: Union[List[Tensor], Tensor] + fp: Union[List[Tensor], Tensor] + tn: Union[List[Tensor], Tensor] + fn: Union[List[Tensor], Tensor] # define common functions def _create_state( diff --git a/src/torchmetrics/clustering/adjusted_mutual_info_score.py b/src/torchmetrics/clustering/adjusted_mutual_info_score.py index 966b2d00540..ebcf4749d08 100644 --- a/src/torchmetrics/clustering/adjusted_mutual_info_score.py +++ b/src/torchmetrics/clustering/adjusted_mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union from torch import Tensor @@ -73,8 +73,8 @@ class AdjustedMutualInfoScore(MutualInfoScore): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] contingency: Tensor def __init__( diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 3f614c25c69..20278f74bc3 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -68,8 +68,8 @@ class AdjustedRandScore(Metric): full_state_update: bool = False plot_lower_bound: float = -0.5 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index 4a3d25138f2..c331fba7866 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -69,8 +69,8 @@ class CalinskiHarabaszScore(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: list[Tensor] - labels: list[Tensor] + data: List[Tensor] + labels: List[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/davies_bouldin_score.py b/src/torchmetrics/clustering/davies_bouldin_score.py index 98f373b9558..ddd079793cd 100644 --- a/src/torchmetrics/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/clustering/davies_bouldin_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -79,8 +79,8 @@ class DaviesBouldinScore(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: list[Tensor] - labels: list[Tensor] + data: List[Tensor] + labels: List[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index db635dee23a..65d1c0c9a94 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -67,8 +67,8 @@ class DunnIndex(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - data: list[Tensor] - labels: list[Tensor] + data: List[Tensor] + labels: List[Tensor] def __init__(self, p: float = 2, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py index 276c28f5456..1317a0cee1c 100644 --- a/src/torchmetrics/clustering/fowlkes_mallows_index.py +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -63,8 +63,8 @@ class FowlkesMallowsIndex(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index b610f134b85..260ab522245 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -68,8 +68,8 @@ class HomogeneityScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -164,8 +164,8 @@ class CompletenessScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -267,8 +267,8 @@ class VMeasureScore(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__(self, beta: float = 1.0, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 549c2f8376b..a2be02f834e 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -68,8 +68,8 @@ class MutualInfoScore(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/clustering/normalized_mutual_info_score.py b/src/torchmetrics/clustering/normalized_mutual_info_score.py index eedda19f784..2583b0b2a9e 100644 --- a/src/torchmetrics/clustering/normalized_mutual_info_score.py +++ b/src/torchmetrics/clustering/normalized_mutual_info_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union from torch import Tensor @@ -72,8 +72,8 @@ class NormalizedMutualInfoScore(MutualInfoScore): full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 0.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] contingency: Tensor def __init__( diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index d625dec55ac..724a38b227c 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -66,8 +66,8 @@ class RandScore(Metric): higher_is_better = None full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] contingency: Tensor def __init__(self, **kwargs: Any) -> None: diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 5704fae1209..9831842734d 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import numpy as np import torch @@ -306,11 +306,11 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detections: list[Tensor] - detection_scores: list[Tensor] - detection_labels: list[Tensor] - groundtruths: list[Tensor] - groundtruth_labels: list[Tensor] + detections: List[Tensor] + detection_scores: List[Tensor] + detection_labels: List[Tensor] + groundtruths: List[Tensor] + groundtruth_labels: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 26c48bd42c1..22d7e5225d4 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -132,8 +132,8 @@ class IntersectionOverUnion(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = True - groundtruth_labels: list[Tensor] - iou_matrix: list[Tensor] + groundtruth_labels: List[Tensor] + iou_matrix: List[Tensor] _iou_type: str = "iou" _invalid_val: float = -1.0 diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 1bc89f35329..a60be809bac 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -344,15 +344,15 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detection_box: list[Tensor] - detection_mask: list[Tensor] - detection_scores: list[Tensor] - detection_labels: list[Tensor] - groundtruth_box: list[Tensor] - groundtruth_mask: list[Tensor] - groundtruth_labels: list[Tensor] - groundtruth_crowds: list[Tensor] - groundtruth_area: list[Tensor] + detection_box: List[Tensor] + detection_mask: List[Tensor] + detection_scores: List[Tensor] + detection_labels: List[Tensor] + groundtruth_box: List[Tensor] + groundtruth_mask: List[Tensor] + groundtruth_labels: List[Tensor] + groundtruth_crowds: List[Tensor] + groundtruth_area: List[Tensor] warn_on_many_detections: bool = True diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index c4c55769be5..f88b79e2a28 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -43,8 +43,8 @@ def _reduce_auroc( - fpr: Union[Tensor, list[Tensor]], - tpr: Union[Tensor, list[Tensor]], + fpr: Union[Tensor, List[Tensor]], + tpr: Union[Tensor, List[Tensor]], average: Optional[Literal["macro", "weighted", "none"]] = "macro", weights: Optional[Tensor] = None, direction: float = 1.0, diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index ff7a46f0a11..1cddf8833ed 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -41,8 +41,8 @@ def _reduce_average_precision( - precision: Union[Tensor, list[Tensor]], - recall: Union[Tensor, list[Tensor]], + precision: Union[Tensor, List[Tensor]], + recall: Union[Tensor, List[Tensor]], average: Optional[Literal["macro", "weighted", "none"]] = "macro", weights: Optional[Tensor] = None, ) -> Tensor: diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index f1765916b36..3c5a840efa1 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -540,7 +540,7 @@ def _multiclass_precision_recall_curve_compute( num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -599,7 +599,7 @@ def multiclass_precision_recall_curve( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -806,7 +806,7 @@ def _multilabel_precision_recall_curve_compute( num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is @@ -845,7 +845,7 @@ def multilabel_precision_recall_curve( thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the @@ -953,7 +953,7 @@ def precision_recall_curve( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 2ce36bb3643..741733ec98f 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -164,7 +164,7 @@ def _multiclass_roc_compute( num_classes: int, thresholds: Optional[Tensor], average: Optional[Literal["micro", "macro"]] = None, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: if average == "micro": return _binary_roc_compute(state, thresholds, pos_label=1) @@ -212,7 +212,7 @@ def multiclass_roc( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multiclass tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -331,7 +331,7 @@ def _multilabel_roc_compute( num_labels: int, thresholds: Optional[Tensor], ignore_index: Optional[int] = None, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -363,7 +363,7 @@ def multilabel_roc( thresholds: Optional[Union[int, list[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multilabel tasks. The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -478,7 +478,7 @@ def roc( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index 940691a51f7..546984262b3 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -413,7 +413,7 @@ def sensitivity_at_specificity( num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 1f252cdd98e..07e7bc4243c 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -414,7 +414,7 @@ def specicity_at_sensitivity( num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. .. warning:: @@ -450,7 +450,7 @@ def specificity_at_sensitivity( num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, -) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 768116613ea..0e07cec3d28 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -24,7 +24,7 @@ # License under BSD 2-clause import inspect import os -from typing import NamedTuple, Optional, Union +from typing import List, NamedTuple, Optional, Union import torch from torch import Tensor, nn @@ -331,7 +331,7 @@ def __init__( def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False - ) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: + ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 6a1ae9deefb..ccaafe66065 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -372,7 +372,7 @@ def _multiscale_ssim_update( If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. """ - mcs_list: list[Tensor] = [] + mcs_list: List[Tensor] = [] is_3d = preds.ndim == 5 diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index bae7cb7b849..070d81bf54c 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List, Union import torch from torch import Tensor @@ -42,7 +42,7 @@ def _download_clip_for_clip_score() -> None: def _clip_score_update( - images: Union[Tensor, list[Tensor]], + images: Union[Tensor, List[Tensor]], text: Union[str, list[str]], model: _CLIPModel, processor: _CLIPProcessor, @@ -113,7 +113,7 @@ def _get_clip_model_and_processor( def clip_score( - images: Union[Tensor, list[Tensor]], + images: Union[Tensor, List[Tensor]], text: Union[str, list[str]], model_name_or_path: Literal[ "openai/clip-vit-base-patch16", diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 4d5b2028a1d..80afcc23b69 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor @@ -225,10 +225,10 @@ def _calculate_p_value( def _kendall_corrcoef_update( preds: Tensor, target: Tensor, - concat_preds: Optional[list[Tensor]] = None, - concat_target: Optional[list[Tensor]] = None, + concat_preds: Optional[List[Tensor]] = None, + concat_target: Optional[List[Tensor]] = None, num_outputs: int = 1, -) -> tuple[list[Tensor], list[Tensor]]: +) -> tuple[List[Tensor], List[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. Args: diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index c34419ac613..380c7048914 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -1,6 +1,6 @@ import os from collections.abc import Sequence -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import torch from torch import Tensor @@ -349,7 +349,7 @@ def _translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index e4a38b29755..9835723fae4 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -16,7 +16,7 @@ import urllib from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -106,8 +106,8 @@ def _get_embeddings_and_idf_scale( If ``all_layers = True`` and a model, which is not from the ``transformers`` package, is used. """ - embeddings_list: list[Tensor] = [] - idf_scale_list: list[Tensor] = [] + embeddings_list: List[Tensor] = [] + idf_scale_list: List[Tensor] = [] for batch in _get_progress_bar(dataloader, verbose): with torch.no_grad(): batch = _input_data_collator(batch, device) diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 66ee3b7a449..0e5e4978f1f 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -23,7 +23,7 @@ from collections import defaultdict from collections.abc import Sequence from itertools import chain -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor, tensor @@ -386,7 +386,7 @@ def _chrf_score_update( beta: float, lowercase: bool, whitespace: bool, - sentence_chrf_score: Optional[list[Tensor]] = None, + sentence_chrf_score: Optional[List[Tensor]] = None, ) -> tuple[ dict[int, Tensor], dict[int, Tensor], @@ -394,7 +394,7 @@ def _chrf_score_update( dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], - Optional[list[Tensor]], + Optional[List[Tensor]], ]: """Update function for chrf score. @@ -594,7 +594,7 @@ def chrf_score( total_matching_word_n_grams, ) = _prepare_n_grams_dicts(n_char_order, n_word_order) - sentence_chrf_score: Optional[list[Tensor]] = [] if return_sentence_level_score else None + sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None ( total_preds_char_n_grams, diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index ca562c5212a..bde77680bea 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -90,7 +90,7 @@ import unicodedata from collections.abc import Sequence from math import inf -from typing import Optional, Union +from typing import List, Optional, Union from torch import Tensor, stack, tensor from typing_extensions import Literal @@ -234,7 +234,7 @@ def _preprocess_ja(sentence: str) -> str: return unicodedata.normalize("NFKC", sentence) -def _eed_compute(sentence_level_scores: list[Tensor]) -> Tensor: +def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor: """Reduction for extended edit distance. Args: @@ -328,8 +328,8 @@ def _eed_update( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, - sentence_eed: Optional[list[Tensor]] = None, -) -> list[Tensor]: + sentence_eed: Optional[List[Tensor]] = None, +) -> List[Tensor]: """Compute scores for ExtendedEditDistance. Args: diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 48a42e32b2f..94452f4886e 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -14,7 +14,7 @@ import os from collections.abc import Sequence from enum import unique -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from torch import Tensor @@ -393,7 +393,7 @@ def _get_batch_distribution( """ seq_len = batch["input_ids"].shape[1] - prob_distribution_batch_list: list[Tensor] = [] + prob_distribution_batch_list: List[Tensor] = [] token_mask = _get_token_mask( batch["input_ids"], special_tokens_map["pad_token_id"], @@ -454,7 +454,7 @@ def _get_data_distribution( """ device = model.device - prob_distribution: list[Tensor] = [] + prob_distribution: List[Tensor] = [] for batch in _get_progress_bar(dataloader, verbose): batch = _input_data_collator(batch, device) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 791de8dbe26..23a7aeb0856 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -14,7 +14,7 @@ import re from collections import Counter from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, tensor @@ -374,7 +374,7 @@ def _rouge_score_update( rouge_key: {} for rouge_key in rouge_keys_values } for rouge_key, metrics in result_avg.items(): - _dict_metric_score_batch: dict[str, list[Tensor]] = {} + _dict_metric_score_batch: dict[str, List[Tensor]] = {} for metric in metrics: for _type, value in metric.items(): if _type not in _dict_metric_score_batch: @@ -391,7 +391,7 @@ def _rouge_score_update( return results -def _rouge_score_compute(sentence_results: dict[str, list[Tensor]]) -> dict[str, Tensor]: +def _rouge_score_compute(sentence_results: dict[str, List[Tensor]]) -> dict[str, Tensor]: """Compute the combined ROUGE metric for all the input set of predicted and target sentences. Args: @@ -505,7 +505,7 @@ def rouge_score( accumulate=accumulate, ) - output: dict[str, list[Tensor]] = { + output: dict[str, List[Tensor]] = { f"rouge{rouge_key}_{tp}": [] for rouge_key in rouge_keys_values for tp in ["fmeasure", "precision", "recall"] } for rouge_key, metrics in sentence_results.items(): diff --git a/src/torchmetrics/functional/text/ter.py b/src/torchmetrics/functional/text/ter.py index 02617f2cbaf..2d7a6211e0d 100644 --- a/src/torchmetrics/functional/text/ter.py +++ b/src/torchmetrics/functional/text/ter.py @@ -36,7 +36,7 @@ import re from collections.abc import Iterator, Sequence from functools import lru_cache -from typing import Optional, Union +from typing import List, Optional, Union from torch import Tensor, tensor @@ -477,8 +477,8 @@ def _ter_update( tokenizer: _TercomTokenizer, total_num_edits: Tensor, total_tgt_length: Tensor, - sentence_ter: Optional[list[Tensor]] = None, -) -> tuple[Tensor, Tensor, Optional[list[Tensor]]]: + sentence_ter: Optional[List[Tensor]] = None, +) -> tuple[Tensor, Tensor, Optional[List[Tensor]]]: """Update TER statistics. Args: @@ -537,7 +537,7 @@ def translation_edit_rate( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, -) -> Union[Tensor, tuple[Tensor, list[Tensor]]]: +) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: """Calculate Translation edit rate (`TER`_) of machine translated text with one or more references. This implementation follows the implementations from @@ -581,7 +581,7 @@ def translation_edit_rate( total_num_edits = tensor(0.0) total_tgt_length = tensor(0.0) - sentence_ter: Optional[list[Tensor]] = [] if return_sentence_level_score else None + sentence_ter: Optional[List[Tensor]] = [] if return_sentence_level_score else None total_num_edits, total_tgt_length, sentence_ter = _ter_update( preds, diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 65824c19600..97d95ccd926 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -70,8 +70,8 @@ class SpectralDistortionIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index f6b12efd5ad..9143810f545 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -94,10 +94,10 @@ class SpatialDistortionIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - ms: list[Tensor] - pan: list[Tensor] - pan_lr: list[Tensor] + preds: List[Tensor] + ms: List[Tensor] + pan: List[Tensor] + pan_lr: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index bdf884c2719..22c24b164f1 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -78,8 +78,8 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index f3c3d16bdd4..99c2b04bf7b 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -169,8 +169,8 @@ class KernelInceptionDistance(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - real_features: list[Tensor] - fake_features: list[Tensor] + real_features: List[Tensor] + fake_features: List[Tensor] inception: Module feature_network: str = "inception" diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 6ce894cffd1..5d344b57dd1 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -149,8 +149,8 @@ class MemorizationInformedFrechetInceptionDistance(Metric): is_differentiable: bool = False full_state_update: bool = False - real_features: list[Tensor] - fake_features: list[Tensor] + real_features: List[Tensor] + fake_features: List[Tensor] inception: Module feature_network: str = "inception" diff --git a/src/torchmetrics/image/qnr.py b/src/torchmetrics/image/qnr.py index 649a69617ef..f28e61b2fbf 100644 --- a/src/torchmetrics/image/qnr.py +++ b/src/torchmetrics/image/qnr.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -90,10 +90,10 @@ class QualityWithNoReference(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - ms: list[Tensor] - pan: list[Tensor] - pan_lr: list[Tensor] + preds: List[Tensor] + ms: List[Tensor] + pan: List[Tensor] + pan_lr: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index 0a06041f94d..bca9504c1aa 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -64,8 +64,8 @@ class RelativeAverageSpectralError(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index 07449cd6939..b313158f80b 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal @@ -73,8 +73,8 @@ class SpectralAngleMapper(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] sum_sam: Tensor numel: Tensor diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 034dabf7821..fd9d12d770d 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -84,8 +84,8 @@ class StructuralSimilarityIndexMeasure(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, @@ -285,8 +285,8 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index eb6122fb41f..287e58a3a43 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor @@ -69,7 +69,7 @@ class TotalVariation(Metric): plot_lower_bound: float = 0.0 num_elements: Tensor - score_list: list[Tensor] + score_list: List[Tensor] score: Tensor def __init__(self, reduction: Optional[Literal["mean", "sum", "none"]] = "sum", **kwargs: Any) -> None: diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index b094f14e851..c503cc1f394 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor, tensor from typing_extensions import Literal @@ -72,8 +72,8 @@ class UniversalImageQualityIndex(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] sum_uqi: Tensor numel: Tensor diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2f0b3b617ed..b270903eafd 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -21,7 +21,7 @@ from collections.abc import Generator, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, ClassVar, Optional, Union +from typing import Any, Callable, ClassVar, List, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -172,7 +172,7 @@ def __init__( # state management self._is_synced = False - self._cache: Optional[dict[str, Union[list[Tensor], Tensor]]] = None + self._cache: Optional[dict[str, Union[List[Tensor], Tensor]]] = None @property def _update_called(self) -> bool: @@ -194,7 +194,7 @@ def update_count(self) -> int: return self._update_count @property - def metric_state(self) -> dict[str, Union[list[Tensor], Tensor]]: + def metric_state(self) -> dict[str, Union[List[Tensor], Tensor]]: """Get the current state of the metric.""" return {attr: getattr(self, attr) for attr in self._defaults} diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index 1378b203040..cc9c0715be6 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor @@ -166,7 +166,7 @@ class CLIPImageQualityAssessment(Metric): plot_upper_bound = 100.0 anchors: Tensor - probs_list: list[Tensor] + probs_list: List[Tensor] feature_network: str = "model" def __init__( diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 382e68db53a..c89384fbb35 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -118,7 +118,7 @@ def __init__( self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - def update(self, images: Union[Tensor, list[Tensor]], text: Union[str, list[str]]) -> None: + def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str]]) -> None: """Update CLIP score on a batch of images and text. Args: diff --git a/src/torchmetrics/nominal/fleiss_kappa.py b/src/torchmetrics/nominal/fleiss_kappa.py index 715565a9fdc..254796e96c9 100644 --- a/src/torchmetrics/nominal/fleiss_kappa.py +++ b/src/torchmetrics/nominal/fleiss_kappa.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -77,7 +77,7 @@ class FleissKappa(Metric): is_differentiable: bool = False higher_is_better: bool = True plot_upper_bound: float = 1.0 - counts: list[Tensor] + counts: List[Tensor] def __init__(self, mode: Literal["counts", "probs"] = "counts", **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/regression/cosine_similarity.py b/src/torchmetrics/regression/cosine_similarity.py index 7d2eec54470..5c86ac00cab 100644 --- a/src/torchmetrics/regression/cosine_similarity.py +++ b/src/torchmetrics/regression/cosine_similarity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -66,8 +66,8 @@ class CosineSimilarity(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index e3266f23623..8a102dee08f 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -120,8 +120,8 @@ class KendallRankCorrCoef(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index adff24aab9e..8cf165471fe 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -117,8 +117,8 @@ class PearsonCorrCoef(Metric): full_state_update: bool = True plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] mean_x: Tensor mean_y: Tensor var_x: Tensor diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 83ea08d5f50..de94903c8c0 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor @@ -75,8 +75,8 @@ class SpearmanCorrCoef(Metric): plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 - preds: list[Tensor] - target: list[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/retrieval/base.py b/src/torchmetrics/retrieval/base.py index d63be19ff1b..f9a0a4f8cc4 100644 --- a/src/torchmetrics/retrieval/base.py +++ b/src/torchmetrics/retrieval/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, tensor @@ -98,9 +98,9 @@ class RetrievalMetric(Metric, ABC): higher_is_better: bool = True full_state_update: bool = False - indexes: list[Tensor] - preds: list[Tensor] - target: list[Tensor] + indexes: List[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index fb4d54a4f50..9eef71c154e 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -143,9 +143,9 @@ class RetrievalPrecisionRecallCurve(Metric): higher_is_better: bool = True full_state_update: bool = False - indexes: list[Tensor] - preds: list[Tensor] - target: list[Tensor] + indexes: List[Tensor] + preds: List[Tensor] + target: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 70a38009ea2..05a6e29b387 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor from typing_extensions import Literal @@ -100,9 +100,9 @@ class DiceScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - numerator: list[Tensor] - denominator: list[Tensor] - support: list[Tensor] + numerator: List[Tensor] + denominator: List[Tensor] + support: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index f6550749f6b..103c5a47bc7 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor @@ -47,7 +47,7 @@ def _download_model_for_bert_score() -> None: __doctest_skip__ = ["BERTScore", "BERTScore.plot"] -def _get_input_dict(input_ids: list[Tensor], attention_mask: list[Tensor]) -> dict[str, Tensor]: +def _get_input_dict(input_ids: List[Tensor], attention_mask: List[Tensor]) -> dict[str, Tensor]: """Create an input dictionary of ``input_ids`` and ``attention_mask`` for BERTScore calculation.""" return {"input_ids": torch.cat(input_ids), "attention_mask": torch.cat(attention_mask)} @@ -128,10 +128,10 @@ class BERTScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - preds_input_ids: list[Tensor] - preds_attention_mask: list[Tensor] - target_input_ids: list[Tensor] - target_attention_mask: list[Tensor] + preds_input_ids: List[Tensor] + preds_attention_mask: List[Tensor] + target_input_ids: List[Tensor] + target_attention_mask: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 8b70e1b3269..64eb7c6b1d4 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -19,7 +19,7 @@ import itertools from collections.abc import Iterator, Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor @@ -101,7 +101,7 @@ class CHRFScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - sentence_chrf_score: Optional[list[Tensor]] = None + sentence_chrf_score: Optional[List[Tensor]] = None def __init__( self, diff --git a/src/torchmetrics/text/edit.py b/src/torchmetrics/text/edit.py index dd1750436b1..947fc79cd6c 100644 --- a/src/torchmetrics/text/edit.py +++ b/src/torchmetrics/text/edit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union import torch from torch import Tensor @@ -90,7 +90,7 @@ class EditDistance(Metric): full_state_update: bool = False plot_lower_bound: float = 0.0 - edit_scores_list: list[Tensor] + edit_scores_list: List[Tensor] edit_scores: Tensor num_elements: Tensor diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index dd12a852ccf..c776eba2331 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from torch import Tensor, stack from typing_extensions import Literal @@ -65,7 +65,7 @@ class ExtendedEditDistance(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - sentence_eed: list[Tensor] + sentence_eed: List[Tensor] def __init__( self, diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index ae585161bc2..74e931c7cba 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -13,7 +13,7 @@ # limitations under the License. import os from collections.abc import Sequence -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, List, Optional, Union import torch from torch import Tensor @@ -112,10 +112,10 @@ class InfoLM(Metric): """ is_differentiable = False - preds_input_ids: list[Tensor] - preds_attention_mask: list[Tensor] - target_input_ids: list[Tensor] - target_attention_mask: list[Tensor] + preds_input_ids: List[Tensor] + preds_attention_mask: List[Tensor] + target_input_ids: List[Tensor] + target_attention_mask: List[Tensor] _information_measure_higher_is_better: ClassVar = { # following values are <0 diff --git a/src/torchmetrics/text/ter.py b/src/torchmetrics/text/ter.py index 3773e3af923..6cdd1d02118 100644 --- a/src/torchmetrics/text/ter.py +++ b/src/torchmetrics/text/ter.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor, tensor @@ -69,7 +69,7 @@ class TranslationEditRate(Metric): total_num_edits: Tensor total_tgt_len: Tensor - sentence_ter: Optional[list[Tensor]] = None + sentence_ter: Optional[List[Tensor]] = None def __init__( self, diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 2f98de35d19..e5bb148a9ce 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -13,7 +13,7 @@ # limitations under the License. import sys from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -26,7 +26,7 @@ METRIC_EPS = 1e-6 -def dim_zero_cat(x: Union[Tensor, list[Tensor]]) -> Tensor: +def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: """Concatenation along the zero dimension.""" if isinstance(x, torch.Tensor): return x diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 31a1d0dca5b..a68cffec3f1 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional import torch from torch import Tensor @@ -88,7 +88,7 @@ def class_reduce( raise ValueError(f"Reduction parameter {class_reduction} unknown. Choose between one of these: {valid_reduction}") -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: with torch.no_grad(): gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) @@ -97,8 +97,8 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> l return gathered_result -def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: - """Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. +def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: + """Gather all tensors from several ddp processes onto a list that is broadcast to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case tensors are padded, gathered and then trimmed to secure equal workload for all processes. diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index e3ced02575c..4d14349b7f9 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -14,7 +14,7 @@ from collections.abc import Generator, Sequence from itertools import product from math import ceil, floor, sqrt -from typing import Any, Optional, Union, no_type_check +from typing import Any, List, Optional, Union, no_type_check import numpy as np import torch @@ -295,7 +295,7 @@ def plot_confusion_matrix( @style_change(_style) def plot_curve( - curve: Union[tuple[Tensor, Tensor, Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]], + curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]], score: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] label_names: Optional[tuple[str, str]] = None, From df066c0c2873d096c0552068c17806d8c3b620c3 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 13:00:11 +0000 Subject: [PATCH 13/15] List... --- src/torchmetrics/collections.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 23cf92a7f9c..01185223bbd 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -15,7 +15,7 @@ from collections import OrderedDict from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, List, Optional, Union import torch from torch import Tensor @@ -192,7 +192,7 @@ class name of the metric: """ _modules: dict[str, Metric] # type: ignore[assignment] - _groups: dict[int, list[str]] + _groups: dict[int, List[str]] __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"] def __init__( @@ -516,7 +516,7 @@ def _init_compute_groups(self) -> None: self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))} @property - def compute_groups(self) -> dict[int, list[str]]: + def compute_groups(self) -> dict[int, List[str]]: """Return a dict with the current compute groups in the collection.""" return self._groups From ee8be60ed19fef462b25633675cc177fdb1f89f5 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 16:04:18 +0000 Subject: [PATCH 14/15] Dict --- src/torchmetrics/collections.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 01185223bbd..e034027fe0d 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -15,7 +15,7 @@ from collections import OrderedDict from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy -from typing import Any, ClassVar, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import torch from torch import Tensor @@ -192,7 +192,7 @@ class name of the metric: """ _modules: dict[str, Metric] # type: ignore[assignment] - _groups: dict[int, List[str]] + _groups: Dict[int, List[str]] __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"] def __init__( @@ -516,7 +516,7 @@ def _init_compute_groups(self) -> None: self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))} @property - def compute_groups(self) -> dict[int, List[str]]: + def compute_groups(self) -> Dict[int, List[str]]: """Return a dict with the current compute groups in the collection.""" return self._groups From ee25b3bc20657ca7b3f3e3f4bd693f0404824a5c Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Nov 2024 16:31:57 +0000 Subject: [PATCH 15/15] chlog + linter --- CHANGELOG.md | 3 +++ pyproject.toml | 2 ++ 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98b55fab971..96db8376cb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671)) +- Dropped support for Python 3.8 ([#2827](https://github.com/Lightning-AI/torchmetrics/pull/2827)) + + - Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) diff --git a/pyproject.toml b/pyproject.toml index 9aa3f5c4e53..d8e25bd895a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,8 @@ lint.per-file-ignores."setup.py" = [ lint.per-file-ignores."src/**" = [ "ANN401", "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. + "UP006", # todo: Use `list` instead of `List` for type annotation + "UP035", # todo: `typing.List` is deprecated, use `list` instead ] lint.per-file-ignores."tests/**" = [ "ANN001",