From 491b9c9fa9634ab968a3da6c125b95ddcbcb4789 Mon Sep 17 00:00:00 2001 From: LPeurey Date: Thu, 14 Nov 2024 17:02:50 +0100 Subject: [PATCH] tests for sampler periodic --- tests/test_samplers.py | 28 +++++++++++++++++---------- tests/truth/sampler/periodic_rec.csv | 11 +++++++++++ tests/truth/sampler/periodic_sess.csv | 12 ++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 tests/truth/sampler/periodic_rec.csv create mode 100644 tests/truth/sampler/periodic_sess.csv diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 2d5d36660..bb1c600b1 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -3,6 +3,7 @@ import pytest import shutil from functools import partial +from pathlib import Path from ChildProject.projects import ChildProject from ChildProject.annotations import AnnotationManager @@ -15,6 +16,8 @@ SamplerPipeline, ) +TRUTH = Path('tests', 'truth') +PATH = Path('output', 'samplers') def fake_conversation(data, filename): return data @@ -22,25 +25,30 @@ def fake_conversation(data, filename): @pytest.fixture(scope="function") def project(request): - if not os.path.exists("output/samplers"): - shutil.copytree(src="examples/valid_raw_data", dst="output/samplers") + if os.path.exists(PATH): + # shutil.copytree(src="examples/valid_raw_data", dst="output/annotations") + shutil.rmtree(PATH) + shutil.copytree(src="examples/valid_raw_data", dst=PATH) - project = ChildProject("output/samplers") + project = ChildProject(PATH) project.read() - yield project + yield project -def test_periodic(project): +@pytest.mark.parametrize("by,truth", + [('recording_filename', TRUTH / 'sampler' / 'periodic_rec.csv'), + ('session_id', TRUTH / 'sampler' / 'periodic_sess.csv'), + ]) +def test_periodic(project, by, truth): sampler = PeriodicSampler( - project=project, length=1000, period=1000, recordings=["sound.wav"] + project=project, offset=1000, length=500, period=200, recordings=["sound.wav",'sound2.wav'], by=by ) sampler.sample() - duration = project.recordings[ - project.recordings["recording_filename"] == "sound.wav" - ]["duration"].iloc[0] + # sampler.segments.to_csv(truth, index=False) + truth = pd.read_csv(truth) - assert len(sampler.segments) == int(duration / (1000 + 1000)) + pd.testing.assert_frame_equal(sampler.segments.reset_index(drop=True), truth, check_like=True) def test_energy_detection(project): diff --git a/tests/truth/sampler/periodic_rec.csv b/tests/truth/sampler/periodic_rec.csv new file mode 100644 index 000000000..3026448cf --- /dev/null +++ b/tests/truth/sampler/periodic_rec.csv @@ -0,0 +1,11 @@ +segment_onset,segment_offset,recording_filename +1000,1500,sound.wav +1700,2200,sound.wav +2400,2900,sound.wav +3100,3600,sound.wav +3800,4000,sound.wav +1000,1500,sound2.wav +1700,2200,sound2.wav +2400,2900,sound2.wav +3100,3600,sound2.wav +3800,4000,sound2.wav diff --git a/tests/truth/sampler/periodic_sess.csv b/tests/truth/sampler/periodic_sess.csv new file mode 100644 index 000000000..4c5dc61b9 --- /dev/null +++ b/tests/truth/sampler/periodic_sess.csv @@ -0,0 +1,12 @@ +segment_onset,segment_offset,recording_filename +1000,1500,sound.wav +1700,2200,sound.wav +2400,2900,sound.wav +3100,3600,sound.wav +3800,4000,sound.wav +0,400,sound2.wav +600,1100,sound2.wav +1300,1800,sound2.wav +2000,2500,sound2.wav +2700,3200,sound2.wav +3400,3900,sound2.wav