Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/super_reset_multioutput_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Jan 24, 2023
2 parents 199b9e3 + d233c9d commit 0add0ce
Show file tree
Hide file tree
Showing 71 changed files with 1,031 additions and 1,065 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
- name: Set PyTorch version
if: inputs.requires != 'oldest'
run: |
pip install packaging
pip install packaging -q
python ./requirements/adjust-versions.py requirements.txt ${{ matrix.pytorch-version }}
- name: full chashing
Expand Down
12 changes: 8 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ repos:
hooks:
- id: yesqa

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.226'
hooks:
- id: flake8
name: PEP8
- id: ruff
# Respect `exclude` and `extend-exclude` settings.
args:
- "--fix"
- "--respect-gitignore"
- "--force-exclude"
97 changes: 95 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,54 @@
[metadata]
license_file = "LICENSE"
description-file = "README.md"

[build-system]
requires = [
"setuptools",
"wheel",
]


[tool.check-manifest]
ignore = [
"*.yml",
".github",
".github/*"
]


[tool.pytest.ini_options]
norecursedirs = [
".git",
".github",
"dist",
"build",
"docs",
]
addopts = [
"--strict-markers",
"--doctest-modules",
"--doctest-plus",
"--color=yes",
"--disable-pytest-warnings",
]
# ToDo
#filterwarnings = ["error::FutureWarning"]
xfail_strict = true
junit_duration_report = "call"

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"pass",
]

[tool.coverage.run]
parallel = true
concurrency = "thread"
relative_files = true


[tool.black]
# https://github.com/psf/black
line-length = 120
Expand All @@ -20,8 +65,56 @@ skip_glob = []
profile = "black"
line_length = 120

[tool.autopep8]
ignore = ["E731"]

[tool.ruff]
line-length = 120
# Enable Pyflakes `E` and `F` codes by default.
select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
# TODO
# "D", # see: https://pypi.org/project/pydocstyle
# "N", # see: https://pypi.org/project/pep8-naming
]
#extend-select = [
# "C4", # see: https://pypi.org/project/flake8-comprehensions
# "PT", # see: https://pypi.org/project/flake8-pytest-style
# "RET", # see: https://pypi.org/project/flake8-return
# "SIM", # see: https://pypi.org/project/flake8-simplify
#]
ignore = [
"E731", # Do not assign a lambda expression, use a def
]
# Exclude a variety of commonly ignored directories.
exclude = [
".eggs",
".git",
".mypy_cache",
".ruff_cache",
"__pypackages__",
"_build",
"build",
"dist",
"docs"
]
ignore-init-module-imports = true

[tool.ruff.per-file-ignores]
"setup.py" = ["D100", "SIM115"]
"__about__.py" = ["D100"]
"__init__.py" = ["D100"]

[tool.ruff.pydocstyle]
# Use Google-style docstrings.
convention = "google"

#[tool.ruff.pycodestyle]
#ignore-overlong-task-comments = true

[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10


[tool.mypy]
files = [
Expand Down
6 changes: 3 additions & 3 deletions requirements/doctest.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytest
pytest-doctestplus
pytest-rerunfailures
pytest>=6.0.0, <7.2.0
pytest-doctestplus>=0.9.0
pytest-rerunfailures>=10.0
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
coverage>5.2
pytest==6.*
pytest>=6.0.0, <7.2.0
pytest-cov>2.10
# pytest-flake8
pytest-doctestplus>=0.9.0
Expand Down
58 changes: 0 additions & 58 deletions setup.cfg

This file was deleted.

16 changes: 8 additions & 8 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ class MaxMetric(BaseAggregator):
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(3.)
"""
Expand Down Expand Up @@ -179,11 +179,11 @@ class MinMetric(BaseAggregator):
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(1.)
"""
Expand Down Expand Up @@ -240,11 +240,11 @@ class SumMetric(BaseAggregator):
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(6.)
"""
Expand Down Expand Up @@ -299,11 +299,11 @@ class CatMetric(BaseAggregator):
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import CatMetric
>>> metric = CatMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor([1., 2., 3.])
"""
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class SignalDistortionRatio(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand Down Expand Up @@ -134,10 +134,10 @@ class ScaleInvariantSignalDistortionRatio(Metric):
if target and preds have a different shape
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import ScaleInvariantSignalDistortionRatio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = ScaleInvariantSignalDistortionRatio()
>>> si_sdr(preds, target)
tensor(18.4030)
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class SignalNoiseRatio(Metric):
if target and preds have a different shape
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import SignalNoiseRatio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
Expand Down Expand Up @@ -103,10 +103,10 @@ class ScaleInvariantSignalNoiseRatio(Metric):
if target and preds have a different shape
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics import ScaleInvariantSignalNoiseRatio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = ScaleInvariantSignalNoiseRatio()
>>> si_snr(preds, target)
tensor(15.0918)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class ShortTimeObjectiveIntelligibility(Metric):
If ``pystoi`` package is not installed
Example:
>>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
>>> import torch
>>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
Expand Down
Loading

0 comments on commit 0add0ce

Please sign in to comment.