From fc484f3149b4829ed372481ccd96a148a7a6bcd8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 22:00:28 +0200 Subject: [PATCH 1/2] typing for kwargs --- torchmetrics/aggregation.py | 12 ++++++------ torchmetrics/audio/pesq.py | 2 +- torchmetrics/audio/pit.py | 2 +- torchmetrics/audio/sdr.py | 4 ++-- torchmetrics/audio/snr.py | 4 ++-- torchmetrics/audio/stoi.py | 2 +- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/auc.py | 2 +- torchmetrics/classification/auroc.py | 2 +- torchmetrics/classification/avg_precision.py | 2 +- .../classification/binned_precision_recall.py | 4 ++-- torchmetrics/classification/calibration_error.py | 2 +- torchmetrics/classification/cohen_kappa.py | 2 +- torchmetrics/classification/confusion_matrix.py | 2 +- torchmetrics/classification/dice.py | 2 +- torchmetrics/classification/f_beta.py | 4 ++-- torchmetrics/classification/hamming.py | 2 +- torchmetrics/classification/hinge.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/kl_divergence.py | 2 +- torchmetrics/classification/matthews_corrcoef.py | 2 +- torchmetrics/classification/precision_recall.py | 4 ++-- .../classification/precision_recall_curve.py | 2 +- torchmetrics/classification/roc.py | 2 +- torchmetrics/classification/specificity.py | 2 +- torchmetrics/classification/stat_scores.py | 2 +- torchmetrics/detection/mean_ap.py | 2 +- torchmetrics/functional/audio/pit.py | 2 +- torchmetrics/image/fid.py | 2 +- torchmetrics/image/inception.py | 2 +- torchmetrics/image/kid.py | 2 +- torchmetrics/image/lpip.py | 2 +- torchmetrics/image/psnr.py | 2 +- torchmetrics/image/ssim.py | 4 ++-- torchmetrics/image/uqi.py | 2 +- torchmetrics/regression/cosine_similarity.py | 2 +- torchmetrics/regression/explained_variance.py | 2 +- torchmetrics/regression/log_mse.py | 2 +- torchmetrics/regression/mae.py | 2 +- torchmetrics/regression/mape.py | 2 +- torchmetrics/regression/mse.py | 2 +- torchmetrics/regression/pearson.py | 2 +- torchmetrics/regression/r2.py | 2 +- torchmetrics/regression/spearman.py | 2 +- torchmetrics/regression/symmetric_mape.py | 2 +- torchmetrics/regression/tweedie_deviance.py | 2 +- torchmetrics/retrieval/base.py | 2 +- torchmetrics/retrieval/fall_out.py | 2 +- torchmetrics/retrieval/hit_rate.py | 2 +- torchmetrics/retrieval/ndcg.py | 2 +- torchmetrics/retrieval/precision.py | 2 +- torchmetrics/retrieval/precision_recall_curve.py | 4 ++-- torchmetrics/retrieval/recall.py | 2 +- torchmetrics/text/bert.py | 2 +- torchmetrics/text/bleu.py | 2 +- torchmetrics/text/cer.py | 2 +- torchmetrics/text/chrf.py | 2 +- torchmetrics/text/eed.py | 2 +- torchmetrics/text/mer.py | 2 +- torchmetrics/text/rouge.py | 2 +- torchmetrics/text/sacre_bleu.py | 2 +- torchmetrics/text/squad.py | 2 +- torchmetrics/text/ter.py | 2 +- torchmetrics/text/wer.py | 2 +- torchmetrics/text/wil.py | 2 +- torchmetrics/text/wip.py | 2 +- torchmetrics/wrappers/bootstrapping.py | 2 +- torchmetrics/wrappers/minmax.py | 2 +- 68 files changed, 80 insertions(+), 80 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index ac06452da57..dcf74b48fef 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -50,7 +50,7 @@ def __init__( fn: Union[Callable, str], default_value: Union[Tensor, List], nan_strategy: Union[str, float] = "error", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) allowed_nan_strategy = ("error", "warn", "ignore") @@ -122,7 +122,7 @@ class MaxMetric(BaseAggregator): def __init__( self, nan_strategy: Union[str, float] = "warn", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__( "max", @@ -173,7 +173,7 @@ class MinMetric(BaseAggregator): def __init__( self, nan_strategy: Union[str, float] = "warn", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__( "min", @@ -222,7 +222,7 @@ class SumMetric(BaseAggregator): def __init__( self, nan_strategy: Union[str, float] = "warn", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__( "sum", @@ -271,7 +271,7 @@ class CatMetric(BaseAggregator): def __init__( self, nan_strategy: Union[str, float] = "warn", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__("cat", [], nan_strategy, **kwargs) @@ -321,7 +321,7 @@ class MeanMetric(BaseAggregator): def __init__( self, nan_strategy: Union[str, float] = "warn", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__( "sum", diff --git a/torchmetrics/audio/pesq.py b/torchmetrics/audio/pesq.py index f16e3714803..3c1c18c7fc1 100644 --- a/torchmetrics/audio/pesq.py +++ b/torchmetrics/audio/pesq.py @@ -80,7 +80,7 @@ def __init__( self, fs: int, mode: str, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if not _PESQ_AVAILABLE: diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index db0181208a2..cf175a6d859 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -70,7 +70,7 @@ def __init__( self, metric_func: Callable, eval_func: str = "max", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: base_kwargs: Dict[str, Any] = { "dist_sync_on_step": kwargs.pop("dist_sync_on_step", False), diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index 88368cdfdc5..4798a2d320a 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -88,7 +88,7 @@ def __init__( filter_length: int = 512, zero_mean: bool = False, load_diag: Optional[float] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -162,7 +162,7 @@ class ScaleInvariantSignalDistortionRatio(Metric): def __init__( self, zero_mean: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.zero_mean = zero_mean diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index 6bde6067b50..6f6b36b6e10 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -69,7 +69,7 @@ class SignalNoiseRatio(Metric): def __init__( self, zero_mean: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.zero_mean = zero_mean @@ -134,7 +134,7 @@ class ScaleInvariantSignalNoiseRatio(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index a82c1342c65..5ebe10345c8 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -87,7 +87,7 @@ def __init__( self, fs: int, extended: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if not _PYSTOI_AVAILABLE: diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 8f77235dc39..1ea5b524885 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -170,7 +170,7 @@ def __init__( top_k: Optional[int] = None, multiclass: Optional[bool] = None, subset_accuracy: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/auc.py b/torchmetrics/classification/auc.py index 451c1efbb58..be1a3e95e0f 100644 --- a/torchmetrics/classification/auc.py +++ b/torchmetrics/classification/auc.py @@ -44,7 +44,7 @@ class AUC(Metric): def __init__( self, reorder: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index 2d86a33272c..7e88bb02e42 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -109,7 +109,7 @@ def __init__( pos_label: Optional[int] = None, average: Optional[str] = "macro", max_fpr: Optional[float] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/avg_precision.py b/torchmetrics/classification/avg_precision.py index d3188ae6a91..d5d42ef3042 100644 --- a/torchmetrics/classification/avg_precision.py +++ b/torchmetrics/classification/avg_precision.py @@ -89,7 +89,7 @@ def __init__( num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = "macro", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index d7cdee35db1..26ace5fcd5d 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -120,7 +120,7 @@ def __init__( self, num_classes: int, thresholds: Union[int, Tensor, List[float]] = 100, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -281,7 +281,7 @@ def __init__( num_classes: int, min_precision: float, thresholds: Union[int, Tensor, List[float]] = 100, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(num_classes=num_classes, thresholds=thresholds, **kwargs) self.min_precision = min_precision diff --git a/torchmetrics/classification/calibration_error.py b/torchmetrics/classification/calibration_error.py index 98906b216c5..4f7b2953dd6 100644 --- a/torchmetrics/classification/calibration_error.py +++ b/torchmetrics/classification/calibration_error.py @@ -66,7 +66,7 @@ def __init__( self, n_bins: int = 15, norm: str = "l1", - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) diff --git a/torchmetrics/classification/cohen_kappa.py b/torchmetrics/classification/cohen_kappa.py index 62dbea8a65c..50c5fba1a4f 100644 --- a/torchmetrics/classification/cohen_kappa.py +++ b/torchmetrics/classification/cohen_kappa.py @@ -77,7 +77,7 @@ def __init__( num_classes: int, weights: Optional[str] = None, threshold: float = 0.5, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.num_classes = num_classes diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index e4052e7e643..814fd324145 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -96,7 +96,7 @@ def __init__( normalize: Optional[str] = None, threshold: float = 0.5, multilabel: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.num_classes = num_classes diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index 09ff5be205b..ec12a60589c 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -128,7 +128,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 754436aac04..d28ea98ef67 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -130,7 +130,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: self.beta = beta allowed_average = list(AverageMethod) @@ -256,7 +256,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( num_classes=num_classes, diff --git a/torchmetrics/classification/hamming.py b/torchmetrics/classification/hamming.py index b475dbb4db1..ac24568eecb 100644 --- a/torchmetrics/classification/hamming.py +++ b/torchmetrics/classification/hamming.py @@ -65,7 +65,7 @@ class HammingDistance(Metric): def __init__( self, threshold: float = 0.5, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index 8037c4f2971..3a1582d4d84 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -97,7 +97,7 @@ def __init__( self, squared: bool = False, multiclass_mode: Optional[Union[str, MulticlassMode]] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 58f2438f671..14260c58b67 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -88,7 +88,7 @@ def __init__( absent_score: float = 0.0, threshold: float = 0.5, multilabel: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( num_classes=num_classes, diff --git a/torchmetrics/classification/kl_divergence.py b/torchmetrics/classification/kl_divergence.py index be7c4e18136..6766caa091d 100644 --- a/torchmetrics/classification/kl_divergence.py +++ b/torchmetrics/classification/kl_divergence.py @@ -74,7 +74,7 @@ def __init__( self, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if not isinstance(log_prob, bool): diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index 2a760a5a14b..b802901da7e 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -72,7 +72,7 @@ def __init__( self, num_classes: int, threshold: float = 0.5, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.num_classes = num_classes diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 69d6cfe0ac1..ff6c55dda77 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -121,7 +121,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -256,7 +256,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index 27e163d0966..4a0640c194f 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -85,7 +85,7 @@ def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index fdc0b42cfca..579a82b2757 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -110,7 +110,7 @@ def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 9f8eabb89f6..e5dd88b6ad0 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -123,7 +123,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index ead35d1c7cb..b65e8a36954 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -128,7 +128,7 @@ def __init__( ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/detection/mean_ap.py b/torchmetrics/detection/mean_ap.py index 8ade4b12677..557135aaad4 100644 --- a/torchmetrics/detection/mean_ap.py +++ b/torchmetrics/detection/mean_ap.py @@ -301,7 +301,7 @@ def __init__( rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, class_metrics: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: # type: ignore super().__init__(**kwargs) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 8d232f54066..6a04af7beb3 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -93,7 +93,7 @@ def _find_best_perm_by_exhaustive_method( def permutation_invariant_training( - preds: Tensor, target: Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any] + preds: Tensor, target: Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Any ) -> Tuple[Tensor, Tensor]: """Permutation invariant training (PIT). The ``permutation_invariant_training`` implements the famous Permutation Invariant Training method. diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 36c798aa46d..0053f5ecce6 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -208,7 +208,7 @@ def __init__( self, feature: Union[int, Module] = 2048, reset_real_features: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index 2dbdeb227c3..0fe54fe95da 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -98,7 +98,7 @@ def __init__( self, feature: Union[str, int, Module] = "logits_unbiased", splits: int = 10, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 2f5a5c4a18c..74bc92f10fb 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -165,7 +165,7 @@ def __init__( gamma: Optional[float] = None, # type: ignore coef: float = 1.0, reset_real_features: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/image/lpip.py b/torchmetrics/image/lpip.py index 4d6a12a8ff1..058aac78fbd 100644 --- a/torchmetrics/image/lpip.py +++ b/torchmetrics/image/lpip.py @@ -95,7 +95,7 @@ def __init__( self, net_type: str = "alex", reduction: Literal["sum", "mean"] = "mean", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index 79f3a31b4e8..5be28c97b61 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -75,7 +75,7 @@ def __init__( base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", dim: Optional[Union[int, Tuple[int, ...]]] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index 9ae07e98d65..08575a7a4cb 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -80,7 +80,7 @@ def __init__( k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) rank_zero_warn( @@ -198,7 +198,7 @@ def __init__( k2: float = 0.03, betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) rank_zero_warn( diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 54abf29a038..330420b9417 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -64,7 +64,7 @@ def __init__( sigma: Sequence[float] = (1.5, 1.5), reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", data_range: Optional[float] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) rank_zero_warn( diff --git a/torchmetrics/regression/cosine_similarity.py b/torchmetrics/regression/cosine_similarity.py index 0824cc08a4b..71bbd621df0 100644 --- a/torchmetrics/regression/cosine_similarity.py +++ b/torchmetrics/regression/cosine_similarity.py @@ -60,7 +60,7 @@ class CosineSimilarity(Metric): def __init__( self, reduction: Literal["mean", "sum", "none", None] = "sum", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) allowed_reduction = ("sum", "mean", "none", None) diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 8844aa55871..6733177351f 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -80,7 +80,7 @@ class ExplainedVariance(Metric): def __init__( self, multioutput: str = "uniform_average", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) allowed_multioutput = ("raw_values", "uniform_average", "variance_weighted") diff --git a/torchmetrics/regression/log_mse.py b/torchmetrics/regression/log_mse.py index 176420177ee..d3f9e9ae94d 100644 --- a/torchmetrics/regression/log_mse.py +++ b/torchmetrics/regression/log_mse.py @@ -50,7 +50,7 @@ class MeanSquaredLogError(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/mae.py b/torchmetrics/regression/mae.py index 780873204f2..3260166c669 100644 --- a/torchmetrics/regression/mae.py +++ b/torchmetrics/regression/mae.py @@ -46,7 +46,7 @@ class MeanAbsoluteError(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/mape.py b/torchmetrics/regression/mape.py index fda394d9830..77e30fb0e30 100644 --- a/torchmetrics/regression/mape.py +++ b/torchmetrics/regression/mape.py @@ -58,7 +58,7 @@ class MeanAbsolutePercentageError(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/mse.py b/torchmetrics/regression/mse.py index 152e9e5c97e..9a8c5a40070 100644 --- a/torchmetrics/regression/mse.py +++ b/torchmetrics/regression/mse.py @@ -49,7 +49,7 @@ class MeanSquaredError(Metric): def __init__( self, squared: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 1461d73635f..3fc13466956 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -102,7 +102,7 @@ class PearsonCorrCoef(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/r2.py b/torchmetrics/regression/r2.py index 4919889f9d7..11a3be7dae8 100644 --- a/torchmetrics/regression/r2.py +++ b/torchmetrics/regression/r2.py @@ -85,7 +85,7 @@ def __init__( num_outputs: int = 1, adjusted: int = 0, multioutput: str = "uniform_average", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index a44b29b0287..28ffdaaea5f 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -52,7 +52,7 @@ class SpearmanCorrCoef(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) rank_zero_warn( diff --git a/torchmetrics/regression/symmetric_mape.py b/torchmetrics/regression/symmetric_mape.py index 55615ccadda..faa40203526 100644 --- a/torchmetrics/regression/symmetric_mape.py +++ b/torchmetrics/regression/symmetric_mape.py @@ -55,7 +55,7 @@ class SymmetricMeanAbsolutePercentageError(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/torchmetrics/regression/tweedie_deviance.py b/torchmetrics/regression/tweedie_deviance.py index 0794c22720f..7084ed41e7d 100644 --- a/torchmetrics/regression/tweedie_deviance.py +++ b/torchmetrics/regression/tweedie_deviance.py @@ -74,7 +74,7 @@ class TweedieDevianceScore(Metric): def __init__( self, power: float = 0.0, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if 0 < power < 1: diff --git a/torchmetrics/retrieval/base.py b/torchmetrics/retrieval/base.py index 0a4d847012d..7ae07331574 100644 --- a/torchmetrics/retrieval/base.py +++ b/torchmetrics/retrieval/base.py @@ -74,7 +74,7 @@ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.allow_non_binary_target = False diff --git a/torchmetrics/retrieval/fall_out.py b/torchmetrics/retrieval/fall_out.py index 0c734b5f779..ba6e49d14d2 100644 --- a/torchmetrics/retrieval/fall_out.py +++ b/torchmetrics/retrieval/fall_out.py @@ -78,7 +78,7 @@ def __init__( empty_target_action: str = "pos", ignore_index: Optional[int] = None, k: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/hit_rate.py b/torchmetrics/retrieval/hit_rate.py index ab620ca0929..40f07885ea9 100644 --- a/torchmetrics/retrieval/hit_rate.py +++ b/torchmetrics/retrieval/hit_rate.py @@ -76,7 +76,7 @@ def __init__( empty_target_action: str = "neg", ignore_index: Optional[int] = None, k: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/ndcg.py b/torchmetrics/retrieval/ndcg.py index 2b4e47243c7..dbf31dafc52 100644 --- a/torchmetrics/retrieval/ndcg.py +++ b/torchmetrics/retrieval/ndcg.py @@ -76,7 +76,7 @@ def __init__( empty_target_action: str = "neg", ignore_index: Optional[int] = None, k: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/precision.py b/torchmetrics/retrieval/precision.py index e9ffe841422..f0197e9a9f1 100644 --- a/torchmetrics/retrieval/precision.py +++ b/torchmetrics/retrieval/precision.py @@ -80,7 +80,7 @@ def __init__( ignore_index: Optional[int] = None, k: Optional[int] = None, adaptive_k: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/retrieval/precision_recall_curve.py b/torchmetrics/retrieval/precision_recall_curve.py index 854a87738b6..fca3f31c5f1 100644 --- a/torchmetrics/retrieval/precision_recall_curve.py +++ b/torchmetrics/retrieval/precision_recall_curve.py @@ -125,7 +125,7 @@ def __init__( adaptive_k: bool = False, empty_target_action: str = "neg", ignore_index: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.allow_non_binary_target = False @@ -270,7 +270,7 @@ def __init__( adaptive_k: bool = False, empty_target_action: str = "neg", ignore_index: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( max_k=max_k, diff --git a/torchmetrics/retrieval/recall.py b/torchmetrics/retrieval/recall.py index 034609ea774..1266c0bcf09 100644 --- a/torchmetrics/retrieval/recall.py +++ b/torchmetrics/retrieval/recall.py @@ -75,7 +75,7 @@ def __init__( empty_target_action: str = "neg", ignore_index: Optional[int] = None, k: Optional[int] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__( empty_target_action=empty_target_action, diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 2884c74aa78..1ce9365e506 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -127,7 +127,7 @@ def __init__( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.model_name_or_path = model_name_or_path or _DEFAULT_MODEL diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 1f365d326e9..a07354795ed 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -69,7 +69,7 @@ def __init__( n_gram: int = 4, smooth: bool = False, weights: Optional[Sequence[float]] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.n_gram = n_gram diff --git a/torchmetrics/text/cer.py b/torchmetrics/text/cer.py index c1158eabb2c..31466da72e9 100644 --- a/torchmetrics/text/cer.py +++ b/torchmetrics/text/cer.py @@ -64,7 +64,7 @@ class CharErrorRate(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index f6eb281412a..019265ba93d 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -96,7 +96,7 @@ def __init__( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) diff --git a/torchmetrics/text/eed.py b/torchmetrics/text/eed.py index 649d56424e8..47d0589e50d 100644 --- a/torchmetrics/text/eed.py +++ b/torchmetrics/text/eed.py @@ -65,7 +65,7 @@ def __init__( rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) diff --git a/torchmetrics/text/mer.py b/torchmetrics/text/mer.py index 76e768a9509..3d48b2ad27a 100644 --- a/torchmetrics/text/mer.py +++ b/torchmetrics/text/mer.py @@ -61,7 +61,7 @@ class MatchErrorRate(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py index ddde5ae0bac..db5572cf51b 100644 --- a/torchmetrics/text/rouge.py +++ b/torchmetrics/text/rouge.py @@ -93,7 +93,7 @@ def __init__( tokenizer: Callable[[str], Sequence[str]] = None, accumulate: Literal["avg", "best"] = "best", rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) if use_stemmer or "rougeLsum" in rouge_keys: diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 9707beeb528..b725d9a20af 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -84,7 +84,7 @@ def __init__( tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, weights: Optional[Sequence[float]] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs) if tokenize not in AVAILABLE_TOKENIZERS: diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 2d5e33ebefc..d63e4a8ee84 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -56,7 +56,7 @@ class SQuAD(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) diff --git a/torchmetrics/text/ter.py b/torchmetrics/text/ter.py index 07853b86bfc..a50b9c2ba78 100644 --- a/torchmetrics/text/ter.py +++ b/torchmetrics/text/ter.py @@ -63,7 +63,7 @@ def __init__( lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) if not isinstance(normalize, bool): diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index 645aff65fed..880cc67464a 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -60,7 +60,7 @@ class WordErrorRate(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 6e69b1e3c36..a45c290be17 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -60,7 +60,7 @@ class WordInfoLost(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index a4f971add7b..f8b0a1568a5 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -59,7 +59,7 @@ class WordInfoPreserved(Metric): def __init__( self, - **kwargs: Dict[str, Any], + **kwargs: Any, ): super().__init__(**kwargs) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") diff --git a/torchmetrics/wrappers/bootstrapping.py b/torchmetrics/wrappers/bootstrapping.py index fd43cfb7c6e..0d258e11b4f 100644 --- a/torchmetrics/wrappers/bootstrapping.py +++ b/torchmetrics/wrappers/bootstrapping.py @@ -91,7 +91,7 @@ def __init__( quantile: Optional[Union[float, Tensor]] = None, raw: bool = False, sampling_strategy: str = "poisson", - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if not isinstance(base_metric, Metric): diff --git a/torchmetrics/wrappers/minmax.py b/torchmetrics/wrappers/minmax.py index 920ea1a08b2..2c75b259506 100644 --- a/torchmetrics/wrappers/minmax.py +++ b/torchmetrics/wrappers/minmax.py @@ -57,7 +57,7 @@ class MinMaxMetric(Metric): def __init__( self, base_metric: Metric, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if not isinstance(base_metric, Metric): From 41254f5f44f756a6b96b035249581e4e803176be Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 22:06:11 +0200 Subject: [PATCH 2/2] imports --- torchmetrics/aggregation.py | 2 +- torchmetrics/audio/pesq.py | 2 +- torchmetrics/audio/sdr.py | 2 +- torchmetrics/audio/snr.py | 2 +- torchmetrics/audio/stoi.py | 2 +- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/auc.py | 2 +- torchmetrics/classification/auroc.py | 2 +- torchmetrics/classification/avg_precision.py | 2 +- torchmetrics/classification/binned_precision_recall.py | 2 +- torchmetrics/classification/calibration_error.py | 2 +- torchmetrics/classification/cohen_kappa.py | 2 +- torchmetrics/classification/confusion_matrix.py | 2 +- torchmetrics/classification/dice.py | 2 +- torchmetrics/classification/f_beta.py | 2 +- torchmetrics/classification/hamming.py | 2 +- torchmetrics/classification/hinge.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/kl_divergence.py | 2 +- torchmetrics/classification/matthews_corrcoef.py | 2 +- torchmetrics/classification/precision_recall.py | 2 +- torchmetrics/classification/precision_recall_curve.py | 2 +- torchmetrics/classification/roc.py | 2 +- torchmetrics/classification/specificity.py | 2 +- torchmetrics/classification/stat_scores.py | 2 +- torchmetrics/functional/audio/pit.py | 2 +- torchmetrics/image/fid.py | 2 +- torchmetrics/image/inception.py | 2 +- torchmetrics/image/kid.py | 2 +- torchmetrics/image/lpip.py | 2 +- torchmetrics/image/psnr.py | 2 +- torchmetrics/image/ssim.py | 2 +- torchmetrics/image/uqi.py | 2 +- torchmetrics/regression/cosine_similarity.py | 2 +- torchmetrics/regression/explained_variance.py | 2 +- torchmetrics/regression/log_mse.py | 2 +- torchmetrics/regression/mae.py | 2 +- torchmetrics/regression/mape.py | 2 +- torchmetrics/regression/mse.py | 2 +- torchmetrics/regression/pearson.py | 2 +- torchmetrics/regression/r2.py | 2 +- torchmetrics/regression/spearman.py | 2 +- torchmetrics/regression/symmetric_mape.py | 2 +- torchmetrics/regression/tweedie_deviance.py | 2 +- torchmetrics/retrieval/base.py | 2 +- torchmetrics/retrieval/fall_out.py | 2 +- torchmetrics/retrieval/hit_rate.py | 2 +- torchmetrics/retrieval/ndcg.py | 2 +- torchmetrics/retrieval/precision.py | 2 +- torchmetrics/retrieval/precision_recall_curve.py | 2 +- torchmetrics/retrieval/recall.py | 2 +- torchmetrics/text/bleu.py | 2 +- torchmetrics/text/cer.py | 2 +- torchmetrics/text/eed.py | 2 +- torchmetrics/text/mer.py | 2 +- torchmetrics/text/sacre_bleu.py | 2 +- torchmetrics/text/ter.py | 2 +- torchmetrics/text/wer.py | 2 +- torchmetrics/text/wil.py | 2 +- torchmetrics/text/wip.py | 2 +- 60 files changed, 60 insertions(+), 60 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index dcf74b48fef..18e45deb1e7 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, List, Union import torch from torch import Tensor diff --git a/torchmetrics/audio/pesq.py b/torchmetrics/audio/pesq.py index 3c1c18c7fc1..8d7d8fa5043 100644 --- a/torchmetrics/audio/pesq.py +++ b/torchmetrics/audio/pesq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from torch import Tensor, tensor diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index 4798a2d320a..17b5e0e92a8 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index 6f6b36b6e10..49108779c60 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from torch import Tensor, tensor diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index 5ebe10345c8..db883e18219 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from torch import Tensor, tensor diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 1ea5b524885..ac00b6ccd17 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/classification/auc.py b/torchmetrics/classification/auc.py index be1a3e95e0f..975ad64dd4e 100644 --- a/torchmetrics/classification/auc.py +++ b/torchmetrics/classification/auc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from torch import Tensor diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index 7e88bb02e42..4a754c7c849 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/avg_precision.py b/torchmetrics/classification/avg_precision.py index d5d42ef3042..ea29d65fdbe 100644 --- a/torchmetrics/classification/avg_precision.py +++ b/torchmetrics/classification/avg_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 26ace5fcd5d..21be67c0475 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/classification/calibration_error.py b/torchmetrics/classification/calibration_error.py index 4f7b2953dd6..f9b7aa23dde 100644 --- a/torchmetrics/classification/calibration_error.py +++ b/torchmetrics/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, List import torch from torch import Tensor diff --git a/torchmetrics/classification/cohen_kappa.py b/torchmetrics/classification/cohen_kappa.py index 50c5fba1a4f..7146ac14480 100644 --- a/torchmetrics/classification/cohen_kappa.py +++ b/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 814fd324145..a847b04044b 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index ec12a60589c..8122fb2d879 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index d28ea98ef67..dd033c888b5 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor diff --git a/torchmetrics/classification/hamming.py b/torchmetrics/classification/hamming.py index ac24568eecb..be9a0bf430e 100644 --- a/torchmetrics/classification/hamming.py +++ b/torchmetrics/classification/hamming.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index 3a1582d4d84..9fb128a6f73 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from torch import Tensor, tensor diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 14260c58b67..cdc2a9b0022 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/kl_divergence.py b/torchmetrics/classification/kl_divergence.py index 6766caa091d..0dbb17f0e92 100644 --- a/torchmetrics/classification/kl_divergence.py +++ b/torchmetrics/classification/kl_divergence.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index b802901da7e..e16778099fe 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index ff6c55dda77..fa646431c4b 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index 4a0640c194f..ee4e29aecbc 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index 579a82b2757..7682dd758ac 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index e5dd88b6ad0..732f6dbf4a1 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index b65e8a36954..eca2150d63b 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import torch from torch import Tensor diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 6a04af7beb3..ea06de3bee6 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import permutations -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Tuple, Union from warnings import warn import torch diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 0053f5ecce6..0229fb8fdf2 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union import numpy as np import torch diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index 0fe54fe95da..6000e075908 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple, Union +from typing import Any, List, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 74bc92f10fb..189bc02109e 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/image/lpip.py b/torchmetrics/image/lpip.py index 058aac78fbd..c4c7e50aa83 100644 --- a/torchmetrics/image/lpip.py +++ b/torchmetrics/image/lpip.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, List import torch from torch import Tensor diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index 5be28c97b61..652d1f2a4b5 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index 08575a7a4cb..6eab726651f 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from torch import Tensor from typing_extensions import Literal diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 330420b9417..0832c045847 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, List, Optional, Sequence from torch import Tensor from typing_extensions import Literal diff --git a/torchmetrics/regression/cosine_similarity.py b/torchmetrics/regression/cosine_similarity.py index 71bbd621df0..1fe18d6e456 100644 --- a/torchmetrics/regression/cosine_similarity.py +++ b/torchmetrics/regression/cosine_similarity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, List import torch from torch import Tensor diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 6733177351f..1f35ad0f024 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Sequence, Union +from typing import Any, Sequence, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/log_mse.py b/torchmetrics/regression/log_mse.py index d3f9e9ae94d..d1f5f843463 100644 --- a/torchmetrics/regression/log_mse.py +++ b/torchmetrics/regression/log_mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/mae.py b/torchmetrics/regression/mae.py index 3260166c669..bac1d7026a3 100644 --- a/torchmetrics/regression/mae.py +++ b/torchmetrics/regression/mae.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/mape.py b/torchmetrics/regression/mape.py index 77e30fb0e30..e77f8acd595 100644 --- a/torchmetrics/regression/mape.py +++ b/torchmetrics/regression/mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/mse.py b/torchmetrics/regression/mse.py index 9a8c5a40070..33c754152b4 100644 --- a/torchmetrics/regression/mse.py +++ b/torchmetrics/regression/mse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 3fc13466956..463b8ed540a 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple import torch from torch import Tensor diff --git a/torchmetrics/regression/r2.py b/torchmetrics/regression/r2.py index 11a3be7dae8..1ceaabe6d04 100644 --- a/torchmetrics/regression/r2.py +++ b/torchmetrics/regression/r2.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor, tensor diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 28ffdaaea5f..4351d0444dc 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, List import torch from torch import Tensor diff --git a/torchmetrics/regression/symmetric_mape.py b/torchmetrics/regression/symmetric_mape.py index faa40203526..128c226dcab 100644 --- a/torchmetrics/regression/symmetric_mape.py +++ b/torchmetrics/regression/symmetric_mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from torch import Tensor, tensor diff --git a/torchmetrics/regression/tweedie_deviance.py b/torchmetrics/regression/tweedie_deviance.py index 7084ed41e7d..62174b0577a 100644 --- a/torchmetrics/regression/tweedie_deviance.py +++ b/torchmetrics/regression/tweedie_deviance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from torch import Tensor diff --git a/torchmetrics/retrieval/base.py b/torchmetrics/retrieval/base.py index 7ae07331574..9c0f27d3362 100644 --- a/torchmetrics/retrieval/base.py +++ b/torchmetrics/retrieval/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import torch from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/fall_out.py b/torchmetrics/retrieval/fall_out.py index ba6e49d14d2..17890669a2f 100644 --- a/torchmetrics/retrieval/fall_out.py +++ b/torchmetrics/retrieval/fall_out.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/hit_rate.py b/torchmetrics/retrieval/hit_rate.py index 40f07885ea9..0256c951e4c 100644 --- a/torchmetrics/retrieval/hit_rate.py +++ b/torchmetrics/retrieval/hit_rate.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/ndcg.py b/torchmetrics/retrieval/ndcg.py index dbf31dafc52..e20d5493400 100644 --- a/torchmetrics/retrieval/ndcg.py +++ b/torchmetrics/retrieval/ndcg.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/precision.py b/torchmetrics/retrieval/precision.py index f0197e9a9f1..ef004c26b10 100644 --- a/torchmetrics/retrieval/precision.py +++ b/torchmetrics/retrieval/precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/precision_recall_curve.py b/torchmetrics/retrieval/precision_recall_curve.py index fca3f31c5f1..5cb2946369f 100644 --- a/torchmetrics/retrieval/precision_recall_curve.py +++ b/torchmetrics/retrieval/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional, Tuple import torch from torch import Tensor, tensor diff --git a/torchmetrics/retrieval/recall.py b/torchmetrics/retrieval/recall.py index 1266c0bcf09..e9724621dd4 100644 --- a/torchmetrics/retrieval/recall.py +++ b/torchmetrics/retrieval/recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index a07354795ed..899b45f7872 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,7 +16,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Dict, Optional, Sequence +from typing import Any, Optional, Sequence import torch from torch import Tensor, tensor diff --git a/torchmetrics/text/cer.py b/torchmetrics/text/cer.py index 31466da72e9..183d0cc2494 100644 --- a/torchmetrics/text/cer.py +++ b/torchmetrics/text/cer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, List, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/text/eed.py b/torchmetrics/text/eed.py index 47d0589e50d..0a35098674e 100644 --- a/torchmetrics/text/eed.py +++ b/torchmetrics/text/eed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, List, Sequence, Tuple, Union from torch import Tensor, stack from typing_extensions import Literal diff --git a/torchmetrics/text/mer.py b/torchmetrics/text/mer.py index 3d48b2ad27a..43b6158fa05 100644 --- a/torchmetrics/text/mer.py +++ b/torchmetrics/text/mer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, List, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index b725d9a20af..28dae1fb856 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Dict, Optional, Sequence +from typing import Any, Optional, Sequence from typing_extensions import Literal diff --git a/torchmetrics/text/ter.py b/torchmetrics/text/ter.py index a50b9c2ba78..29fe89d91f5 100644 --- a/torchmetrics/text/ter.py +++ b/torchmetrics/text/ter.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/text/wer.py b/torchmetrics/text/wer.py index 880cc67464a..2f02e144fa6 100644 --- a/torchmetrics/text/wer.py +++ b/torchmetrics/text/wer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, List, Union import torch from torch import Tensor, tensor diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index a45c290be17..e6bff807dd3 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, List, Union from torch import Tensor, tensor diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index f8b0a1568a5..6609282bec8 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, List, Union from torch import Tensor, tensor