Skip to content

Commit

Permalink
tests: move reset seed to fixture [1/2] (#2702)
Browse files Browse the repository at this point in the history
* move reset seed to fixture
* updating test outputs
* Update src/conftest.py
  • Loading branch information
Borda authored Aug 29, 2024
1 parent 773fb92 commit 79b33ce
Show file tree
Hide file tree
Showing 51 changed files with 284 additions and 297 deletions.
40 changes: 40 additions & 0 deletions src/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path
from typing import Optional

from lightning_utilities.core.imports import package_available

if package_available("pytest") and package_available("doctest"):
import doctest

import pytest

MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED")

@pytest.fixture(autouse=True)
def reset_random_seed(seed: int = 42) -> None: # noqa: PT004
"""Reset the random seed before running each doctest."""
import random

import numpy as np
import torch

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

class DoctestModule(pytest.Module):
"""A custom module class that augments collected doctests with the reset_random_seed fixture."""

def collect(self) -> GeneratorExit:
"""Augment collected doctests with the reset_random_seed fixture."""
for item in super().collect():
if isinstance(item, pytest.DoctestItem):
item.add_marker(pytest.mark.usefixtures("reset_random_seed"))
yield item

def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]:
"""Collect doctests and add the reset_random_seed fixture."""
if path.ext == ".py":
return DoctestModule.from_parent(parent, fspath=path)
return None
6 changes: 2 additions & 4 deletions src/torchmetrics/audio/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class _PermutationInvariantTraining(PermutationInvariantTraining):
>>> import torch
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
Expand Down Expand Up @@ -79,20 +78,19 @@ class _SignalDistortionRatio(SignalDistortionRatio):
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = _SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
tensor(-11.9930)
>>> # use with pit
>>> from torchmetrics.functional import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.6051)
tensor(-11.7277)
"""

Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/audio/dnsmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ class DeepNoiseSuppressionMeanOpinionScore(Metric):
Example:
>>> from torch import randn
>>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore
>>> g = torch.manual_seed(1)
>>> preds = randn(8000)
>>> dnsmos = DeepNoiseSuppressionMeanOpinionScore(8000, False)
>>> dnsmos(preds)
tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64)
tensor([2.2687, 2.0766, 1.1375, 1.2722], dtype=torch.float64)
"""

Expand Down
11 changes: 5 additions & 6 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,16 @@ class PerceptualEvaluationSpeechQuality(Metric):
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> pesq(preds, target)
tensor(2.2076)
tensor(2.2885)
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
tensor(1.6805)
"""

Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ class PermutationInvariantTraining(Metric):
see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> preds = randn(3, 2, 5) # [batch, spk, time]
>>> target = randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
Expand Down
30 changes: 14 additions & 16 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,22 @@ class SignalDistortionRatio(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import SignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
tensor(-11.9930)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> preds = randn(4, 2, 8000) # [batch, spk, time]
>>> target = randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.6051)
tensor(-11.7277)
"""

Expand Down Expand Up @@ -302,23 +301,22 @@ class SourceAggregatedSignalDistortionRatio(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> preds = randn(2, 8000) # [..., spk, time]
>>> target = randn(2, 8000)
>>> sasdr = SourceAggregatedSignalDistortionRatio()
>>> sasdr(preds, target)
tensor(-41.6579)
tensor(-50.8171)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> preds = randn(4, 2, 8000) # [batch, spk, time]
>>> target = randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
... mode="permutation-wise", eval_func="max")
>>> pit(preds, target)
tensor(-41.2790)
tensor(-43.9780)
"""

Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,13 @@ class ComplexScaleInvariantSignalNoiseRatio(Metric):
If ``preds`` and ``target`` does not have the same shape.
Example:
>>> import torch
>>> from torch import tensor
>>> from torch import randn
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> preds = randn((1,257,100,2))
>>> target = randn((1,257,100,2))
>>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio()
>>> c_si_snr(preds, target)
tensor(-63.4849)
tensor(-38.8832)
"""

Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ class SpeechReverberationModulationEnergyRatio(Metric):
If ``gammatone`` or ``torchaudio`` package is not installed
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> preds = randn(8000)
>>> srmr = SpeechReverberationModulationEnergyRatio(8000)
>>> srmr(preds)
tensor(0.3354)
tensor(0.3191)
"""

Expand Down
23 changes: 10 additions & 13 deletions src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ class ShortTimeObjectiveIntelligibility(Metric):
If ``pystoi`` package is not installed
Example:
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> stoi = ShortTimeObjectiveIntelligibility(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
tensor(-0.0842)
"""

Expand Down Expand Up @@ -132,11 +131,10 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> metric = ShortTimeObjectiveIntelligibility(8000, False)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
Expand All @@ -145,12 +143,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torch import randn
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> metric = ShortTimeObjectiveIntelligibility(8000, False)
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,25 +302,23 @@ def plot(
.. plot::
:scale: 75
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> metric.update(torch.rand(20), torch.randint(2,(20,)), torch.randint(2,(20,)))
>>> metric.update(rand(20), randint(2, (20,)), randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import ones, rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20), torch.randint(2,(20,)), torch.ones(20).long()))
... values.append(metric(rand(20), randint(2, (20,) ), ones(20).long()))
>>> fig_, ax_ = metric.plot(values)
"""
Expand Down
18 changes: 9 additions & 9 deletions src/torchmetrics/classification/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class MultilabelCoverageError(Metric):
Set to ``False`` for faster computations.
Example:
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelCoverageError
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> preds = rand(10, 5)
>>> target = randint(2, (10, 5))
>>> mlce = MultilabelCoverageError(num_labels=5)
>>> mlce(preds, target)
tensor(3.9000)
Expand Down Expand Up @@ -186,10 +186,10 @@ class MultilabelRankingAveragePrecision(Metric):
Set to ``False`` for faster computations.
Example:
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelRankingAveragePrecision
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> preds = rand(10, 5)
>>> target = randint(2, (10, 5))
>>> mlrap = MultilabelRankingAveragePrecision(num_labels=5)
>>> mlrap(preds, target)
tensor(0.7744)
Expand Down Expand Up @@ -308,10 +308,10 @@ class MultilabelRankingLoss(Metric):
Set to ``False`` for faster computations.
Example:
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelRankingLoss
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> preds = rand(10, 5)
>>> target = randint(2, (10, 5))
>>> mlrl = MultilabelRankingLoss(num_labels=5)
>>> mlrl(preds, target)
tensor(0.4167)
Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ class CalinskiHarabaszScore(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example::
>>> import torch
>>> from torch import randn, randint
>>> from torchmetrics.clustering import CalinskiHarabaszScore
>>> _ = torch.manual_seed(42)
>>> data = torch.randn(10, 3)
>>> labels = torch.randint(3, (10,))
>>> data = randn(10, 3)
>>> labels = randint(3, (10,))
>>> metric = CalinskiHarabaszScore()
>>> metric(data, labels)
tensor(3.0053)
Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ class DaviesBouldinScore(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example::
>>> import torch
>>> from torch import randn, randint
>>> from torchmetrics.clustering import DaviesBouldinScore
>>> _ = torch.manual_seed(42)
>>> data = torch.randn(10, 3)
>>> labels = torch.randint(3, (10,))
>>> data = randn(10, 3)
>>> labels = randint(3, (10,))
>>> metric = DaviesBouldinScore()
>>> metric(data, labels)
tensor(1.2540)
Expand Down
Loading

0 comments on commit 79b33ce

Please sign in to comment.