Skip to content

Commit

Permalink
Docs: update formatting (#954)
Browse files Browse the repository at this point in the history
* map
* audio
* classify
* image
* regression
* text
* utils
* Apply suggestions from code review

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 14, 2022
1 parent 5216c50 commit 13e6781
Show file tree
Hide file tree
Showing 148 changed files with 1,216 additions and 1,739 deletions.
1 change: 1 addition & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pesq>=0.0.3
pystoi
fast-bss-eval>=0.1.0
torch_complex # needed for fast-bss-eval torch<=1.7
4 changes: 2 additions & 2 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,14 @@ def test_error_on_wrong_shape() -> None:

def test_consistency_of_two_implementations() -> None:
from torchmetrics.functional.audio.pit import (
_find_best_perm_by_exhuastive_method,
_find_best_perm_by_exhaustive_method,
_find_best_perm_by_linear_sum_assignment,
)

shapes_test = [(5, 2, 2), (4, 3, 3), (4, 4, 4), (3, 5, 5)]
for shp in shapes_test:
metric_mtx = torch.randn(size=shp)
bm1, bp1 = _find_best_perm_by_linear_sum_assignment(metric_mtx, torch.max)
bm2, bp2 = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max)
bm2, bp2 = _find_best_perm_by_exhaustive_method(metric_mtx, torch.max)
assert torch.allclose(bm1, bm2)
assert (bp1 == bp2).all()
18 changes: 6 additions & 12 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class BaseAggregator(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -115,8 +114,7 @@ class MaxMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -173,8 +171,7 @@ class MinMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -231,8 +228,7 @@ class SumMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -288,8 +284,7 @@ class CatMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -346,8 +341,7 @@ class MeanMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down
14 changes: 5 additions & 9 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class PerceptualEvaluationSpeechQuality(Metric):
"""Perceptual Evaluation of Speech Quality (PESQ)
This is a wrapper for the pesq package [1]. . Note that input will be moved to `cpu`
This is a wrapper for the pesq package [1]. Note that input will be moved to `cpu`
to perform the metric calculation.
.. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install
Expand All @@ -39,20 +39,16 @@ class PerceptualEvaluationSpeechQuality(Metric):
- ``target``: ``shape [...,time]``
Args:
fs:
sampling frequency, should be 16000 or 8000 (Hz)
mode:
'wb' (wide-band) or 'nb' (narrow-band)
keep_same_device:
whether to move the pesq value to the device of preds
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
keep_same_device: whether to move the pesq value to the device of preds
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ModuleNotFoundError:
Expand Down
7 changes: 3 additions & 4 deletions torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class PermutationInvariantTraining(Metric):
Args:
metric_func:
a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...],
target[:, j, ...]), and returns a batch of metric tensors [batch]
a metric function accept a batch of target and estimate,
i.e. ``metric_func(preds[:, i, ...], target[:, j, ...])``, and returns a batch of metric tensors ``[batch]``
eval_func:
the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better
or the larger the better.
Expand All @@ -43,8 +43,7 @@ class PermutationInvariantTraining(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments for either the `metric_func` or distributed communication,
kwargs: Additional keyword arguments for either the ``metric_func`` or distributed communication,
see :ref:`Metric kwargs` for more info.
Returns:
Expand Down
27 changes: 10 additions & 17 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,24 @@ class SignalDistortionRatio(Metric):
Args:
use_cg_iter:
If provided, an iterative method is used to solve for the distortion
filter coefficients instead of direct Gaussian elimination.
This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide
good accuracy in most cases and is sufficient when using this
loss to train neural separation networks.
filter_length:
The length of the distortion filter allowed
If provided, an iterative method is used to solve for the distortion filter coefficients instead
of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient
when using this loss to train neural separation networks.
filter_length: The length of the distortion filter allowed
zero_mean:
When set to True, the mean of all signals is subtracted prior to computation of the metrics
load_diag:
If provided, this small value is added to the diagonal coefficients of
the system metrics when solving for the filter coefficients.
This can help stabilize the metric in the case where some of the reference
If provided, this small value is added to the diagonal coefficients of the system metrics when solving
for the filter coefficients. This can help stabilize the metric in the case where some reference
signals may sometimes be zero
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ModuleNotFoundError:
Expand Down Expand Up @@ -155,16 +150,14 @@ class ScaleInvariantSignalDistortionRatio(Metric):
- ``target``: ``shape [...,time]``
Args:
zero_mean:
if to zero mean target and preds or not
zero_mean: if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
TypeError:
Expand Down
9 changes: 3 additions & 6 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ class SignalNoiseRatio(Metric):
- ``target``: ``shape [..., time]``
Args:
zero_mean:
if to zero mean target and preds or not
zero_mean: if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
TypeError:
Expand Down Expand Up @@ -116,8 +114,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
TypeError:
Expand Down
13 changes: 5 additions & 8 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class ShortTimeObjectiveIntelligibility(Metric):
Note that input will be moved to `cpu` to perform the metric calculation.
Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due
to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations.
to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations.
The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good
alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are
interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms,
on speech intelligibility. Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/).
on speech intelligibility. Description taken from `Cees Taal's website <http://www.ceestaal.nl/code/>`_.
.. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install pystoi``
Expand All @@ -42,18 +42,15 @@ class ShortTimeObjectiveIntelligibility(Metric):
- ``target``: ``shape [...,time]``
Args:
fs:
sampling frequency (Hz)
extended:
whether to use the extended STOI described in [4]
fs: sampling frequency (Hz)
extended: whether to use the extended STOI described in [4]
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
average STOI value
Expand Down
7 changes: 3 additions & 4 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Accuracy(StatScores):
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the `preds` or `target`,
.. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``,
the value for the class will be ``nan``.
mdmc_average:
Expand All @@ -98,7 +98,7 @@ class Accuracy(StatScores):
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Expand Down Expand Up @@ -132,8 +132,7 @@ class Accuracy(StatScores):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down
3 changes: 1 addition & 2 deletions torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class AUC(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
"""
is_differentiable = False
x: List[Tensor]
Expand Down
12 changes: 6 additions & 6 deletions torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class AUROC(Metric):
Args:
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
which for binary problem is translated to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
range ``[0, num_classes-1]``
average:
- ``'micro'`` computes metric globally. Only works for multilabel problems
- ``'macro'`` computes metric for each class and uniformly averages them
Expand All @@ -61,23 +62,22 @@ class AUROC(Metric):
- ``None`` computes and returns the metric per class
max_fpr:
If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
range ``[0, max_fpr]``. Should be a float between 0 and 1.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
ValueError:
If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
RuntimeError:
If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize``
If ``PyTorch version`` is ``below 1.6`` since ``max_fpr`` requires ``torch.bucketize``
which is not available below 1.6.
ValueError:
If the mode of data (binary, multi-label, multi-class) changes between batches.
Expand Down
10 changes: 4 additions & 6 deletions torchmetrics/classification/avg_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class AveragePrecision(Metric):
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
which for binary problem is translated to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
range ``[0, num_classes-1]``
average:
defines the reduction that is applied in the case of multiclass and multilabel input.
Should be one of the following:
Expand All @@ -63,8 +63,7 @@ class AveragePrecision(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (binary case):
>>> from torchmetrics import AveragePrecision
Expand Down Expand Up @@ -133,8 +132,7 @@ def compute(self) -> Union[Tensor, List[Tensor]]:
"""Compute the average precision score.
Returns:
tensor with average precision. If multiclass will return list
of such tensors, one for each class
tensor with average precision. If multiclass return list of such tensors, one for each class
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
Expand Down
Loading

0 comments on commit 13e6781

Please sign in to comment.