Skip to content

Commit

Permalink
Merge branch 'master' into ter
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Dec 6, 2021
2 parents 0d4828e + 4e26593 commit a7dc808
Show file tree
Hide file tree
Showing 28 changed files with 78 additions and 62 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))


- Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463))


- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))


Expand Down
8 changes: 4 additions & 4 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Audio Metrics
pesq [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pesq
.. autofunction:: torchmetrics.functional.audio.pesq.pesq


pit [func]
Expand Down Expand Up @@ -55,7 +55,7 @@ snr [func]
stoi [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.stoi
.. autofunction:: torchmetrics.functional.audio.stoi.stoi
:noindex:


Expand Down Expand Up @@ -433,7 +433,7 @@ Text
bert_score [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bert_score
.. autofunction:: torchmetrics.functional.text.bert.bert_score

bleu_score [func]
~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -462,7 +462,7 @@ match_error_rate [func]
rouge_score [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.rouge_score
.. autofunction:: torchmetrics.functional.text.rouge.rouge_score
:noindex:

sacre_bleu_score [func]
Expand Down
18 changes: 9 additions & 9 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ the metric will be computed over the ``time`` dimension.
PESQ
~~~~

.. autoclass:: torchmetrics.PESQ
.. autoclass:: torchmetrics.audio.pesq.PESQ

PIT
~~~
Expand Down Expand Up @@ -112,7 +112,7 @@ SNR
STOI
~~~~

.. autoclass:: torchmetrics.STOI
.. autoclass:: torchmetrics.audio.stoi.STOI
:noindex:


Expand Down Expand Up @@ -369,25 +369,25 @@ learning algorithms such as `Generative Adverserial Networks (GANs) <https://en.
FID
~~~

.. autoclass:: torchmetrics.FID
.. autoclass:: torchmetrics.image.fid.FID
:noindex:

IS
~~

.. autoclass:: torchmetrics.IS
.. autoclass:: torchmetrics.image.inception.IS
:noindex:

KID
~~~

.. autoclass:: torchmetrics.KID
.. autoclass:: torchmetrics.image.kid.KID
:noindex:

LPIPS
~~~~~

.. autoclass:: torchmetrics.LPIPS
.. autoclass:: torchmetrics.image.lpip_similarity.LPIPS
:noindex:

PSNR
Expand All @@ -411,7 +411,7 @@ Object detection metrics can be used to evaluate the predicted detections with g
MAP
~~~

.. autoclass:: torchmetrics.MAP
.. autoclass:: torchmetrics.detection.map.MAP
:noindex:

******************
Expand Down Expand Up @@ -613,7 +613,7 @@ Text
BERTScore
~~~~~~~~~~

.. autoclass:: torchmetrics.BERTScore
.. autoclass:: torchmetrics.text.bert.BERTScore
:noindex:

BLEUScore
Expand Down Expand Up @@ -643,7 +643,7 @@ MatchErrorRate
ROUGEScore
~~~~~~~~~~

.. autoclass:: torchmetrics.ROUGEScore
.. autoclass:: torchmetrics.text.rouge.ROUGEScore
:noindex:

SacreBLEUScore
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio import PESQ
from torchmetrics.functional import pesq
from torchmetrics.audio.pesq import PESQ
from torchmetrics.functional.audio.pesq import pesq
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio import STOI
from torchmetrics.functional import stoi
from torchmetrics.audio.stoi import STOI
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down
2 changes: 1 addition & 1 deletion tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def generate_cov(n):

scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real
tm_res = sqrtm(cov1 @ cov2)
assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3)
assert torch.allclose(torch.tensor(scipy_res).float().trace(), tm_res.trace())


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
Expand Down
4 changes: 2 additions & 2 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.distributed as dist
import torch.multiprocessing as mp

from torchmetrics.functional import bert_score as metrics_bert_score
from torchmetrics.text import BERTScore
from torchmetrics.functional.text.bert import bert_score as metrics_bert_score
from torchmetrics.text.bert import BERTScore
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE

if _BERTSCORE_AVAILABLE:
Expand Down
3 changes: 2 additions & 1 deletion tests/wrappers/test_minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class TestMinMaxWrapper(MetricTester):

atol = 1e-6

@pytest.mark.parametrize("ddp", [True, False])
# TODO: fix ddp=True case, difference in how compare function works and wrapper metric
@pytest.mark.parametrize("ddp", [False])
def test_minmax_wrapper(self, preds, target, base_metric, ddp):
self.run_class_metric_test(
ddp,
Expand Down
16 changes: 2 additions & 14 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PESQ, PIT, SDR, SI_SDR, SI_SNR, SNR, STOI # noqa: E402
from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR # noqa: E402
from torchmetrics.classification import ( # noqa: E402
AUC,
AUROC,
Expand All @@ -40,8 +40,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection import MAP # noqa: E402
from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402
from torchmetrics.image import PSNR, SSIM # noqa: E402
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
Expand Down Expand Up @@ -69,12 +68,10 @@
from torchmetrics.text import ( # noqa: E402
TER,
WER,
BERTScore,
BLEUScore,
CharErrorRate,
CHRFScore,
MatchErrorRate,
ROUGEScore,
SacreBLEUScore,
SQuAD,
WordInfoLost,
Expand All @@ -91,7 +88,6 @@
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
"BinnedRecallAtFixedPrecision",
"BERTScore",
"BLEUScore",
"BootStrapper",
"CalibrationError",
Expand All @@ -104,15 +100,10 @@
"ExplainedVariance",
"F1",
"FBeta",
"FID",
"HammingDistance",
"Hinge",
"IoU",
"IS",
"KID",
"KLDivergence",
"LPIPS",
"MAP",
"MatthewsCorrcoef",
"MaxMetric",
"MeanAbsoluteError",
Expand All @@ -127,7 +118,6 @@
"MinMetric",
"MultioutputWrapper",
"PearsonCorrcoef",
"PESQ",
"PIT",
"Precision",
"PrecisionRecallCurve",
Expand All @@ -143,7 +133,6 @@
"RetrievalRecall",
"RetrievalRPrecision",
"ROC",
"ROUGEScore",
"SacreBLEUScore",
"SDR",
"SI_SDR",
Expand All @@ -154,7 +143,6 @@
"SQuAD",
"SSIM",
"StatScores",
"STOI",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TER",
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.audio.pesq import PESQ # noqa: F401
from torchmetrics.audio.pit import PIT # noqa: F401
from torchmetrics.audio.sdr import SDR # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
from torchmetrics.audio.stoi import STOI # noqa: F401
2 changes: 1 addition & 1 deletion torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PESQ(Metric):
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> from torchmetrics.audio import PESQ
>>> from torchmetrics.audio.pesq import PESQ
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class STOI(Metric):
If ``pystoi`` package is not installed
Example:
>>> from torchmetrics.audio import STOI
>>> from torchmetrics.audio.stoi import STOI
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
Expand Down
1 change: 0 additions & 1 deletion torchmetrics/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.detection.map import MAP # noqa: F401
36 changes: 36 additions & 0 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,42 @@ class MAP(Metric):
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Example:
>>> import torch
>>> from torchmetrics.detection.map import MAP
>>> preds = [
... dict(
... boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
... scores=torch.Tensor([0.536]),
... labels=torch.IntTensor([0]),
... )
... ]
>>> target = [
... dict(
... boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
... labels=torch.IntTensor([0]),
... )
... ]
>>> metric = MAP() # doctest: +SKIP
>>> metric.update(preds, target) # doctest: +SKIP
>>> from pprint import pprint
>>> pprint(metric.compute()) # doctest: +SKIP
{'map': tensor(0.6000),
'map_50': tensor(1.),
'map_75': tensor(1.),
'map_small': tensor(-1.),
'map_medium': tensor(-1.),
'map_large': tensor(0.6000),
'mar_1': tensor(0.6000),
'mar_10': tensor(0.6000),
'mar_100': tensor(0.6000),
'mar_small': tensor(-1.),
'mar_medium': tensor(-1.),
'mar_large': tensor(0.6000),
'map_per_class': tensor(-1.),
'mar_100_per_class': tensor(-1.)
}
Raises:
ImportError:
If ``torchvision`` is not installed or version installed is lower than 0.8.0
Expand Down
4 changes: 0 additions & 4 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pesq import pesq
from torchmetrics.functional.audio.pit import pit, pit_permutate
from torchmetrics.functional.audio.sdr import sdr
from torchmetrics.functional.audio.si_sdr import si_sdr
from torchmetrics.functional.audio.si_snr import si_snr
from torchmetrics.functional.audio.snr import snr
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.functional.classification.accuracy import accuracy
from torchmetrics.functional.classification.auc import auc
from torchmetrics.functional.classification.auroc import auroc
Expand Down Expand Up @@ -110,7 +108,6 @@
"pairwise_linear_similarity",
"pairwise_manhatten_distance",
"pearson_corrcoef",
"pesq",
"pit",
"pit_permutate",
"precision",
Expand Down Expand Up @@ -139,7 +136,6 @@
"squad",
"ssim",
"stat_scores",
"stoi",
"symmetric_mean_absolute_percentage_error",
"ter",
"wer",
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pesq import pesq # noqa: F401
from torchmetrics.functional.audio.pit import pit, pit_permutate # noqa: F401
from torchmetrics.functional.audio.sdr import sdr # noqa: F401
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import snr # noqa: F401
from torchmetrics.functional.audio.stoi import stoi # noqa: F401
2 changes: 1 addition & 1 deletion torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bo
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> from torchmetrics.functional.audio import pesq
>>> from torchmetrics.functional.audio.pesq import pesq
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa
If ``pystoi`` package is not installed
Example:
>>> from torchmetrics.functional.audio import stoi
>>> from torchmetrics.functional.audio.stoi import stoi
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def bert_score(
If invalid input is provided.
Example:
>>> from torchmetrics.functional.text.bert import bert_score
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "master kenobi"]
>>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP
Expand Down
Loading

0 comments on commit a7dc808

Please sign in to comment.