Skip to content

Commit

Permalink
Apply deterministic RNG to more unit tests (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko authored Oct 8, 2023
1 parent 499ddd3 commit 9c80a1e
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions test/cut/test_custom_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
dummy_recording,
dummy_supervision,
)
from lhotse.testing.random import deterministic_rng


@pytest.mark.parametrize("cut", [dummy_cut(1), dummy_cut(2).pad(300)])
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_cut_custom_nonarray_attr_serialization():
assert restored_cut.SNR == 7.3


def test_cut_load_temporal_array():
def test_cut_load_temporal_array(deterministic_rng):
"""Check that we can read a TemporalArray from a cut when their durations match."""
alignment = np.random.randint(500, size=131)
with TemporaryDirectory() as d, NumpyFilesWriter(d) as writer:
Expand All @@ -120,7 +121,7 @@ def test_cut_load_temporal_array():
np.testing.assert_equal(alignment, restored_alignment)


def test_cut_load_temporal_array_truncate():
def test_cut_load_temporal_array_truncate(deterministic_rng):
"""Check the array loaded via TemporalArray is truncated along with the cut."""
with TemporaryDirectory() as d, NumpyFilesWriter(d) as writer:
expected_duration = 52.4 # 131 frames x 0.4s frame shift == 52.4s
Expand All @@ -138,7 +139,7 @@ def test_cut_load_temporal_array_truncate():


@pytest.mark.parametrize("pad_value", [-1, 0])
def test_cut_load_temporal_array_pad(pad_value):
def test_cut_load_temporal_array_pad(deterministic_rng, pad_value):
"""Check the array loaded via TemporalArray is padded along with the cut."""
with TemporaryDirectory() as d, NumpyFilesWriter(d) as writer:
cut = MonoCut(
Expand All @@ -161,7 +162,7 @@ def test_cut_load_temporal_array_pad(pad_value):
np.testing.assert_equal(alignment_pad[131:], pad_value)


def test_validate_cut_with_temporal_array(caplog):
def test_validate_cut_with_temporal_array(caplog, deterministic_rng):
# Note: "caplog" is a special variable in pytest that captures logs.
caplog.set_level(logging.WARNING)
with TemporaryDirectory() as d, NumpyFilesWriter(d) as writer:
Expand All @@ -185,7 +186,7 @@ def test_validate_cut_with_temporal_array(caplog):
)


def test_cut_load_custom_recording():
def test_cut_load_custom_recording(deterministic_rng):
sampling_rate = 16000
duration = 52.4
audio = np.random.randn(1, compute_num_samples(duration, sampling_rate)).astype(
Expand All @@ -208,7 +209,7 @@ def test_cut_load_custom_recording():
np.testing.assert_allclose(audio, restored_audio, atol=4e-5)


def test_cut_load_custom_recording_truncate():
def test_cut_load_custom_recording_truncate(deterministic_rng):
sampling_rate = 16000
duration = 52.4
audio = np.random.randn(1, compute_num_samples(duration, sampling_rate)).astype(
Expand All @@ -235,7 +236,7 @@ def test_cut_load_custom_recording_truncate():
np.testing.assert_allclose(audio[:, :80000], restored_audio, atol=3e-5)


def test_cut_load_custom_recording_pad_right():
def test_cut_load_custom_recording_pad_right(deterministic_rng):
sampling_rate = 16000
duration = 52.4
audio = np.random.randn(1, compute_num_samples(duration, sampling_rate)).astype(
Expand Down Expand Up @@ -271,7 +272,7 @@ def test_cut_load_custom_recording_pad_right():
np.testing.assert_allclose(0, restored_audio[:, audio.shape[1] :], atol=4e-5)


def test_cut_load_custom_recording_pad_left():
def test_cut_load_custom_recording_pad_left(deterministic_rng):
sampling_rate = 16000
duration = 52.4
audio = np.random.randn(1, compute_num_samples(duration, sampling_rate)).astype(
Expand Down Expand Up @@ -307,7 +308,7 @@ def test_cut_load_custom_recording_pad_left():
)


def test_cut_load_custom_recording_pad_both():
def test_cut_load_custom_recording_pad_both(deterministic_rng):
sampling_rate = 16000
duration = 52.4
audio = np.random.randn(1, compute_num_samples(duration, sampling_rate)).astype(
Expand Down Expand Up @@ -350,7 +351,7 @@ def test_cut_load_custom_recording_pad_both():
)


def test_cut_attach_tensor():
def test_cut_attach_tensor(deterministic_rng):
alignment = np.random.randint(500, size=131)
expected_duration = 52.4 # 131 frames x 0.4s frame shift == 52.4s
cut = MonoCut(id="x", start=0, duration=expected_duration, channel=0)
Expand Down

0 comments on commit 9c80a1e

Please sign in to comment.