Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Imports #1681

Merged
merged 8 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 23 additions & 121 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging as __logging
import os

from torchmetrics.__about__ import * # noqa: F401, F403
from torchmetrics.__about__ import * # noqa: F403

_logger = __logging.getLogger("torchmetrics")
_logger.addHandler(__logging.StreamHandler())
Expand All @@ -13,13 +13,6 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PermutationInvariantTraining # noqa: E402
from torchmetrics.audio import ( # noqa: E402
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
SignalDistortionRatio,
SignalNoiseRatio,
)
from torchmetrics.classification import ( # noqa: E402
AUROC,
ROC,
Expand All @@ -43,26 +36,10 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402
from torchmetrics.image import ( # noqa: E402
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
RelativeAverageSpectralError,
RootMeanSquaredErrorUsingSlidingWindow,
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
TotalVariation,
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.nominal import PearsonsContingencyCoefficient # noqa: E402
from torchmetrics.nominal import TheilsU, TschuprowsT # noqa: E402
from torchmetrics.regression import ConcordanceCorrCoef # noqa: E402
from torchmetrics.regression import CosineSimilarity # noqa: E402
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
ExplainedVariance,
KendallRankCorrCoef,
KLDivergence,
Expand All @@ -79,126 +56,51 @@
TweedieDevianceScore,
WeightedMeanAbsolutePercentageError,
)
from torchmetrics.retrieval import RetrievalFallOut # noqa: E402
from torchmetrics.retrieval import RetrievalHitRate # noqa: E402
from torchmetrics.retrieval import ( # noqa: E402
RetrievalMAP,
RetrievalMRR,
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalPrecisionRecallCurve,
RetrievalRecall,
RetrievalRecallAtFixedPrecision,
RetrievalRPrecision,
)
from torchmetrics.text import ( # noqa: E402
BLEUScore,
CharErrorRate,
CHRFScore,
ExtendedEditDistance,
MatchErrorRate,
Perplexity,
SacreBLEUScore,
SQuAD,
TranslationEditRate,
WordErrorRate,
WordInfoLost,
WordInfoPreserved,
)
from torchmetrics.wrappers import BootStrapper # noqa: E402
from torchmetrics.wrappers import ClasswiseWrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402

__all__ = [
"Metric",
"MetricCollection",
"functional",
"Accuracy",
"CatMetric",
"MaxMetric",
"MeanMetric",
"MinMetric",
"SumMetric",
"AUROC",
"ROC",
"Accuracy",
"AveragePrecision",
"BLEUScore",
"BootStrapper",
"CalibrationError",
"CatMetric",
"ClasswiseWrapper",
"CharErrorRate",
"CHRFScore",
"ConcordanceCorrCoef",
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
"CramersV",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExactMatch",
"ExplainedVariance",
"ExtendedEditDistance",
"F1Score",
"FBetaScore",
"HammingDistance",
"HingeLoss",
"JaccardIndex",
"MatthewsCorrCoef",
"Precision",
"PrecisionRecallCurve",
"Recall",
"Specificity",
"StatScores",
"ConcordanceCorrCoef",
"CosineSimilarity",
"ExplainedVariance",
"KendallRankCorrCoef",
"KLDivergence",
"LogCoshError",
"MatchErrorRate",
"MatthewsCorrCoef",
"MaxMetric",
"MeanAbsoluteError",
"MeanAbsolutePercentageError",
"MeanMetric",
"MeanSquaredError",
"MeanSquaredLogError",
"Metric",
"MetricCollection",
"MetricTracker",
"MinMaxMetric",
"MinMetric",
"ModifiedPanopticQuality",
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PanopticQuality",
"MinkowskiDistance",
"PearsonCorrCoef",
"PearsonsContingencyCoefficient",
"PermutationInvariantTraining",
"Perplexity",
"Precision",
"PrecisionRecallCurve",
"PeakSignalNoiseRatio",
"R2Score",
"Recall",
"RelativeAverageSpectralError",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
"RetrievalMRR",
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalRecall",
"RetrievalRPrecision",
"RetrievalPrecisionRecallCurve",
"RetrievalRecallAtFixedPrecision",
"ROC",
"RootMeanSquaredErrorUsingSlidingWindow",
"SacreBLEUScore",
"SignalDistortionRatio",
"ScaleInvariantSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpectralAngleMapper",
"SpectralDistortionIndex",
"SQuAD",
"StructuralSimilarityIndexMeasure",
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TheilsU",
"TotalVariation",
"TranslationEditRate",
"TschuprowsT",
"UniversalImageQualityIndex",
"TweedieDevianceScore",
"WeightedMeanAbsolutePercentageError",
"WordErrorRate",
"WordInfoLost",
"WordInfoPreserved",
]
6 changes: 3 additions & 3 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ScaleInvariantSignalDistortionRatio(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics import ScaleInvariantSignalDistortionRatio
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = ScaleInvariantSignalDistortionRatio()
Expand Down Expand Up @@ -242,7 +242,7 @@ def plot(

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = torch.randn(5)
>>> preds = torch.randn(5)
>>> metric = ScaleInvariantSignalDistortionRatio()
Expand All @@ -254,7 +254,7 @@ def plot(

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = torch.randn(5)
>>> preds = torch.randn(5)
>>> metric = ScaleInvariantSignalDistortionRatio()
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SignalNoiseRatio(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics import SignalNoiseRatio
>>> from torchmetrics.audio import SignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SignalNoiseRatio()
Expand Down Expand Up @@ -111,7 +111,7 @@ def plot(

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.snr import SignalNoiseRatio
>>> from torchmetrics.audio import SignalNoiseRatio
>>> metric = SignalNoiseRatio()
>>> metric.update(torch.rand(4), torch.rand(4))
>>> fig_, ax_ = metric.plot()
Expand All @@ -121,7 +121,7 @@ def plot(

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.snr import SignalNoiseRatio
>>> from torchmetrics.audio import SignalNoiseRatio
>>> metric = SignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
Expand Down Expand Up @@ -152,7 +152,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics import ScaleInvariantSignalNoiseRatio
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = ScaleInvariantSignalNoiseRatio()
Expand Down Expand Up @@ -207,7 +207,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> metric = ScaleInvariantSignalNoiseRatio()
>>> metric.update(torch.rand(4), torch.rand(4))
>>> fig_, ax_ = metric.plot()
Expand All @@ -217,7 +217,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> metric = ScaleInvariantSignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ShortTimeObjectiveIntelligibility(Metric):

Example:
>>> import torch
>>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand Down Expand Up @@ -131,7 +131,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand All @@ -144,7 +144,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> metric = ShortTimeObjectiveIntelligibility(8000, False)
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class MeanAveragePrecision(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
>>> from torchmetrics.detection import MeanAveragePrecision
>>> preds = [
... dict(
... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/detection/modified_panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ModifiedPanopticQuality(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
Expand Down Expand Up @@ -172,7 +172,7 @@ def plot(
:scale: 75

>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
Expand All @@ -192,7 +192,7 @@ def plot(

>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics import ModifiedPanopticQuality
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/detection/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class PanopticQuality(Metric):

Example:
>>> from torch import tensor
>>> from torchmetrics import PanopticQuality
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
Expand Down Expand Up @@ -175,7 +175,7 @@ def plot(
:scale: 75

>>> from torch import tensor
>>> from torchmetrics import PanopticQuality
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
Expand All @@ -195,7 +195,7 @@ def plot(

>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics import PanopticQuality
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SpectralDistortionIndex(Metric):
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> sdi = SpectralDistortionIndex()
Expand Down Expand Up @@ -127,7 +127,7 @@ def plot(
>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
Expand All @@ -140,7 +140,7 @@ def plot(
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
Expand Down
Loading